dgx-spark-playbooks/nvidia/vlm-finetuning/assets/ui_video/train/video_vlm.ipynb
2025-10-06 15:32:36 +00:00

768 lines
24 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "31e8ca53",
"metadata": {},
"source": [
"# Copyright Notice\n",
"\n",
"```\n",
"SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
"SPDX-License-Identifier: Apache-2.0\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
"http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License.\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "ae293c8d",
"metadata": {},
"source": [
"# VLM-Finetuning for Large Scale Data Analysis"
]
},
{
"cell_type": "markdown",
"id": "0c471d1d",
"metadata": {},
"source": [
"### 📓 Notebook Overview \n",
"In this notebook, we show how to train a VLM to genreate structured metadata about videos. The use case we target is using a VLM to analyze driving videos to generate json-formatted video descriptions and metadata like risky maneuvers to understand dangerous driving patterns. \n",
"\n",
"This is just one example, this workflow can be generalized to any large scale video data analysis task where it's helpful to have structured metadata and automated video analysis."
]
},
{
"cell_type": "markdown",
"id": "6d02f831",
"metadata": {},
"source": [
"### Intitialize Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d06618bc",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"import logging\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" format=\"%(asctime)s - %(levelname)s - %(message)s\"\n",
")\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"import os\n",
"import gc\n",
"import sys\n",
"import json\n",
"import random\n",
"import pathlib\n",
"\n",
"import torch\n",
"import numpy as np\n",
"from PIL import Image\n",
"from decord import VideoReader\n",
"from tensorboard import program\n",
"import torch.nn.functional as F\n",
"from datasets import load_dataset\n",
"from transformers.utils import hub\n",
"import torchvision.transforms as T\n",
"from huggingface_hub import snapshot_download\n",
"from trl import SFTTrainer, SFTConfig\n",
"from transformers.trainer_pt_utils import LabelSmoother\n",
"from transformers.trainer_utils import get_last_checkpoint\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor\n",
"from torchvision.transforms.functional import InterpolationMode"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3cc9f01",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# Set seed for reproducibility\n",
"def set_seed(seed):\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" if torch.cuda.is_available():\n",
" torch.cuda.manual_seed_all(seed)\n",
"\n",
"\n",
"set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef437e65",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# Set some constants\n",
"seq_len = 8192\n",
"model_name = \"OpenGVLab/InternVL3-8B\"\n",
"ignore_token_id = LabelSmoother.ignore_index"
]
},
{
"cell_type": "markdown",
"id": "0bf4800f",
"metadata": {},
"source": [
"### Model Loading"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4e2e572",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\n",
" model_name, \n",
" trust_remote_code=True,\n",
" use_fast=True\n",
")\n",
"tokenizer.padding_side = \"right\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True,\n",
" use_flash_attn=True,\n",
")\n",
"\n",
"# Load the processor\n",
"processor = AutoProcessor.from_pretrained(\n",
" model_name,\n",
" trust_remote_code=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b2c1476a",
"metadata": {},
"source": [
"### Data Processing"
]
},
{
"cell_type": "markdown",
"id": "d3458cd1",
"metadata": {},
"source": [
"**Action Required**: Please update the `dataset_path` with a path to your local dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05d6cd17",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"dataset_path = \"path/to/your/dataset\"\n",
"if os.path.exists(dataset_path):\n",
" print('true')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34fad55c",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# Load conversation template\n",
"model_dir = hub.cached_file(model_name, \"conversation.py\", trust_remote_code=True)\n",
"sys.path.append(os.path.dirname(model_dir))\n",
"\n",
"from conversation import get_conv_template"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f98733c2",
"metadata": {},
"outputs": [],
"source": [
"def get_jsonl_data(sample): \n",
" \"\"\" Data processing helper: Define the labels based on the desired format we hope to teach \n",
" the model to generate \n",
" \"\"\"\n",
" answer_dict = {\n",
" \"caption\": sample['caption'],\n",
" \"event_type\": sample['event_type'],\n",
" \"rule_violations\": sample['rule_violations'],\n",
" \"intended_action\": sample['intended_action'],\n",
" \"traffic_density\": sample['traffic_density'],\n",
" \"scene\": sample['scene'],\n",
" \"visibility\": sample['visibility'],\n",
" }\n",
"\n",
" return json.dumps(answer_dict, ensure_ascii=False) # create a single line, valid JSON"
]
},
{
"cell_type": "markdown",
"id": "84d2f75f",
"metadata": {},
"source": [
"### Load the Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfb85f7c",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# Load the dataset\n",
"dataset = load_dataset(dataset_path)\n",
"dataset = dataset['train']\n",
"dataset = dataset.map(lambda ex: {\"labels\": get_jsonl_data(ex)})\n",
"ds_splits = dataset.train_test_split(test_size=0.01, seed=42)\n",
"train_dataset, val_dataset = ds_splits['train'], ds_splits['test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bca23383",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b98dbe9",
"metadata": {},
"outputs": [],
"source": [
"train_dataset"
]
},
{
"cell_type": "markdown",
"id": "9299e759",
"metadata": {},
"source": [
"### Data Visualization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a95631c",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"def get_video_path(dataset_path, sample_path):\n",
" \"\"\" Dataset speciic helper function -- this appends the sample path to the root path to create the full video path \"\"\"\n",
" root_dir = dataset_path.split('/')[:-1]\n",
" video_path = '/'.join(root_dir) + '/' + sample_path\n",
" return video_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcb4dfba",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"min_frames, max_frames = 8, 32\n",
"\n",
"# load frames from a video\n",
"def load_video(video_path):\n",
" video = VideoReader(video_path, num_threads=1)\n",
"\n",
" # sample a random number of equally-spaced frames from the video\n",
" frame_indices = np.linspace(\n",
" 0,\n",
" len(video) - 1,\n",
" random.randint(min_frames, max_frames),\n",
" dtype=int\n",
" )\n",
" frames = video.get_batch(frame_indices).asnumpy()\n",
" return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d388155a",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"video_path = get_video_path(dataset_path, train_dataset[0]['video'])\n",
"display(load_video(video_path)[0])"
]
},
{
"cell_type": "markdown",
"id": "2b21a463",
"metadata": {},
"source": [
"### Data Processing\n",
"Functions taken from [InternVL3 Documentation](https://internvl.readthedocs.io/en/latest/internvl3.0/quick_start.html#inference-with-transformers)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f8e4c46",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# preprocessing code directly adopted HF model card\n",
"IMAGENET_MEAN = (0.485, 0.456, 0.406)\n",
"IMAGENET_STD = (0.229, 0.224, 0.225)\n",
"\n",
"\n",
"def build_transform(input_size):\n",
" MEAN, STD = IMAGENET_MEAN, IMAGENET_STD\n",
" transform = T.Compose([\n",
" T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n",
" T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=MEAN, std=STD)\n",
" ])\n",
" return transform\n",
"\n",
"\n",
"def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n",
" best_ratio_diff = float('inf')\n",
" best_ratio = (1, 1)\n",
" area = width * height\n",
" for ratio in target_ratios:\n",
" target_aspect_ratio = ratio[0] / ratio[1]\n",
" ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n",
" if ratio_diff < best_ratio_diff:\n",
" best_ratio_diff = ratio_diff\n",
" best_ratio = ratio\n",
" elif ratio_diff == best_ratio_diff:\n",
" if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n",
" best_ratio = ratio\n",
" return best_ratio\n",
"\n",
"\n",
"def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):\n",
" orig_width, orig_height = image.size\n",
" aspect_ratio = orig_width / orig_height\n",
"\n",
" # calculate the existing image aspect ratio\n",
" target_ratios = set(\n",
" (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if\n",
" i * j <= max_num and i * j >= min_num)\n",
" target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n",
"\n",
" # find the closest aspect ratio to the target\n",
" target_aspect_ratio = find_closest_aspect_ratio(\n",
" aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n",
"\n",
" # calculate the target width and height\n",
" target_width = image_size * target_aspect_ratio[0]\n",
" target_height = image_size * target_aspect_ratio[1]\n",
" blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n",
"\n",
" # resize the image\n",
" resized_img = image.resize((target_width, target_height))\n",
" processed_images = []\n",
" for i in range(blocks):\n",
" box = (\n",
" (i % (target_width // image_size)) * image_size,\n",
" (i // (target_width // image_size)) * image_size,\n",
" ((i % (target_width // image_size)) + 1) * image_size,\n",
" ((i // (target_width // image_size)) + 1) * image_size\n",
" )\n",
" # split the image\n",
" split_img = resized_img.crop(box)\n",
" processed_images.append(split_img)\n",
" assert len(processed_images) == blocks\n",
" if use_thumbnail and len(processed_images) != 1:\n",
" thumbnail_img = image.resize((image_size, image_size))\n",
" processed_images.append(thumbnail_img)\n",
" return processed_images\n",
"\n",
"\n",
"# build the transform and get number of tokens per image (per tile technically)\n",
"image_size = model.config.force_image_size\n",
"transform = build_transform(input_size=image_size)\n",
"num_image_tokens = int((image_size // model.config.vision_config.patch_size) ** 2 * (model.config.downsample_ratio ** 2))"
]
},
{
"cell_type": "markdown",
"id": "6097f018",
"metadata": {},
"source": [
"### Define user prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e24f2a4a",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"user_prompt = \"\"\"You are a vision-language assistant analyzing driving videos. You will receive a 5-second video clip of a specific scene. \n",
"\n",
"---\n",
"\n",
"### Task 1: Dense Caption\n",
"Generate a 2 sentence caption describing:\n",
"- Ego vehicle behavior\n",
"- Interactions with other vehicles or pedestrians\n",
"\n",
"Focus on **what happens**, **when**, and **who/what is involved**, using only visible information and metadata.\n",
"\n",
"---\n",
"\n",
"### Task 2: Structured JSON\n",
"Generate the caption from the perspective of the ego vehicle in a structured JSON object with:\n",
"\n",
"- `\"caption\"`: from Task 1 \n",
"- `\"event_type\"`: \"collision\" | \"near_miss\" | \"no_incident\" \n",
"- `\"rule_violations\"`: choose relevant items from [\"speeding\", \"failure_to_yield\", \"ignoring_traffic_signs\"] \n",
"- `\"intended_action\"`: \"turn_left\" | \"turn_right\" | \"change_lanes\" \n",
"- `\"traffic_density\"`: \"low\" | \"high\" \n",
"- `\"visibility\"`: \"good\" | \"bad\" \n",
"- `\"scene\"`: \"Urban\" | \"Sub-urban\" | \"Rural\" | \"Highway\"\n",
"\n",
"**Rules:**\n",
"1. Use only visible info and metadata. \n",
"2. Do not invent details. \n",
"3. Include all fields; enum values must match allowed options. \n",
"4. Output a single valid JSON object—no extra text or markdown. \n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e82e25f",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# View a sample label to confirm it's in the format we want\n",
"sample = train_dataset[0]\n",
"\n",
"answer_dict = {\n",
" \"caption\": sample['caption'],\n",
" \"event_type\": sample['event_type'],\n",
" \"rule_violations\": sample['rule_violations'],\n",
" \"intended_action\": sample['intended_action'],\n",
" \"traffic_density\": sample['traffic_density'],\n",
" \"scene\": sample['scene'],\n",
" \"visibility\": sample['visibility'],\n",
"}\n",
"answer_jsonl = json.dumps(answer_dict, ensure_ascii=False)\n",
"\n",
"print(answer_jsonl)"
]
},
{
"cell_type": "markdown",
"id": "24b3e453",
"metadata": {},
"source": [
"### Custom Data Preprocessing\n",
"\n",
"This novel data tokenization function takes in a batch of samples and tokenizes them according to what InternVL3 source code expects, returning a processed batch of features with input IDs, labels, attention masks, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59243cf6",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# Collate function for the dataset that does on-the-fly preprocessing and batching\n",
"def collate_fn(samples, tokenizer, transform, seq_len, num_image_tokens, get_conv_template, ignore_token_id, load_video, get_video_path, dataset_path):\n",
" input_ids_batch, labels_batch, attention_mask_batch, position_ids_batch, pixel_values_batch, image_flags_batch = [], [], [], [], [], []\n",
" for sample in samples:\n",
" # load the video frames\n",
" video_frames = load_video(get_video_path(dataset_path, sample['video']))\n",
" num_frames = len(video_frames)\n",
"\n",
" # preprocess the video frames\n",
" pixel_values = [transform(frame) for frame in video_frames]\n",
" pixel_values = torch.stack(pixel_values)\n",
" num_tiles = pixel_values.size(0)\n",
"\n",
" # prepend special video tokens to the user message\n",
" video_tokens = '\\n'.join(['Frame-{}: <image>'.format(i + 1) for i in range(num_frames)])\n",
"\n",
" # setup conversation\n",
" conv_template = get_conv_template(\"internvl2_5\")\n",
"\n",
" system_instruction = user_prompt\n",
" answer = get_jsonl_data(sample)\n",
"\n",
" conv_template.append_message(conv_template.roles[0], f'{video_tokens}\\n{system_instruction}')\n",
" conv_template.append_message(conv_template.roles[1], answer)\n",
"\n",
" # replace image tokens with context tokens\n",
" prompt = conv_template.get_prompt()\n",
" prompt = prompt.replace(\"<image>\", f\"<img>{'<IMG_CONTEXT>' * num_image_tokens}</img>\")\n",
"\n",
" # create a list of messages\n",
" messages = [f\"<|im_start|>{message}\" for message in prompt.split(\"<|im_start|>\")[1: ]]\n",
"\n",
" # tokenize the prompt (we manually truncate and pad the sequence)\n",
" input_ids = tokenizer(\n",
" messages,\n",
" return_tensors=\"np\",\n",
" padding=False,\n",
" max_length=seq_len,\n",
" truncation=False,\n",
" ).input_ids\n",
"\n",
" # create targets by masking out system and user messages\n",
" # since we only want to compute loss for the assistant message\n",
" targets = []\n",
" num_ignore_ids = tokenizer('<|im_start|>assistant\\n', return_tensors='np').input_ids[0].shape[0]\n",
" for idx, input_id in enumerate(input_ids):\n",
" if idx != 2:\n",
" targets.append(np.full(input_id.shape, ignore_token_id))\n",
" else:\n",
" target = input_id.copy()\n",
" target[: num_ignore_ids] = ignore_token_id\n",
" target[-1: ] = ignore_token_id\n",
" targets.append(target)\n",
"\n",
" # prepare the input_ids and targets\n",
" input_ids = torch.tensor(np.concatenate(input_ids))[: seq_len]\n",
" targets = torch.tensor(np.concatenate(targets))[: seq_len]\n",
"\n",
" # pad the input_ids and targets to the sequence length\n",
" pad_len = seq_len - input_ids.shape[0]\n",
" input_ids = F.pad(input_ids, (0, pad_len), value=tokenizer.pad_token_id)\n",
" targets = F.pad(targets, (0, pad_len), value=ignore_token_id)\n",
"\n",
" # generate attention mask to filter out padding tokens\n",
" attention_mask = input_ids.ne(tokenizer.pad_token_id)\n",
"\n",
" position_ids = attention_mask.long().cumsum(-1) - 1\n",
" position_ids.masked_fill_(attention_mask == 0, 1)\n",
"\n",
" input_ids_batch.append(input_ids)\n",
" labels_batch.append(targets)\n",
" attention_mask_batch.append(attention_mask)\n",
" position_ids_batch.append(position_ids)\n",
" pixel_values_batch.append(pixel_values)\n",
" image_flags_batch.append(torch.tensor([1] * num_tiles, dtype=torch.long))\n",
"\n",
" batch = {\n",
" \"input_ids\": torch.stack(input_ids_batch),\n",
" \"labels\": torch.stack(labels_batch),\n",
" \"attention_mask\": torch.stack(attention_mask_batch),\n",
" \"position_ids\": torch.stack(position_ids_batch),\n",
" \"pixel_values\": torch.cat(pixel_values_batch),\n",
" \"image_flags\": torch.cat(image_flags_batch)\n",
" }\n",
"\n",
" return batch"
]
},
{
"cell_type": "markdown",
"id": "6a8912c4",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"id": "36c55eb2",
"metadata": {},
"source": [
"### Set Model Config Params for Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59766bcf",
"metadata": {},
"outputs": [],
"source": [
"model.img_context_token_id = tokenizer.convert_tokens_to_ids(\"<IMG_CONTEXT>\")\n",
"\n",
"model.train()\n",
"\n",
"model.language_model.config.use_cache = False\n",
"model.vision_model.gradient_checkpointing = True\n",
"model.vision_model.encoder.gradient_checkpointing = True\n",
"model.language_model._set_gradient_checkpointing()"
]
},
{
"cell_type": "markdown",
"id": "b6c2dec6",
"metadata": {},
"source": [
"### Define Training Params"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "686f3cfd",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"save_dir = \"../saved_model\"\n",
"\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" data_collator=collate_fn,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset,\n",
" processing_class=processor,\n",
" args=SFTConfig(\n",
" num_train_epochs=30,\n",
" per_device_train_batch_size=1,\n",
" per_device_eval_batch_size=1,\n",
" eval_steps=250,\n",
" do_eval=True,\n",
" warmup_ratio=0.03,\n",
" lr_scheduler_type=\"cosine\",\n",
" eval_strategy=\"steps\",\n",
" label_names=[\"labels\"],\n",
" dataloader_num_workers=4,\n",
" gradient_accumulation_steps=4,\n",
" dataloader_persistent_workers=True,\n",
" learning_rate=2e-5,\n",
" weight_decay=0.05,\n",
" logging_steps=10,\n",
" logging_dir=\"logs\",\n",
" save_strategy=\"steps\",\n",
" save_steps=100,\n",
" output_dir=save_dir,\n",
" save_total_limit=2,\n",
" optim=\"adamw_torch\",\n",
" bf16=True,\n",
" remove_unused_columns=False,\n",
" report_to=\"wandb\",\n",
" dataset_kwargs = {\"skip_prepare_dataset\": True},\n",
"\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7c364912",
"metadata": {},
"source": [
"### Train the model\n",
"**Note:** Remove the `resume_from_checkpoint` parameter of `trainer.train()` if you don't want to resume training from a checkpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ab9fa74",
"metadata": {},
"outputs": [],
"source": [
"last_ckpt = get_last_checkpoint(save_dir)\n",
"print(f\"Resuming from {last_ckpt}\")\n",
"trainer.train(resume_from_checkpoint=last_ckpt)"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"encoding": "# coding: utf-8",
"executable": "/usr/bin/env python",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}