mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-26 20:03:52 +00:00
768 lines
24 KiB
Plaintext
768 lines
24 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "31e8ca53",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Copyright Notice\n",
|
|
"\n",
|
|
"```\n",
|
|
"SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
|
|
"SPDX-License-Identifier: Apache-2.0\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"you may not use this file except in compliance with the License.\n",
|
|
"You may obtain a copy of the License at\n",
|
|
"\n",
|
|
"http://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software\n",
|
|
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"See the License for the specific language governing permissions and\n",
|
|
"limitations under the License.\n",
|
|
"```\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ae293c8d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# VLM-Finetuning for Large Scale Data Analysis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0c471d1d",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 📓 Notebook Overview \n",
|
|
"In this notebook, we show how to train a VLM to genreate structured metadata about videos. The use case we target is using a VLM to analyze driving videos to generate json-formatted video descriptions and metadata like risky maneuvers to understand dangerous driving patterns. \n",
|
|
"\n",
|
|
"This is just one example, this workflow can be generalized to any large scale video data analysis task where it's helpful to have structured metadata and automated video analysis."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6d02f831",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Intitialize Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d06618bc",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import logging\n",
|
|
"\n",
|
|
"logging.basicConfig(\n",
|
|
" level=logging.INFO,\n",
|
|
" format=\"%(asctime)s - %(levelname)s - %(message)s\"\n",
|
|
")\n",
|
|
"logger = logging.getLogger(__name__)\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import gc\n",
|
|
"import sys\n",
|
|
"import json\n",
|
|
"import random\n",
|
|
"import pathlib\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"from PIL import Image\n",
|
|
"from decord import VideoReader\n",
|
|
"from tensorboard import program\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from datasets import load_dataset\n",
|
|
"from transformers.utils import hub\n",
|
|
"import torchvision.transforms as T\n",
|
|
"from huggingface_hub import snapshot_download\n",
|
|
"from trl import SFTTrainer, SFTConfig\n",
|
|
"from transformers.trainer_pt_utils import LabelSmoother\n",
|
|
"from transformers.trainer_utils import get_last_checkpoint\n",
|
|
"from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor\n",
|
|
"from torchvision.transforms.functional import InterpolationMode"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f3cc9f01",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set seed for reproducibility\n",
|
|
"def set_seed(seed):\n",
|
|
" random.seed(seed)\n",
|
|
" np.random.seed(seed)\n",
|
|
" torch.manual_seed(seed)\n",
|
|
" if torch.cuda.is_available():\n",
|
|
" torch.cuda.manual_seed_all(seed)\n",
|
|
"\n",
|
|
"\n",
|
|
"set_seed(42)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ef437e65",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set some constants\n",
|
|
"seq_len = 8192\n",
|
|
"model_name = \"OpenGVLab/InternVL3-8B\"\n",
|
|
"ignore_token_id = LabelSmoother.ignore_index"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0bf4800f",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Model Loading"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a4e2e572",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
|
" model_name, \n",
|
|
" trust_remote_code=True,\n",
|
|
" use_fast=True\n",
|
|
")\n",
|
|
"tokenizer.padding_side = \"right\"\n",
|
|
"\n",
|
|
"model = AutoModelForCausalLM.from_pretrained(\n",
|
|
" model_name,\n",
|
|
" torch_dtype=torch.bfloat16,\n",
|
|
" device_map=\"auto\",\n",
|
|
" trust_remote_code=True,\n",
|
|
" use_flash_attn=True,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Load the processor\n",
|
|
"processor = AutoProcessor.from_pretrained(\n",
|
|
" model_name,\n",
|
|
" trust_remote_code=True,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b2c1476a",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Data Processing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d3458cd1",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Action Required**: Please update the `dataset_path` with a path to your local dataset."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "05d6cd17",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset_path = \"path/to/your/dataset\"\n",
|
|
"if os.path.exists(dataset_path):\n",
|
|
" print('true')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "34fad55c",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load conversation template\n",
|
|
"model_dir = hub.cached_file(model_name, \"conversation.py\", trust_remote_code=True)\n",
|
|
"sys.path.append(os.path.dirname(model_dir))\n",
|
|
"\n",
|
|
"from conversation import get_conv_template"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f98733c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_jsonl_data(sample): \n",
|
|
" \"\"\" Data processing helper: Define the labels based on the desired format we hope to teach \n",
|
|
" the model to generate \n",
|
|
" \"\"\"\n",
|
|
" answer_dict = {\n",
|
|
" \"caption\": sample['caption'],\n",
|
|
" \"event_type\": sample['event_type'],\n",
|
|
" \"rule_violations\": sample['rule_violations'],\n",
|
|
" \"intended_action\": sample['intended_action'],\n",
|
|
" \"traffic_density\": sample['traffic_density'],\n",
|
|
" \"scene\": sample['scene'],\n",
|
|
" \"visibility\": sample['visibility'],\n",
|
|
" }\n",
|
|
"\n",
|
|
" return json.dumps(answer_dict, ensure_ascii=False) # create a single line, valid JSON"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "84d2f75f",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load the Dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bfb85f7c",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load the dataset\n",
|
|
"dataset = load_dataset(dataset_path)\n",
|
|
"dataset = dataset['train']\n",
|
|
"dataset = dataset.map(lambda ex: {\"labels\": get_jsonl_data(ex)})\n",
|
|
"ds_splits = dataset.train_test_split(test_size=0.01, seed=42)\n",
|
|
"train_dataset, val_dataset = ds_splits['train'], ds_splits['test']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bca23383",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6b98dbe9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9299e759",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Data Visualization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8a95631c",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_video_path(dataset_path, sample_path):\n",
|
|
" \"\"\" Dataset speciic helper function -- this appends the sample path to the root path to create the full video path \"\"\"\n",
|
|
" root_dir = dataset_path.split('/')[:-1]\n",
|
|
" video_path = '/'.join(root_dir) + '/' + sample_path\n",
|
|
" return video_path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "dcb4dfba",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"min_frames, max_frames = 8, 32\n",
|
|
"\n",
|
|
"# load frames from a video\n",
|
|
"def load_video(video_path):\n",
|
|
" video = VideoReader(video_path, num_threads=1)\n",
|
|
"\n",
|
|
" # sample a random number of equally-spaced frames from the video\n",
|
|
" frame_indices = np.linspace(\n",
|
|
" 0,\n",
|
|
" len(video) - 1,\n",
|
|
" random.randint(min_frames, max_frames),\n",
|
|
" dtype=int\n",
|
|
" )\n",
|
|
" frames = video.get_batch(frame_indices).asnumpy()\n",
|
|
" return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d388155a",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"video_path = get_video_path(dataset_path, train_dataset[0]['video'])\n",
|
|
"display(load_video(video_path)[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2b21a463",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Data Processing\n",
|
|
"Functions taken from [InternVL3 Documentation](https://internvl.readthedocs.io/en/latest/internvl3.0/quick_start.html#inference-with-transformers)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8f8e4c46",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# preprocessing code directly adopted HF model card\n",
|
|
"IMAGENET_MEAN = (0.485, 0.456, 0.406)\n",
|
|
"IMAGENET_STD = (0.229, 0.224, 0.225)\n",
|
|
"\n",
|
|
"\n",
|
|
"def build_transform(input_size):\n",
|
|
" MEAN, STD = IMAGENET_MEAN, IMAGENET_STD\n",
|
|
" transform = T.Compose([\n",
|
|
" T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n",
|
|
" T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),\n",
|
|
" T.ToTensor(),\n",
|
|
" T.Normalize(mean=MEAN, std=STD)\n",
|
|
" ])\n",
|
|
" return transform\n",
|
|
"\n",
|
|
"\n",
|
|
"def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n",
|
|
" best_ratio_diff = float('inf')\n",
|
|
" best_ratio = (1, 1)\n",
|
|
" area = width * height\n",
|
|
" for ratio in target_ratios:\n",
|
|
" target_aspect_ratio = ratio[0] / ratio[1]\n",
|
|
" ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n",
|
|
" if ratio_diff < best_ratio_diff:\n",
|
|
" best_ratio_diff = ratio_diff\n",
|
|
" best_ratio = ratio\n",
|
|
" elif ratio_diff == best_ratio_diff:\n",
|
|
" if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n",
|
|
" best_ratio = ratio\n",
|
|
" return best_ratio\n",
|
|
"\n",
|
|
"\n",
|
|
"def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):\n",
|
|
" orig_width, orig_height = image.size\n",
|
|
" aspect_ratio = orig_width / orig_height\n",
|
|
"\n",
|
|
" # calculate the existing image aspect ratio\n",
|
|
" target_ratios = set(\n",
|
|
" (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if\n",
|
|
" i * j <= max_num and i * j >= min_num)\n",
|
|
" target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n",
|
|
"\n",
|
|
" # find the closest aspect ratio to the target\n",
|
|
" target_aspect_ratio = find_closest_aspect_ratio(\n",
|
|
" aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n",
|
|
"\n",
|
|
" # calculate the target width and height\n",
|
|
" target_width = image_size * target_aspect_ratio[0]\n",
|
|
" target_height = image_size * target_aspect_ratio[1]\n",
|
|
" blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n",
|
|
"\n",
|
|
" # resize the image\n",
|
|
" resized_img = image.resize((target_width, target_height))\n",
|
|
" processed_images = []\n",
|
|
" for i in range(blocks):\n",
|
|
" box = (\n",
|
|
" (i % (target_width // image_size)) * image_size,\n",
|
|
" (i // (target_width // image_size)) * image_size,\n",
|
|
" ((i % (target_width // image_size)) + 1) * image_size,\n",
|
|
" ((i // (target_width // image_size)) + 1) * image_size\n",
|
|
" )\n",
|
|
" # split the image\n",
|
|
" split_img = resized_img.crop(box)\n",
|
|
" processed_images.append(split_img)\n",
|
|
" assert len(processed_images) == blocks\n",
|
|
" if use_thumbnail and len(processed_images) != 1:\n",
|
|
" thumbnail_img = image.resize((image_size, image_size))\n",
|
|
" processed_images.append(thumbnail_img)\n",
|
|
" return processed_images\n",
|
|
"\n",
|
|
"\n",
|
|
"# build the transform and get number of tokens per image (per tile technically)\n",
|
|
"image_size = model.config.force_image_size\n",
|
|
"transform = build_transform(input_size=image_size)\n",
|
|
"num_image_tokens = int((image_size // model.config.vision_config.patch_size) ** 2 * (model.config.downsample_ratio ** 2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6097f018",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Define user prompt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e24f2a4a",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"user_prompt = \"\"\"You are a vision-language assistant analyzing driving videos. You will receive a 5-second video clip of a specific scene. \n",
|
|
"\n",
|
|
"---\n",
|
|
"\n",
|
|
"### Task 1: Dense Caption\n",
|
|
"Generate a 2 sentence caption describing:\n",
|
|
"- Ego vehicle behavior\n",
|
|
"- Interactions with other vehicles or pedestrians\n",
|
|
"\n",
|
|
"Focus on **what happens**, **when**, and **who/what is involved**, using only visible information and metadata.\n",
|
|
"\n",
|
|
"---\n",
|
|
"\n",
|
|
"### Task 2: Structured JSON\n",
|
|
"Generate the caption from the perspective of the ego vehicle in a structured JSON object with:\n",
|
|
"\n",
|
|
"- `\"caption\"`: from Task 1 \n",
|
|
"- `\"event_type\"`: \"collision\" | \"near_miss\" | \"no_incident\" \n",
|
|
"- `\"rule_violations\"`: choose relevant items from [\"speeding\", \"failure_to_yield\", \"ignoring_traffic_signs\"] \n",
|
|
"- `\"intended_action\"`: \"turn_left\" | \"turn_right\" | \"change_lanes\" \n",
|
|
"- `\"traffic_density\"`: \"low\" | \"high\" \n",
|
|
"- `\"visibility\"`: \"good\" | \"bad\" \n",
|
|
"- `\"scene\"`: \"Urban\" | \"Sub-urban\" | \"Rural\" | \"Highway\"\n",
|
|
"\n",
|
|
"**Rules:**\n",
|
|
"1. Use only visible info and metadata. \n",
|
|
"2. Do not invent details. \n",
|
|
"3. Include all fields; enum values must match allowed options. \n",
|
|
"4. Output a single valid JSON object—no extra text or markdown. \n",
|
|
"\"\"\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9e82e25f",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# View a sample label to confirm it's in the format we want\n",
|
|
"sample = train_dataset[0]\n",
|
|
"\n",
|
|
"answer_dict = {\n",
|
|
" \"caption\": sample['caption'],\n",
|
|
" \"event_type\": sample['event_type'],\n",
|
|
" \"rule_violations\": sample['rule_violations'],\n",
|
|
" \"intended_action\": sample['intended_action'],\n",
|
|
" \"traffic_density\": sample['traffic_density'],\n",
|
|
" \"scene\": sample['scene'],\n",
|
|
" \"visibility\": sample['visibility'],\n",
|
|
"}\n",
|
|
"answer_jsonl = json.dumps(answer_dict, ensure_ascii=False)\n",
|
|
"\n",
|
|
"print(answer_jsonl)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "24b3e453",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Custom Data Preprocessing\n",
|
|
"\n",
|
|
"This novel data tokenization function takes in a batch of samples and tokenizes them according to what InternVL3 source code expects, returning a processed batch of features with input IDs, labels, attention masks, etc."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "59243cf6",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Collate function for the dataset that does on-the-fly preprocessing and batching\n",
|
|
"def collate_fn(samples, tokenizer, transform, seq_len, num_image_tokens, get_conv_template, ignore_token_id, load_video, get_video_path, dataset_path):\n",
|
|
" input_ids_batch, labels_batch, attention_mask_batch, position_ids_batch, pixel_values_batch, image_flags_batch = [], [], [], [], [], []\n",
|
|
" for sample in samples:\n",
|
|
" # load the video frames\n",
|
|
" video_frames = load_video(get_video_path(dataset_path, sample['video']))\n",
|
|
" num_frames = len(video_frames)\n",
|
|
"\n",
|
|
" # preprocess the video frames\n",
|
|
" pixel_values = [transform(frame) for frame in video_frames]\n",
|
|
" pixel_values = torch.stack(pixel_values)\n",
|
|
" num_tiles = pixel_values.size(0)\n",
|
|
"\n",
|
|
" # prepend special video tokens to the user message\n",
|
|
" video_tokens = '\\n'.join(['Frame-{}: <image>'.format(i + 1) for i in range(num_frames)])\n",
|
|
"\n",
|
|
" # setup conversation\n",
|
|
" conv_template = get_conv_template(\"internvl2_5\")\n",
|
|
"\n",
|
|
" system_instruction = user_prompt\n",
|
|
" answer = get_jsonl_data(sample)\n",
|
|
"\n",
|
|
" conv_template.append_message(conv_template.roles[0], f'{video_tokens}\\n{system_instruction}')\n",
|
|
" conv_template.append_message(conv_template.roles[1], answer)\n",
|
|
"\n",
|
|
" # replace image tokens with context tokens\n",
|
|
" prompt = conv_template.get_prompt()\n",
|
|
" prompt = prompt.replace(\"<image>\", f\"<img>{'<IMG_CONTEXT>' * num_image_tokens}</img>\")\n",
|
|
"\n",
|
|
" # create a list of messages\n",
|
|
" messages = [f\"<|im_start|>{message}\" for message in prompt.split(\"<|im_start|>\")[1: ]]\n",
|
|
"\n",
|
|
" # tokenize the prompt (we manually truncate and pad the sequence)\n",
|
|
" input_ids = tokenizer(\n",
|
|
" messages,\n",
|
|
" return_tensors=\"np\",\n",
|
|
" padding=False,\n",
|
|
" max_length=seq_len,\n",
|
|
" truncation=False,\n",
|
|
" ).input_ids\n",
|
|
"\n",
|
|
" # create targets by masking out system and user messages\n",
|
|
" # since we only want to compute loss for the assistant message\n",
|
|
" targets = []\n",
|
|
" num_ignore_ids = tokenizer('<|im_start|>assistant\\n', return_tensors='np').input_ids[0].shape[0]\n",
|
|
" for idx, input_id in enumerate(input_ids):\n",
|
|
" if idx != 2:\n",
|
|
" targets.append(np.full(input_id.shape, ignore_token_id))\n",
|
|
" else:\n",
|
|
" target = input_id.copy()\n",
|
|
" target[: num_ignore_ids] = ignore_token_id\n",
|
|
" target[-1: ] = ignore_token_id\n",
|
|
" targets.append(target)\n",
|
|
"\n",
|
|
" # prepare the input_ids and targets\n",
|
|
" input_ids = torch.tensor(np.concatenate(input_ids))[: seq_len]\n",
|
|
" targets = torch.tensor(np.concatenate(targets))[: seq_len]\n",
|
|
"\n",
|
|
" # pad the input_ids and targets to the sequence length\n",
|
|
" pad_len = seq_len - input_ids.shape[0]\n",
|
|
" input_ids = F.pad(input_ids, (0, pad_len), value=tokenizer.pad_token_id)\n",
|
|
" targets = F.pad(targets, (0, pad_len), value=ignore_token_id)\n",
|
|
"\n",
|
|
" # generate attention mask to filter out padding tokens\n",
|
|
" attention_mask = input_ids.ne(tokenizer.pad_token_id)\n",
|
|
"\n",
|
|
" position_ids = attention_mask.long().cumsum(-1) - 1\n",
|
|
" position_ids.masked_fill_(attention_mask == 0, 1)\n",
|
|
"\n",
|
|
" input_ids_batch.append(input_ids)\n",
|
|
" labels_batch.append(targets)\n",
|
|
" attention_mask_batch.append(attention_mask)\n",
|
|
" position_ids_batch.append(position_ids)\n",
|
|
" pixel_values_batch.append(pixel_values)\n",
|
|
" image_flags_batch.append(torch.tensor([1] * num_tiles, dtype=torch.long))\n",
|
|
"\n",
|
|
" batch = {\n",
|
|
" \"input_ids\": torch.stack(input_ids_batch),\n",
|
|
" \"labels\": torch.stack(labels_batch),\n",
|
|
" \"attention_mask\": torch.stack(attention_mask_batch),\n",
|
|
" \"position_ids\": torch.stack(position_ids_batch),\n",
|
|
" \"pixel_values\": torch.cat(pixel_values_batch),\n",
|
|
" \"image_flags\": torch.cat(image_flags_batch)\n",
|
|
" }\n",
|
|
"\n",
|
|
" return batch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6a8912c4",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "36c55eb2",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Set Model Config Params for Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "59766bcf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.img_context_token_id = tokenizer.convert_tokens_to_ids(\"<IMG_CONTEXT>\")\n",
|
|
"\n",
|
|
"model.train()\n",
|
|
"\n",
|
|
"model.language_model.config.use_cache = False\n",
|
|
"model.vision_model.gradient_checkpointing = True\n",
|
|
"model.vision_model.encoder.gradient_checkpointing = True\n",
|
|
"model.language_model._set_gradient_checkpointing()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b6c2dec6",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Define Training Params"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "686f3cfd",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_dir = \"../saved_model\"\n",
|
|
"\n",
|
|
"trainer = SFTTrainer(\n",
|
|
" model=model,\n",
|
|
" data_collator=collate_fn,\n",
|
|
" train_dataset=train_dataset,\n",
|
|
" eval_dataset=val_dataset,\n",
|
|
" processing_class=processor,\n",
|
|
" args=SFTConfig(\n",
|
|
" num_train_epochs=30,\n",
|
|
" per_device_train_batch_size=1,\n",
|
|
" per_device_eval_batch_size=1,\n",
|
|
" eval_steps=250,\n",
|
|
" do_eval=True,\n",
|
|
" warmup_ratio=0.03,\n",
|
|
" lr_scheduler_type=\"cosine\",\n",
|
|
" eval_strategy=\"steps\",\n",
|
|
" label_names=[\"labels\"],\n",
|
|
" dataloader_num_workers=4,\n",
|
|
" gradient_accumulation_steps=4,\n",
|
|
" dataloader_persistent_workers=True,\n",
|
|
" learning_rate=2e-5,\n",
|
|
" weight_decay=0.05,\n",
|
|
" logging_steps=10,\n",
|
|
" logging_dir=\"logs\",\n",
|
|
" save_strategy=\"steps\",\n",
|
|
" save_steps=100,\n",
|
|
" output_dir=save_dir,\n",
|
|
" save_total_limit=2,\n",
|
|
" optim=\"adamw_torch\",\n",
|
|
" bf16=True,\n",
|
|
" remove_unused_columns=False,\n",
|
|
" report_to=\"wandb\",\n",
|
|
" dataset_kwargs = {\"skip_prepare_dataset\": True},\n",
|
|
"\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7c364912",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Train the model\n",
|
|
"**Note:** Remove the `resume_from_checkpoint` parameter of `trainer.train()` if you don't want to resume training from a checkpoint."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7ab9fa74",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"last_ckpt = get_last_checkpoint(save_dir)\n",
|
|
"print(f\"Resuming from {last_ckpt}\")\n",
|
|
"trainer.train(resume_from_checkpoint=last_ckpt)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"encoding": "# coding: utf-8",
|
|
"executable": "/usr/bin/env python",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|