chore: Regenerate all playbooks
@ -46,7 +46,6 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
|
||||
- [Unsloth on DGX Spark](nvidia/unsloth/)
|
||||
- [Vibe Coding in VS Code](nvidia/vibe-coding/)
|
||||
- [Install and Use vLLM for Inference](nvidia/vllm/)
|
||||
- [Vision-Language Model Fine-tuning](nvidia/vlm-finetuning/)
|
||||
- [VS Code](nvidia/vscode/)
|
||||
- [Build a Video Search and Summarization (VSS) Agent](nvidia/vss/)
|
||||
|
||||
|
||||
@ -64,6 +64,8 @@ All required assets can be found [in the ComfyUI repository on GitHub](https://g
|
||||
* Model downloads are large (~2GB) and may fail due to network issues
|
||||
* Port 8188 must be accessible for web interface functionality
|
||||
* **Rollback:** Virtual environment can be deleted to remove all installed packages. Downloaded models can be removed manually from the checkpoints directory.
|
||||
* **Last Updated:** 11/10/2025
|
||||
* Update ComfyUI PyTorch to CUDA 13.0
|
||||
|
||||
## Instructions
|
||||
|
||||
|
||||
@ -1,334 +0,0 @@
|
||||
# Vision-Language Model Fine-tuning
|
||||
|
||||
> Fine-tune Vision-Language Models for image and video understanding tasks using Qwen2.5-VL and InternVL3
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Instructions](#instructions)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
This playbook demonstrates how to fine-tune Vision-Language Models (VLMs) for both image and video understanding tasks on DGX Spark.
|
||||
With 128GB of unified memory and powerful GPU acceleration, DGX Spark provides an ideal environment for training VRAM-intensive multimodal models that can understand and reason about visual content.
|
||||
|
||||
The playbook covers two distinct VLM fine-tuning approaches:
|
||||
- **Image VLM Fine-tuning**: Using Qwen2.5-VL-7B for wildfire detection from satellite imagery with GRPO (Generalized Reward Preference Optimization)
|
||||
- **Video VLM Fine-tuning**: Using InternVL3 8B for dangerous driving detection and structured metadata generation from driving videos
|
||||
|
||||
Both approaches leverage advanced training techniques, including LoRA fine-tuning, preference optimization, and structured reasoning to achieve superior performance on specialized tasks.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You will have fine-tuned VLM models capable of understanding and analyzing both images and videos for specialized use cases, accessible through interactive Web UIs.
|
||||
The setup includes:
|
||||
- **Image VLM**: Qwen2.5-VL fine-tuned for wildfire detection with reasoning capability
|
||||
- **Video VLM**: InternVL3 fine-tuned for dangerous driving analysis and structured metadata generation
|
||||
- Interactive Streamlit interfaces for both training and inference
|
||||
- Side-by-side model comparison (base vs fine-tuned) in the Web UIs
|
||||
- Docker containerization for reproducible environments
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- DGX Spark device is set up and accessible
|
||||
- No other processes running on the DGX Spark GPU
|
||||
- Enough disk space for model downloads and datasets
|
||||
- NVIDIA Docker installed and configured
|
||||
- Weights & Biases account for training monitoring (optional but recommended)
|
||||
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Duration**:
|
||||
* 15-20 minutes for initial setup and model downloads
|
||||
* 30-60 minutes for image VLM training (depending on dataset size)
|
||||
* 1-2 hours for video VLM training (depending on video dataset size)
|
||||
* **Risks**:
|
||||
* Docker permission issues may require user group changes and a session restart
|
||||
* Large model downloads and datasets may require significant disk space and time
|
||||
* Training requires sustained GPU usage and memory
|
||||
* Dataset preparation may require manual steps (Kaggle downloads, video processing)
|
||||
* **Rollback**: Stop and remove Docker containers, delete downloaded models and datasets if needed.
|
||||
|
||||
## Instructions
|
||||
|
||||
## Step 1. Configure Docker permissions
|
||||
|
||||
To easily manage containers without sudo, you must be in the `docker` group. If you choose to skip this step, you will need to run Docker commands with sudo.
|
||||
|
||||
Open a new terminal and test Docker access. In the terminal, run:
|
||||
|
||||
```bash
|
||||
docker ps
|
||||
```
|
||||
|
||||
If you see a permission denied error (something like permission denied while trying to connect to the Docker daemon socket), add your user to the docker group so that you don't need to run the command with sudo .
|
||||
|
||||
```bash
|
||||
sudo usermod -aG docker $USER
|
||||
newgrp docker
|
||||
```
|
||||
|
||||
## Step 2. Clone the repository
|
||||
|
||||
In a terminal, clone the repository and navigate to the VLM fine-tuning directory.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NVIDIA/dgx-spark-playbooks
|
||||
```
|
||||
|
||||
## Step 3. Build the Docker container
|
||||
|
||||
Build the Docker image. This will set up the environment for both image and video VLM fine-tuning.
|
||||
Please export your Hugging Face token as an environment variable - `HF_TOKEN`. You may encounter warnings when building the image. This is expected and can be ignored.
|
||||
|
||||
```bash
|
||||
## Enter the correct directory for building the image
|
||||
cd dgx-spark-playbooks/nvidia/vlm-finetuning/assets
|
||||
|
||||
## Build the VLM fine-tuning container
|
||||
docker build --build-arg HF_TOKEN=$HF_TOKEN -t vlm_demo .
|
||||
```
|
||||
|
||||
## Step 4. Run the Docker container
|
||||
|
||||
```bash
|
||||
## Run the container with GPU support
|
||||
sh launch.sh
|
||||
|
||||
## Enter the mounted directory within the container
|
||||
cd /vlm_finetuning
|
||||
```
|
||||
> [!NOTE]
|
||||
> The same Docker container and launch commands work for both image and video VLM recipes. The container features all necessary dependencies, including FFmpeg, Decord, and optimized libraries for both workflows.
|
||||
|
||||
## Step 5. [Option A] For image VLM fine-tuning (Wildfire Detection)
|
||||
|
||||
#### 5.1. Model download
|
||||
|
||||
```bash
|
||||
hf download Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
If you already have a fine-tuned checkpoint, place it in the `saved_model/` folder. Note that your checkpoint number can be different. For a comparative analysis against the base model, skip directly to the `Finetuned Model Inference` section.
|
||||
|
||||
#### 5.2. Download the wildfire dataset
|
||||
|
||||
The project uses a **Wildfire Detection Dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes:
|
||||
- Satellite and aerial imagery from wildfire-affected areas
|
||||
- Binary classification: wildfire vs no wildfire
|
||||
|
||||
```bash
|
||||
mkdir -p ui_image/data
|
||||
cd ui_image/data
|
||||
```
|
||||
|
||||
For this fine-tuning playbook, we will use the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) from Kaggle. Visit the kaggle dataset page [here](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) to click the download button. Select the `cURL` option in the `Download Via` dropdown and copy the curl command.
|
||||
|
||||
> [!NOTE]
|
||||
> You will need to be logged into Kaggle and may need to accept the dataset terms before the download link works.
|
||||
|
||||
Run the following commands in your container:
|
||||
|
||||
```bash
|
||||
## Paste and run the curl command from Kaggle here, and then continue to unzip the dataset
|
||||
|
||||
unzip -qq wildfire-prediction-dataset.zip
|
||||
rm wildfire-prediction-dataset.zip
|
||||
cd ..
|
||||
```
|
||||
|
||||
#### 5.3. Base model inference
|
||||
|
||||
Before we start fine-tuning, let's spin up the demo UI to evaluate the base model's performance on this task.
|
||||
|
||||
```bash
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
When you access the streamlit demo for the first time, the backend triggers vLLM servers to spin up for the base model. You will see a spinner on the demo site as vLLM is being brought up for optimized inference. This step can take up to 15 mins.
|
||||
|
||||
Since we are currently focused on inferring the base model, let's scroll down to the `Image Inference` section of the UI. Here, you should see a sample pre-loaded satellite image of a potentially wildfire-affected region.
|
||||
|
||||
Enter your prompt in the chat box and hit `Generate`. Your prompt would be first sent to the base model and you should see the generation response on the left chat box. If you did not provide a fine-tuned model, you should not see any generations from the right chat box. You can use the following prompt to quickly test inference:
|
||||
|
||||
`Identify if this region has been affected by a wildfire`
|
||||
|
||||
As you can see, the base model is incapable of providing the right response for this domain-specific task. Let's try to improve the model's accuracy by performing GRPO fine-tuning.
|
||||
|
||||
#### 5.4. GRPO fine-tuning
|
||||
|
||||
We will perform GRPO fine-tuning to add reasoning capabilities to our base model and improve the model's understanding of the underlying domain. Considering that you have already spun up the streamlit demo, scroll to the `GRPO Training section`.
|
||||
|
||||
Configure the finetuning method and lora parameters based on the following options.
|
||||
|
||||
- `Finetuning Method`: Choose from Full Finetuning or LoRA
|
||||
- `LoRA Parameters`: Adjustable rank (8-64) and alpha (8-64)
|
||||
|
||||
You can additionally choose whether the layers you want to fine-tune in the VLM. For the best performance, ensure that all options are toggled on. Note that this will increase the model training time as well.
|
||||
|
||||
In this section, we can select certain model parameters as relevant to our training run.
|
||||
|
||||
- `Steps`: 1-1000
|
||||
- `Batch Size`: 1, 2, 4, 8, or 16
|
||||
- `Learning Rate`: 1e-6 to 1e-2
|
||||
- `Optimizer`: AdamW or Adafactor
|
||||
|
||||
For a GRPO setup, we also have the flexibility in choosing the reward that is assigned to the model based on certain criteria
|
||||
|
||||
- `Format Reward`: 2.0 (reward for proper reasoning format)
|
||||
- `Correctness Reward`: 5.0 (reward for correct answers)
|
||||
- `Number of Generations`: 4 (for preference optimization)
|
||||
|
||||
After configuring all the parameters, hit `Start Finetuning` to begin the training process. You will need to wait about 15 minutes for the model to load and start recording metadata on the UI. As the training progresses, information such as the loss, step, and GRPO rewards will be recorded on a live table.
|
||||
|
||||
The default loaded configuration should give you reasonable accuracy, taking 100 steps of training over a period of up to 2 hours. We achieved our best accuracy with around 1000 steps of training, taking close to 16 hours.
|
||||
|
||||
After training is complete, the script automatically merges LoRA weights into the base model. After the training process has reached the desired number of training steps, it can take 5 mins to merge the LoRA weights.
|
||||
|
||||
If you wish to stop training, just hit the `Stop Finetuning` button. Please use this button only to interrupt training. This button does not guarantee that the checkpoints will be properly stored or merged with lora adapter layers.
|
||||
|
||||
Once you stop training, the UI will automatically bring up the vLLM servers for the base model and the newly fine-tuned model.
|
||||
|
||||
#### 5.5. Fine-tuned model inference
|
||||
|
||||
Now we are ready to perform a comparative analysis between the base model and the fine-tuned model.
|
||||
|
||||
If you haven't spun up the streamlit demo already, execute the following command. If had just just stopped training and are still within the live UI, skip this step.
|
||||
|
||||
```bash
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
|
||||
Regardless of whether you just spun up the demo or just stopped training, please wait about 15 minutes for the vLLM servers to be brought up.
|
||||
|
||||
Scroll down to the `Image Inference` section and enter your prompt in the provided chat box. Upon clicking `Generate` your prompt will be first sent to the base model and then to the fine-tuned model. You can use the following prompt to quickly test inference:
|
||||
|
||||
`Identify if this region has been affected by a wildfire`
|
||||
|
||||
If you trained your model sufficiently, you should see that the fine-tuned model is able to perform reasoning and provide a concise, accurate answer to the prompt. The reasoning steps are provided in the markdown format, while the final answer is bolded and provided at the end of the model's response.
|
||||
|
||||
## Step 6. [Option B] For video VLM fine-tuning (Driver Behaviour Analysis)
|
||||
|
||||
Within the same container, navigate to the `ui_video` directory.
|
||||
|
||||
```bash
|
||||
cd /vlm_finetuning/ui_video
|
||||
```
|
||||
|
||||
#### 6.1. Prepare your video dataset
|
||||
|
||||
Structure your dataset as follows. Ensure that `metadata.jsonl` contains rows of structured JSON data about each video.
|
||||
```
|
||||
dataset/
|
||||
├── videos/
|
||||
│ ├── video1.mp4
|
||||
│ ├── video2.mp4
|
||||
│ └── ...
|
||||
└── metadata.jsonl
|
||||
```
|
||||
|
||||
#### 6.2. Model download
|
||||
|
||||
> [!NOTE]
|
||||
> These instructions assume you are already inside the Docker container. For container setup, refer to the section above to `Build the Docker container`.
|
||||
|
||||
```bash
|
||||
hf download OpenGVLab/InternVL3-8B
|
||||
```
|
||||
|
||||
#### 6.3. Base model inference
|
||||
|
||||
Before going ahead to fine-tune our video VLM for this task, let's see how the base InternVL3-8B does.
|
||||
|
||||
```bash
|
||||
## cd into /vlm_finetuning/ui_video if you haven't already
|
||||
streamlit run Video_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
When you access the streamlit demo for the first time, the backend triggers Huggingface to spin up the base model. You will see a spinner on the demo site as the model is being loaded, which can take up to 10 minutes.
|
||||
|
||||
First, let's select a video from our dashcam gallery. Upon clicking the green file open icon near a video, you should see the video render and play automatically for our reference.
|
||||
|
||||
Scroll down, enter your prompt in the chat box and hit `Generate`. Your prompt would be first sent to the base model and you should see the generation response on the left chat box. If you did not provide a finetuned model, you should not see any generations from the right chat box. You can use the following prompt to quickly test inference:
|
||||
|
||||
`Analyze the dashcam footage for unsafe driver behavior`
|
||||
|
||||
If you are proceeding to train a fine-tuned model, ensure that the streamlit demo UI is brought down before proceeding to train. You can bring it down by interrupting the terminal with `Ctrl+C` keystroke.
|
||||
|
||||
> [!NOTE]
|
||||
> To clear out any extra occupied memory from your system, execute the following command outside the container after interrupting the ComfyUI server.
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
|
||||
#### 6.4. Run the training notebook
|
||||
|
||||
```bash
|
||||
## Enter the correct directory
|
||||
cd train
|
||||
|
||||
## Start Jupyter Lab
|
||||
jupyter notebook video_vlm.ipynb
|
||||
```
|
||||
Access Jupyter at `http://localhost:8888`. Ensure that you set the path to your dataset correctly in the appropriate cell.
|
||||
|
||||
```python
|
||||
dataset_path = "/path/to/your/dataset"
|
||||
```
|
||||
|
||||
Here are some of the key training parameters that are configurable. Please note that for reasonable quality, you will need to train your video VLM for atleast 24 hours given the complexity of processing spatio-temporal video sequences.
|
||||
|
||||
- **Model**: InternVL3-8B
|
||||
- **Video Frames**: 12 to 16 frames per video
|
||||
- **Sampling Mode**: Uniform temporal sampling
|
||||
- **LoRA Configuration**: Efficient parameter updates for large-scale fine-tuning
|
||||
- **Hyperparameters**: Exhaustive suite of hyperparameters to tune for video VLM fine-tuning
|
||||
|
||||
You can monitor and evaluate the training progress and metrics, as they will be continuously shown in the notebook.
|
||||
|
||||
After training, ensure that you shutdown the jupyter kernel in the notebook and kill the jupyter server in the terminal with a `Ctrl+C` keystroke.
|
||||
|
||||
> [!NOTE]
|
||||
> To clear out any extra occupied memory from your system, execute the following command outside the container after interrupting the ComfyUI server.
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
#### 6.5. Fine-tuned model inference
|
||||
|
||||
Now we are ready to perform a comparative analysis between the base model and the fine-tuned model.
|
||||
|
||||
If you haven't spun up the streamlit demo already, execute the following command. If you have just stopped training and are still within the live UI, skip to the next step.
|
||||
|
||||
```bash
|
||||
## cd back to /vlm_finetuning/ui_video if you haven't already
|
||||
streamlit run Video_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
If you trained your model sufficiently, you should see that the fine-tuned model is able to identify the salient events from the video and generate a structured output.
|
||||
|
||||
Since the model's output adheres to the schema we trained, we can directly export the model's prediction into a database for video analytics.
|
||||
|
||||
Feel free to play around with additional videos available in the gallery.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
> [!NOTE]
|
||||
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing between the GPU and CPU.
|
||||
> With many applications still updating to take advantage of UMA, you may encounter memory issues even when within
|
||||
> the memory capacity of DGX Spark. If that happens, manually flush the buffer cache with:
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
@ -1,184 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# Dockerfile stage to build GPU-accelerated ffmpeg
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3 AS ffmpeg-builder
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
yasm \
|
||||
libx264-dev \
|
||||
libfaac-dev \
|
||||
libmp3lame-dev \
|
||||
libtheora-dev \
|
||||
libvorbis-dev \
|
||||
libxvidcore-dev \
|
||||
libxext-dev \
|
||||
libxfixes-dev \
|
||||
build-essential \
|
||||
git \
|
||||
pkg-config && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc-11 g++-11
|
||||
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH} \
|
||||
CUDA_HOME=/usr/local/cuda \
|
||||
NVCC=/usr/local/cuda/bin/nvcc \
|
||||
CC=/usr/bin/gcc-11 \
|
||||
CXX=/usr/bin/g++-11 \
|
||||
CUDAHOSTCXX=/usr/bin/g++-11 \
|
||||
FFMPEG_VERSION=4.4.6 \
|
||||
LD_LIBRARY_PATH=/usr/local/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
RUN mkdir -p /deps && \
|
||||
cd /deps && \
|
||||
wget -q https://ffmpeg.org/releases/ffmpeg-${FFMPEG_VERSION}.tar.xz && \
|
||||
tar -xf ffmpeg-${FFMPEG_VERSION}.tar.xz && \
|
||||
rm ffmpeg-${FFMPEG_VERSION}.tar.xz && \
|
||||
cd /deps/ffmpeg-${FFMPEG_VERSION} && \
|
||||
apt-get update && \
|
||||
apt-get install -y libdrm-dev && \
|
||||
./configure \
|
||||
--prefix=/usr/local \
|
||||
--enable-nonfree \
|
||||
--enable-shared \
|
||||
--disable-static \
|
||||
--enable-libdrm \
|
||||
--enable-v4l2-m2m \
|
||||
--enable-gpl \
|
||||
--enable-libx264 \
|
||||
--extra-cflags="-I/usr/include/aarch64-linux-gnu" \
|
||||
--extra-ldflags="-L/usr/lib/aarch64-linux-gnu/tegra" \
|
||||
--nvccflags="-ccbin=/usr/bin/g++-11" \
|
||||
|| (echo "---- tail ffbuild/config.log ----" && tail -n 200 ffbuild/config.log && exit 1) && \
|
||||
make -j"$(nproc)" && \
|
||||
make install && \
|
||||
ldconfig && \
|
||||
echo "✅ FFmpeg installed:" && \
|
||||
ffmpeg -hide_banner -version | head -n 8 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Dockerfile stage to compile decord from source
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3 AS decord-builder
|
||||
COPY --from=ffmpeg-builder /usr/local /usr/local
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git \
|
||||
cmake \
|
||||
ninja-build \
|
||||
pkg-config \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
gcc-11 \
|
||||
g++-11 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV CC=/usr/bin/gcc-11 \
|
||||
CXX=/usr/bin/g++-11 \
|
||||
PATH=/usr/local/bin:/lib/aarch64-linux-gnu:${PATH} \
|
||||
LD_LIBRARY_PATH=/usr/lib/aarch64-linux-gnu/:/lib/aarch64-linux-gnu/:/usr/local/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH} \
|
||||
PKG_CONFIG_PATH=/usr/local/lib/pkgconfig:/usr/lib/aarch64-linux-gnu/pkgconfig${PKG_CONFIG_PATH:+:$PKG_CONFIG_PATH}
|
||||
|
||||
RUN ln -sf /usr/lib/aarch64-linux-gnu/libnvcuvid.so.1 /usr/lib/aarch64-linux-gnu/libnvcuvid.so && \
|
||||
cp /usr/lib/aarch64-linux-gnu/libnvcuvid.so* /usr/local/cuda/lib64/ || true && \
|
||||
ln -sf /usr/lib/aarch64-linux-gnu/libnvcuvid.so.1 /usr/local/cuda/lib64/libnvcuvid.so && \
|
||||
echo "/usr/lib/aarch64-linux-gnu" > /etc/ld.so.conf.d/nvidia.conf && \
|
||||
ldconfig && \
|
||||
python3 -m pip install --no-cache-dir --upgrade pip wheel build numpy && \
|
||||
python3 -m pip install --no-cache-dir --upgrade pip ninja && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends libnvidia-decode-575 libnvidia-encode-575
|
||||
|
||||
RUN cd /workspace && \
|
||||
git clone --recursive https://github.com/dmlc/decord && \
|
||||
cmake -S /workspace/decord -B /workspace/decord/build \
|
||||
-G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DUSE_CUDA=ON \
|
||||
-DCMAKE_CUDA_HOST_COMPILER=/usr/bin/g++-11 \
|
||||
-DCMAKE_C_COMPILER=/usr/bin/gcc-11 \
|
||||
-DCMAKE_CXX_COMPILER=/usr/bin/g++-11 \
|
||||
-DFFMPEG_ROOT=/usr/local \
|
||||
-DUSE_VIDEO_CODEC=OFF && \
|
||||
cd /workspace/decord/build && \
|
||||
ninja -j"$(nproc)" && \
|
||||
cd /workspace/decord/python && \
|
||||
python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel build && \
|
||||
python3 -m build --wheel
|
||||
|
||||
# Dockerfile for demo
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3
|
||||
COPY --from=ffmpeg-builder /usr/local /usr/local
|
||||
COPY --from=decord-builder /workspace/decord/python/dist/*.whl /tmp/wheels/
|
||||
|
||||
ARG HF_TOKEN
|
||||
|
||||
RUN pip install --no-cache-dir /tmp/wheels/*.whl && \
|
||||
rm -rf /tmp/wheels && \
|
||||
apt-get update && \
|
||||
apt-get install -y libdrm2 libdrm-dev libx264-dev && \
|
||||
pip install streamlit timm wandb && \
|
||||
hf auth login --token $HF_TOKEN
|
||||
|
||||
# Set CUDA environment variables
|
||||
ENV CUDA_HOME=/usr/local/cuda-13.0/
|
||||
ENV CUDA_PATH=$CUDA_HOME
|
||||
ENV PATH=$CUDA_HOME/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
||||
ENV C_INCLUDE_PATH=$CUDA_HOME/include:$C_INCLUDE_PATH
|
||||
ENV CPLUS_INCLUDE_PATH=$CUDA_HOME/include:$CPLUS_INCLUDE_PATH
|
||||
|
||||
# install triton from source for latest blackwell support
|
||||
RUN git clone https://github.com/triton-lang/triton.git && \
|
||||
cd triton && \
|
||||
git checkout c5d671f91d90f40900027382f98b17a3e04045f6 && \
|
||||
pip install -r python/requirements.txt && \
|
||||
pip install . && \
|
||||
cd ..
|
||||
|
||||
# install xformers from source for blackwell support
|
||||
RUN git clone https://github.com/facebookresearch/xformers && \
|
||||
cd xformers && \
|
||||
git checkout 5146f2ab37b2163985c19fb4e8fbf6183e82f8ce && \
|
||||
git submodule update --init --recursive && \
|
||||
export TORCH_CUDA_ARCH_LIST="12.1" && \
|
||||
python setup.py install && \
|
||||
cd ..
|
||||
# install unsloth without depedencies so we can build them from source
|
||||
RUN pip install unsloth==2025.9.11 unsloth_zoo==2025.9.14 bitsandbytes==0.48.0
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
|
||||
# docker run \
|
||||
# --rm \
|
||||
# --gpus=all \
|
||||
# --ipc=host \
|
||||
# --net=host \
|
||||
# --ulimit memlock=-1 \
|
||||
# --ulimit stack=67108864 \
|
||||
# -w $(pwd) \
|
||||
# -v $(pwd):$(pwd) \
|
||||
# -v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||
# nvcr.io/nvidia/vllm:25.09-py3 \
|
||||
# vllm serve "unsloth/Qwen2.5-VL-7B-Instruct" --port "8000" --served-model-name "base-model" --max-model-len 16384 --gpu-memory-utilization 0.3 --async-scheduling --enable_prefix_caching
|
||||
@ -1,47 +0,0 @@
|
||||
# VLM Fine-tuning Recipes
|
||||
|
||||
This repository contains comprehensive fine-tuning recipes for Vision-Language Models (VLMs), supporting both **image** and **video** understanding tasks with modern models and training techniques.
|
||||
|
||||
## 🎯 Available Recipes
|
||||
|
||||
### 📸 Image VLM Fine-tuning (`ui_image/`)
|
||||
- **Model**: Qwen2.5-VL-7B-Instruct
|
||||
- **Task**: Wildfire detection from satellite imagery
|
||||
- **Training Method**: GRPO (Generalized Reward Preference Optimization) and LoRA (Low-rank Adaptation)
|
||||
|
||||
### 🎥 Video VLM Fine-tuning (`ui_video/`)
|
||||
- **Model**: InternVL3-8B
|
||||
- **Task**: Dangerous driving detection and structured metadata generation
|
||||
- **Training Method**: Supervised Fine-tuning on Multimodal Instructions
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 1. Build the Docker Container
|
||||
|
||||
```bash
|
||||
# Build the VLM fine-tuning container
|
||||
docker build --build-arg HF_TOKEN=$HF_TOKEN -t vlm_demo .
|
||||
```
|
||||
|
||||
### 2. Launch the Container
|
||||
|
||||
```bash
|
||||
# Enter the correct directory for building the image
|
||||
cd vlm-finetuning/assets
|
||||
|
||||
# Run the container with GPU support
|
||||
sh launch.sh
|
||||
|
||||
# Enter the mounted directory within the container
|
||||
cd /vlm_finetuning
|
||||
```
|
||||
|
||||
> **Note**: The same Docker container and launch commands work for both image and video VLM recipes. The container includes all necessary dependencies including FFmpeg, Decord, and optimized libraries for both workflows.
|
||||
|
||||
## 📚 Detailed Instructions
|
||||
|
||||
Each recipe includes comprehensive documentation:
|
||||
|
||||
- **[Image VLM README](ui_image/README.md)**: Complete guide for wildfire detection fine-tuning with Qwen2.5-VL, including dataset setup, GRPO training configuration, and interactive inference
|
||||
- **[Video VLM README](ui_video/README.md)**: Full walkthrough for dangerous driving detection with InternVL3, covering video data preparation, training notebooks, and structured output generation
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
docker run -it \
|
||||
--gpus=all \
|
||||
--net=host \
|
||||
--ipc=host \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
-e HOST_HOME=$HOME \
|
||||
-e HOST_PWD=$(pwd) \
|
||||
-v $(pwd):/vlm_finetuning \
|
||||
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v /usr/bin/docker:/usr/bin/docker \
|
||||
vlm_demo
|
||||
@ -1,3 +0,0 @@
|
||||
[theme]
|
||||
base="dark"
|
||||
greenTextColor = "#76b900"
|
||||
@ -1,509 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import yaml
|
||||
import glob
|
||||
import time
|
||||
import base64
|
||||
import random
|
||||
import requests
|
||||
import subprocess
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
|
||||
REASONING_START = "<REASONING>"
|
||||
REASONING_END = "</REASONING>"
|
||||
SOLUTION_START = "<SOLUTION>"
|
||||
SOLUTION_END = "</SOLUTION>"
|
||||
|
||||
|
||||
def load_config():
|
||||
config_key = "config"
|
||||
if getattr(st.session_state, config_key, None) is None:
|
||||
with open("src/image_vlm_config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
setattr(st.session_state, config_key, config)
|
||||
else:
|
||||
config = getattr(st.session_state, config_key)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def start_vllm_server(model_id, model_type, max_seq_length, port):
|
||||
# get pwd
|
||||
return subprocess.Popen([
|
||||
"docker", "run",
|
||||
"--rm",
|
||||
"--gpus=all",
|
||||
"--ipc=host",
|
||||
"--net=host",
|
||||
"--ulimit", "memlock=-1",
|
||||
"--ulimit", "stack=67108864",
|
||||
"-v", f"{os.environ.get('HOST_HOME')}/.cache/huggingface:/root/.cache/huggingface",
|
||||
"-v", f"{os.environ.get('HOST_PWD')}/ui_image/saved_model:/workspace/saved_model",
|
||||
"nvcr.io/nvidia/vllm:25.09-py3",
|
||||
"vllm", "serve",
|
||||
model_id,
|
||||
"--port", str(port),
|
||||
"--served-model-name", model_type,
|
||||
"--max-model-len", str(max_seq_length),
|
||||
"--gpu-memory-utilization", "0.45",
|
||||
"--async-scheduling",
|
||||
"--enable_prefix_caching"
|
||||
])
|
||||
|
||||
|
||||
def check_vllm_health(model_type, port):
|
||||
try :
|
||||
output = json.loads(subprocess.check_output(
|
||||
["curl", "-s", f"http://localhost:{port}/v1/models"],
|
||||
text=True
|
||||
))
|
||||
|
||||
return output["data"][0]["id"] == model_type
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def invoke_vllm_server(model_type, prompt, image, port):
|
||||
with open(image, "rb") as f:
|
||||
image = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
payload = json.dumps({
|
||||
"model": model_type,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
})
|
||||
|
||||
return requests.post(
|
||||
f"http://localhost:{port}/v1/chat/completions",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=payload
|
||||
).json()["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
def initialize_state(config):
|
||||
st.session_state["mode"] = st.session_state.get("mode", "inference")
|
||||
|
||||
st.session_state["base"] = st.session_state.get("base", {})
|
||||
st.session_state["finetuned"] = st.session_state.get("finetuned", {})
|
||||
|
||||
st.session_state["base"]["port"] = st.session_state["base"].get("port", "8000")
|
||||
st.session_state["finetuned"]["port"] = st.session_state["finetuned"].get("port", "8001")
|
||||
|
||||
if st.session_state["mode"] == "inference":
|
||||
st.session_state["base"]["process"] = start_vllm_server(
|
||||
config["model_id"], "base", config["max_seq_length"], st.session_state["base"]["port"])
|
||||
finetuned_model_path = get_last_checkpoint(config["finetuned_model_id"])
|
||||
if finetuned_model_path is not None:
|
||||
st.session_state["finetuned"]["process"] = start_vllm_server(
|
||||
finetuned_model_path, "finetuned", config["max_seq_length"], st.session_state["finetuned"]["port"])
|
||||
|
||||
if not check_vllm_health("base", st.session_state["base"]["port"]):
|
||||
with st.spinner("Loading vLLM server for base model..."):
|
||||
while not check_vllm_health("base", st.session_state["base"]["port"]):
|
||||
time.sleep(1)
|
||||
st.toast("Base model loaded", icon="✅", duration="short")
|
||||
|
||||
if finetuned_model_path is not None:
|
||||
if not check_vllm_health("finetuned", st.session_state["finetuned"]["port"]):
|
||||
with st.spinner("Loading vLLM server for finetuned model..."):
|
||||
while not check_vllm_health("finetuned", st.session_state["finetuned"]["port"]):
|
||||
time.sleep(1)
|
||||
st.toast("Finetuned model loaded", icon="✅", duration="short")
|
||||
|
||||
st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[-1])
|
||||
st.session_state["train_process"] = st.session_state.get("train_process", None)
|
||||
|
||||
|
||||
def main():
|
||||
# set page ui
|
||||
st.title("Image VLM Finetuning")
|
||||
st.caption("A DGX Spark showcase for on-device VLM finetuning")
|
||||
# st.page_link("https://github.com/your-username/your-repo", label="GitHub", icon=":material/github:")
|
||||
|
||||
# load css
|
||||
with open("src/styles.css", "r") as f:
|
||||
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
||||
|
||||
# load resources
|
||||
config = load_config()
|
||||
initialize_state(config["inference"])
|
||||
|
||||
# train section
|
||||
st.markdown("---")
|
||||
train_section()
|
||||
|
||||
# inference Section
|
||||
st.markdown("---")
|
||||
inference_section()
|
||||
|
||||
|
||||
def train_section():
|
||||
st.header("GRPO Training")
|
||||
|
||||
column_1, column_2, column_3 = st.columns(3, gap="large")
|
||||
with column_1:
|
||||
finetuning_method = st.selectbox(
|
||||
"Finetuning Method:",
|
||||
["LoRA", "Full Fine-tuning"],
|
||||
)
|
||||
|
||||
# update lora config
|
||||
if finetuning_method == "LoRA":
|
||||
lora_config = st.session_state["config"]["train"]["model"]["lora_config"]
|
||||
|
||||
with column_2:
|
||||
lora_rank = st.slider(
|
||||
"LoRA Rank",
|
||||
min_value=8,
|
||||
max_value=64,
|
||||
value=lora_config["rank"],
|
||||
step=8,
|
||||
)
|
||||
|
||||
with column_3:
|
||||
lora_alpha = st.slider(
|
||||
"LoRA Alpha",
|
||||
min_value=8,
|
||||
max_value=64,
|
||||
value=lora_config["alpha"],
|
||||
step=8,
|
||||
)
|
||||
|
||||
st.session_state["config"]["train"]["model"]["lora_config"].update({
|
||||
'rank': lora_rank,
|
||||
'alpha': lora_alpha,
|
||||
})
|
||||
|
||||
# update model config based on selection
|
||||
st.session_state["config"]["train"]["model"]["use_lora"] = finetuning_method == "LoRA"
|
||||
|
||||
# update train config
|
||||
st.write("")
|
||||
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
|
||||
with column_1:
|
||||
finetune_vision_layers = st.toggle(
|
||||
"Finetune Vision Layers",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_vision_layers"])
|
||||
|
||||
with column_2:
|
||||
finetune_language_layers = st.toggle(
|
||||
"Finetune Language Layers",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_language_layers"])
|
||||
|
||||
with column_3:
|
||||
finetune_attention_modules = st.toggle(
|
||||
"Finetune Attention Modules",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_attention_modules"])
|
||||
|
||||
with column_4:
|
||||
finetune_mlp_modules = st.toggle(
|
||||
"Finetune MLP Modules",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_mlp_modules"])
|
||||
|
||||
st.write("")
|
||||
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
|
||||
with column_1:
|
||||
steps = st.slider(
|
||||
"Steps",
|
||||
min_value=1,
|
||||
max_value=1000,
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["steps"])
|
||||
|
||||
with column_2:
|
||||
batch_size = st.select_slider(
|
||||
"Batch Size",
|
||||
options=[1, 2, 4, 8, 16],
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["batch_size"])
|
||||
|
||||
with column_3:
|
||||
learning_rate = st.number_input(
|
||||
"Learning Rate",
|
||||
min_value=1e-6,
|
||||
max_value=1e-2,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["learning_rate"]),
|
||||
format="%.2e")
|
||||
|
||||
with column_4:
|
||||
optimizer = st.selectbox(
|
||||
"Optimizer",
|
||||
options=["adamw_torch", "adafactor"])
|
||||
|
||||
st.session_state["config"]["train"]["hyperparameters"].update({
|
||||
'steps': steps,
|
||||
'batch_size': batch_size,
|
||||
'learning_rate': learning_rate,
|
||||
'optimizer': optimizer,
|
||||
})
|
||||
|
||||
st.session_state["config"]["train"]["model"].update({
|
||||
'finetune_vision_layers': finetune_vision_layers,
|
||||
'finetune_language_layers': finetune_language_layers,
|
||||
'finetune_attention_modules': finetune_attention_modules,
|
||||
'finetune_mlp_modules': finetune_mlp_modules,
|
||||
})
|
||||
|
||||
st.write("")
|
||||
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
|
||||
with column_1:
|
||||
enable_grpo = st.toggle(
|
||||
"Enable GRPO",
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["enable_grpo"],
|
||||
disabled=True)
|
||||
|
||||
with column_2:
|
||||
format_reward = st.number_input(
|
||||
"Reward for reasoning format",
|
||||
min_value=0.0,
|
||||
max_value=5.0,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["format_reward"]),
|
||||
format="%.2f")
|
||||
|
||||
with column_3:
|
||||
correctness_reward = st.number_input(
|
||||
"Reward for correct response",
|
||||
min_value=0.0,
|
||||
max_value=5.0,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["correctness_reward"]),
|
||||
format="%.2f")
|
||||
|
||||
with column_4:
|
||||
num_generations = st.number_input(
|
||||
"Number of generations",
|
||||
min_value=1,
|
||||
max_value=16,
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"])
|
||||
|
||||
# Training control
|
||||
st.write("")
|
||||
column_1, column_2, column_3 = st.columns([4, 4, 1])
|
||||
|
||||
with column_1:
|
||||
button_type = "secondary" if st.session_state["train_process"] else "primary"
|
||||
if st.button("▶️ Start Finetuning", type=button_type, width="stretch", disabled=bool(st.session_state["train_process"])):
|
||||
if st.session_state["train_process"] is None:
|
||||
st.session_state["base"]["process"].terminate()
|
||||
st.session_state["base"]["process"].wait()
|
||||
st.session_state["base"]["process"] = None
|
||||
if "finetuned" in st.session_state and "process" in st.session_state["finetuned"]:
|
||||
st.session_state["finetuned"]["process"].terminate()
|
||||
st.session_state["finetuned"]["process"].wait()
|
||||
st.session_state["finetuned"]["process"] = None
|
||||
st.session_state["mode"] = "train"
|
||||
st.cache_resource.clear()
|
||||
|
||||
# store config
|
||||
with open("src/train.yaml", "w") as f:
|
||||
yaml.dump(st.session_state["config"]["train"], f, default_flow_style=False)
|
||||
|
||||
# start training
|
||||
with open("/tmp/logs.txt", "w") as f:
|
||||
st.session_state["train_process"] = subprocess.Popen(
|
||||
["python", "-u", "src/train_image_vlm.py"],
|
||||
stdout=f,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True
|
||||
)
|
||||
st.toast("Training started", icon="✅", duration="short")
|
||||
else:
|
||||
st.toast("Training already in progress", icon="❌", duration="short")
|
||||
|
||||
with column_2:
|
||||
button_type = "primary" if st.session_state["train_process"] else "secondary"
|
||||
if st.button("⏹️ Stop Finetuning", type=button_type, width="stretch", disabled=not bool(st.session_state["train_process"])):
|
||||
if st.session_state["train_process"] is not None:
|
||||
st.session_state["train_process"].terminate()
|
||||
st.session_state["train_process"].wait()
|
||||
st.session_state["train_process"] = None
|
||||
st.session_state["mode"] = "inference"
|
||||
st.toast("Training stopped", icon="✅", duration="short")
|
||||
st.rerun()
|
||||
else:
|
||||
st.toast("No training to stop", icon="❌", duration="short")
|
||||
|
||||
with column_3:
|
||||
badge_holder = st.empty()
|
||||
|
||||
# create empty holders
|
||||
columns = st.columns(4)
|
||||
with columns[0]:
|
||||
steps_holder = st.empty()
|
||||
with columns[1]:
|
||||
format_reward_holder = st.empty()
|
||||
with columns[2]:
|
||||
correctness_reward_holder = st.empty()
|
||||
with columns[3]:
|
||||
total_reward_holder = st.empty()
|
||||
df_holder = st.empty()
|
||||
|
||||
# parse grpo logs
|
||||
if st.session_state["train_process"] is not None:
|
||||
while True:
|
||||
output = open("/tmp/logs.txt", "r").read().strip()
|
||||
|
||||
logs = []
|
||||
for line in output.split("\n"):
|
||||
if "{" in line and "}" in line:
|
||||
dict_match = re.search(r"\{[^}]+\}", line)
|
||||
if dict_match:
|
||||
log_dict = eval(dict_match.group())
|
||||
if isinstance(log_dict, dict) and any(k in log_dict for k in [
|
||||
"rewards/format_reward_func/mean",
|
||||
"rewards/correctness_reward_func/mean",
|
||||
"reward",
|
||||
]):
|
||||
logs.append(log_dict)
|
||||
|
||||
df = pd.DataFrame(logs)
|
||||
if "reward" in df.columns:
|
||||
steps_holder.metric("Steps", f"{len(df)}" if len(df) > 0 else "N/A")
|
||||
format_reward_holder.metric("Format Reward", f"{df['rewards/format_reward_func/mean'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
|
||||
correctness_reward_holder.metric("Correctness Reward", f"{df['rewards/correctness_reward_func/mean'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
|
||||
total_reward_holder.metric("Total Reward", f"{df['reward'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
|
||||
|
||||
badge_holder.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch")
|
||||
else:
|
||||
badge_holder.badge("Loading", icon=":material/hourglass_empty:", color="yellow", width="stretch")
|
||||
|
||||
df_holder.dataframe(df, width="stretch", hide_index=True)
|
||||
time.sleep(1)
|
||||
|
||||
if st.session_state["train_process"] is None or st.session_state["train_process"].poll() is not None:
|
||||
st.session_state["train_process"].terminate()
|
||||
st.session_state["train_process"].wait()
|
||||
st.session_state["train_process"] = None
|
||||
st.session_state["mode"] = "inference"
|
||||
st.toast("Training stopped", icon="✅", duration="short")
|
||||
st.rerun()
|
||||
|
||||
else:
|
||||
badge_holder.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch")
|
||||
|
||||
|
||||
def inference_section():
|
||||
st.header("Image Inference")
|
||||
|
||||
columns = st.columns([3, 3, 1, 2])
|
||||
with columns[1]:
|
||||
with st.container(border=True, horizontal_alignment="center", vertical_alignment="center"):
|
||||
image_holder = st.empty()
|
||||
image_holder.image(st.session_state["current_image"])
|
||||
|
||||
with columns[3]:
|
||||
if st.button("🎲 Test another sample"):
|
||||
while True:
|
||||
current_image = random.choice(glob.glob("assets/image_vlm/images/*/*"))
|
||||
if current_image != st.session_state["current_image"]:
|
||||
break
|
||||
st.session_state["current_image"] = current_image
|
||||
image_holder.image(st.session_state["current_image"])
|
||||
|
||||
columns = st.columns(2, gap="small")
|
||||
with columns[0]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Base Qwen2.5-VL-7B]")
|
||||
base_generation = st.empty()
|
||||
base_generation.write("...")
|
||||
|
||||
with columns[1]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Finetuned Qwen2.5-VL-7B]")
|
||||
finetuned_generation = st.empty()
|
||||
finetuned_generation.write("...")
|
||||
|
||||
columns = st.columns([9, 1], gap="small")
|
||||
with columns[0]:
|
||||
prompt = st.text_input(
|
||||
"Prompt Input",
|
||||
label_visibility="collapsed",
|
||||
key="prompt_input",
|
||||
on_change=lambda: st.session_state.update(prompt=st.session_state["prompt_input"])
|
||||
)
|
||||
|
||||
with columns[1]:
|
||||
if st.button("Generate", width="stretch"):
|
||||
if st.session_state.get("prompt", None):
|
||||
st.session_state["prompt"] = prompt
|
||||
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("base")
|
||||
base_generation.markdown(response)
|
||||
|
||||
if "finetuned" in st.session_state and "process" in st.session_state["finetuned"]:
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("finetuned")
|
||||
finetuned_generation.markdown(response)
|
||||
else:
|
||||
finetuned_generation.markdown("```No response since there is no finetuned model```")
|
||||
|
||||
|
||||
def start_inference(model_type):
|
||||
prompt = st.session_state["prompt"]
|
||||
if model_type == "finetuned":
|
||||
prompt = (
|
||||
f"{prompt}. Also first provide your reasoning or working out"\
|
||||
f" on how you would go about identifying the presence of wildfire affected regions between {REASONING_START} and {REASONING_END}"
|
||||
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
|
||||
)
|
||||
|
||||
response = invoke_vllm_server(
|
||||
model_type,
|
||||
prompt,
|
||||
st.session_state["current_image"],
|
||||
st.session_state[model_type]["port"]
|
||||
)
|
||||
|
||||
# format response
|
||||
if model_type == "finetuned":
|
||||
response = response.replace(REASONING_START, "```")
|
||||
response = response.replace(REASONING_END, "```")
|
||||
|
||||
# Handle solution formatting with proper newline handling
|
||||
solution_pattern = f'{re.escape(SOLUTION_START)}(.*?){re.escape(SOLUTION_END)}'
|
||||
solution_match = re.search(solution_pattern, response, re.DOTALL)
|
||||
if solution_match:
|
||||
solution_content = solution_match.group(1).strip()
|
||||
response = re.sub(solution_pattern, f"**{solution_content}**", response, flags=re.DOTALL)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,229 +0,0 @@
|
||||
# Image VLM Fine-tuning with Qwen2.5-VL
|
||||
|
||||
This project demonstrates fine-tuning Vision-Language Models (VLMs) for image understanding tasks, specifically using the Qwen2.5-VL-7B model for wildfire detection from satellite imagery using GRPO (Generalized Reward Preference Optimization).
|
||||
|
||||
## Overview
|
||||
|
||||
The project includes:
|
||||
- **Interactive Training Interface**: Streamlit-based UI for configuring and monitoring VLM fine-tuning
|
||||
- **GRPO Training**: Advanced preference optimization for better reasoning capabilities
|
||||
- **Multiple Fine-tuning Methods**: Support for LoRA and Full Finetuning
|
||||
- **Side-by-side Inference using vLLM**: Run the base model and fine-tuned model side-by-side to compare performance
|
||||
|
||||
## Contents
|
||||
1. [Model Download](#1-model-download)
|
||||
2. [Dataset Preparation](#2-dataset-preparation)
|
||||
3. [Base Model Inference](#3-base-model-inference)
|
||||
4. [GRPO Finetuning](#4-grpo-finetuning)
|
||||
5. [Finetuned Model Inference](#5-finetuned-model-inference)
|
||||
|
||||
## 1. Model Download
|
||||
|
||||
> **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/assets/README.md`.
|
||||
|
||||
### 1.1 Download the pre-trained model
|
||||
|
||||
```bash
|
||||
hf download Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
### 1.2 (Optional) Download the fine-tuned model
|
||||
|
||||
If you already have a fine-tuned checkpoint, place it in the `saved_model/` folder. Your directory structure should look something like this. Note that your checkpoint number can be different.
|
||||
|
||||
```
|
||||
saved_model/
|
||||
└── checkpoint-3/
|
||||
├── config.json
|
||||
├── generation_config.json
|
||||
├── model.safetensors.index.json
|
||||
├── model-00001-of-00004.safetensors
|
||||
├── model-00002-of-00004.safetensors
|
||||
├── model-00003-of-00004.safetensors
|
||||
├── model-00004-of-00004.safetensors
|
||||
├── preprocessor_config.json
|
||||
├── special_tokens_map.json
|
||||
├── tokenizer_config.json
|
||||
├── tokenizer.json
|
||||
├── merges.txt
|
||||
└── vocab.json
|
||||
```
|
||||
|
||||
If you already have a finetuned checkpoint that you would like to just use for a comparative analysis against the base model, skip directly to the [Finetuned Model Inference](#5-finetuned-model-inference) section.
|
||||
|
||||
## 2. Dataset Preparation
|
||||
|
||||
The project uses a **Wildfire Detection Dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes:
|
||||
- Satellite and aerial imagery from wildfire-affected areas
|
||||
- Binary classification: wildfire vs no wildfire
|
||||
|
||||
### 2.1 Create a dataset folder
|
||||
|
||||
```bash
|
||||
mkdir -p ui_image/data
|
||||
cd ui_image/data
|
||||
```
|
||||
|
||||
### 2.2 Dataset Download
|
||||
|
||||
For this finetuning playbook, we will use the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) from Kaggle. Visit the kaggle dataset page [here](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) to click the download button. Select the `cURL` option in the `Download Via` dropdown and copy the curl command.
|
||||
|
||||
> **Note**: You will need to be logged into Kaggle and may need to accept the dataset terms before the download link works.
|
||||
|
||||
Run the following commands in your container:
|
||||
|
||||
```bash
|
||||
# Past and run the curl command from Kaggle here, and then continue to unzip the dataset
|
||||
|
||||
unzip -qq wildfire-prediction-dataset.zip
|
||||
rm wildfire-prediction-dataset.zip
|
||||
cd ..
|
||||
```
|
||||
|
||||
## 3. Base Model Inference
|
||||
|
||||
Before we start finetuning, let's start spin up the demo UI to evaluate the base model's performance on this task.
|
||||
|
||||
### 3.1 Spin up the Streamlit demo
|
||||
|
||||
```bash
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
### 3.2 Wait for demo spin-up
|
||||
|
||||
When you access the streamlit demo for the first time, the backend triggers vLLM servers to spin up for the base model. You will see a spinner on the demo site as vLLM is being brought up for optimized inference. This step can take upto 15 mins.
|
||||
|
||||
After the streamlit demo is fully loaded, you should be able to see a similar UI state that is ready for inference.
|
||||
|
||||
<figure>
|
||||
<img src="assets/inference_page.png" alt="Inference Page" width="1000"/>
|
||||
<figcaption>Inference demo on the UI</figcaption>
|
||||
</figure>
|
||||
|
||||
### 3.3 Run base model inference
|
||||
|
||||
Since we are currently focused on inferring the base model, let's scroll down to the `Image Inference` section of the UI. Here, you should see a sample pre-loaded satellite image of a potentially wildfire-affected region.
|
||||
|
||||
Enter your prompt in the chat box and hit `Generate`. Your prompt would be first sent to the base model and you should see the generation response on the left chat box. If you did not provide a finetuned model, you should not see any generations from the right chat box.
|
||||
|
||||
As you can see, the base model is incapable of providing the right response for this domain-specific task. Let's try to improve the model's accuracy by performing GRPO finetuning.
|
||||
|
||||
## 4. GRPO Finetuning
|
||||
|
||||
We will perform GRPO finetuning to add reasoning capabilities to our base model and improve the model's understanding to the underlying domain. Considering that you have already spun up the streamlit demo, scroll to the `GRPO Training section`.
|
||||
|
||||
<figure>
|
||||
<img src="assets/training_page.png" alt="Training Page" width="1000"/>
|
||||
<figcaption>Training parameters on the UI</figcaption>
|
||||
</figure>
|
||||
|
||||
### 4.1 Model Settings
|
||||
|
||||
Configure the finetuning method and lora parameters based on the following options.
|
||||
|
||||
- `Finetuning Method`: Choose from Full Finetuning or LoRA
|
||||
- `LoRA Parameters`: Adjustable rank (8-64) and alpha (8-64)
|
||||
|
||||
### 4.1 Finetune layers
|
||||
|
||||
You can additionally choose whether the layers you want to finetune in the VLM. For the best performance, ensure that all options are toggled on. Note that this will increase the model training time as well.
|
||||
|
||||
### 4.2 Training parameters
|
||||
|
||||
In this section, we can select certain model parameters as relevant to our training run.
|
||||
|
||||
- `Steps`: 1-1000
|
||||
- `Batch Size`: 1, 2, 4, 8, or 16
|
||||
- `Learning Rate`: 1e-6 to 1e-2
|
||||
- `Optimizer`: AdamW or Adafactor
|
||||
|
||||
### 4.3 GRPO settings
|
||||
|
||||
For a GRPO setup, we also have the flexibility in choosing the reward that is assigned to the model based on certain criteria
|
||||
|
||||
- `Format Reward`: 2.0 (reward for proper reasoning format)
|
||||
- `Correctness Reward`: 5.0 (reward for correct answers)
|
||||
- `Number of Generations`: 4 (for preference optimization)
|
||||
|
||||
### 4.4 Start training
|
||||
|
||||
After configuring all the parameters, hit `Start Finetuning` to begin the training process. You will need to wait about 15 mins for the model to load and start recording metadata on the UI. As the training progresses, information such as the loss, step, and GRPO rewards will be recorded on a live table.
|
||||
|
||||
The default loaded configuration should give you reasonable accuracy, taking 100 steps of training over a period of upto 2 hours. We achieved our best accuracies with around 1000 steps of training, taking close to 16 hours.
|
||||
|
||||
After training is complete, the script automatically merges lora weights into the base model. After the training process has reached the desired number of training steps, it can take 5 mins to merge the lora weights.
|
||||
|
||||
### 4.5 Stop training
|
||||
|
||||
If you wish to stop training, just hit the `Stop Finetuning` button. Please use this button only to interrupt training. This button does not guarantee that the checkpoints will be properly stored or merged with lora adapter layers.
|
||||
|
||||
Once you stop training, the UI will automatically bring up the vLLM servers for the base model and the newly finetuned model.
|
||||
|
||||
## 5. Finetuned Model Inference
|
||||
|
||||
Now we are ready to perform a comparative analysis between the base model and the finetuned model.
|
||||
|
||||
### 5.1 (Optional) Spin up the Streamlit demo
|
||||
|
||||
If you haven't spun up the streamlit demo already, execute the following command. If had just just stopped training and are still within the live UI, skip to the next step.
|
||||
|
||||
```bash
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
### 5.2 vLLM startup
|
||||
|
||||
Regardless of whether you just spun up the demo or just stopped training, please wait about 15 mins for the vLLM servers to be brought up.
|
||||
|
||||
### 5.3 Run finetuned model inference
|
||||
|
||||
Scroll down to the `Image Inference` section, and enter your prompt in the provided chat box. Upon clicking `Generate`, your prompt would be first sent to the base model and then to the finetuned model. You can use the following prompt to quickly test inference
|
||||
|
||||
`Identify if this region has been affected by a wildfire`
|
||||
|
||||
If you trained your model sufficiently enough, you should see that the finetuned model is able to perform reasoning and provide a concise, accurate answer to the prompt. The reasoning steps are provided in the markdown format, while the final answer is bolded and provided at the end of the model's response.
|
||||
|
||||
For the image shown below, we have trained the model for 1000 steps, which took about 16 hours.
|
||||
|
||||
### 5.4 Further analysis
|
||||
|
||||
If you wish to play around with these models with additional images, the `Test another sample` button will load another random satellite image.
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
ui_image/
|
||||
├── Image_VLM_Finetuning.py # Main Streamlit application
|
||||
├── README.md # This file
|
||||
├── src/
|
||||
│ ├── image_vlm_config.yaml # Configuration file (update finetuned_model_id after training)
|
||||
│ └── styles.css # Custom UI styling
|
||||
├── assets/
|
||||
│ └── image_vlm/
|
||||
│ └── images/
|
||||
│ ├── wildfire/ # Wildfire-affected images
|
||||
│ └── nowildfire/ # Non-wildfire images
|
||||
├── assets/
|
||||
│ └── inference_screenshot.png # UI demonstration screenshot
|
||||
└── saved_model/ # Training checkpoints directory (update config to point here)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you are facing VRAM issues where the model fails to load or offloads to cpu/meta device, ensure you bring down all docker containers and flush out dangling memory.
|
||||
|
||||
```bash
|
||||
docker ps
|
||||
|
||||
docker rm <CONTAINER_ID_1>
|
||||
docker rm <CONTAINER_ID_2>
|
||||
|
||||
docker system prune
|
||||
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
|
Before Width: | Height: | Size: 18 KiB |
|
Before Width: | Height: | Size: 36 KiB |
|
Before Width: | Height: | Size: 27 KiB |
|
Before Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 753 KiB |
|
Before Width: | Height: | Size: 1.2 MiB |
|
Before Width: | Height: | Size: 358 KiB |
@ -1,16 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
@ -1,49 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
inference:
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
finetuned_model_id: saved_model
|
||||
max_seq_length: 8192
|
||||
|
||||
train:
|
||||
model:
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
max_seq_length: 16384
|
||||
use_lora: true
|
||||
lora_config:
|
||||
rank: 32
|
||||
alpha: 64
|
||||
dropout: 0.05
|
||||
finetune_vision_layers: true
|
||||
finetune_language_layers: true
|
||||
finetune_attention_modules: true
|
||||
finetune_mlp_modules: true
|
||||
|
||||
data:
|
||||
dataset_id: data
|
||||
|
||||
hyperparameters:
|
||||
steps: 100
|
||||
batch_size: 4
|
||||
enable_grpo: true
|
||||
num_generations: 2
|
||||
format_reward: 2.0
|
||||
learning_rate: 1e-5
|
||||
correctness_reward: 5.0
|
||||
optimizer: adamw_torch
|
||||
output_dir: saved_model
|
||||
@ -1,178 +0,0 @@
|
||||
/*
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
*/
|
||||
|
||||
/* VLM Fine-tuning Streamlit App Styles */
|
||||
:root {
|
||||
--nvidia-green: #76b900;
|
||||
}
|
||||
|
||||
/* Set maximum page width - multiple selectors for compatibility */
|
||||
.main .block-container,
|
||||
.block-container,
|
||||
.main > div,
|
||||
section.main > div {
|
||||
max-width: 1200px !important;
|
||||
margin: 0 auto !important;
|
||||
padding-left: 1rem !important;
|
||||
padding-right: 1rem !important;
|
||||
}
|
||||
|
||||
/* h3 {
|
||||
color:#76b900 !important;
|
||||
} */
|
||||
|
||||
.main > div {
|
||||
padding-top: 1rem;
|
||||
}
|
||||
|
||||
/* Global button styling - applies to ALL buttons */
|
||||
.stButton > button,
|
||||
button[kind="primary"],
|
||||
button[kind="secondary"] {
|
||||
width: 100% !important;
|
||||
font-size: 16px !important;
|
||||
font-weight: 600 !important;
|
||||
padding: 12px 20px !important;
|
||||
border-radius: 8px !important;
|
||||
border: none !important;
|
||||
transition: all 0.3s ease !important;
|
||||
cursor: pointer !important;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
|
||||
}
|
||||
|
||||
/* Primary button styling */
|
||||
.stButton > button[kind="primary"],
|
||||
button[kind="primary"] {
|
||||
background: linear-gradient(135deg, #32864c 0%, #6edd93 100%) !important;
|
||||
/* background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; */
|
||||
color: white !important;
|
||||
}
|
||||
|
||||
/* Primary button hover effect */
|
||||
.stButton > button[kind="primary"]:hover,
|
||||
button[kind="primary"]:hover {
|
||||
transform: translateY(-2px) !important;
|
||||
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
|
||||
}
|
||||
|
||||
/* Secondary button styling */
|
||||
.stButton > button[kind="secondary"],
|
||||
button[kind="secondary"] {
|
||||
background: linear-gradient(135deg, #b6fb93 0%, #57f586 100%) !important;
|
||||
/* background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; */
|
||||
/* color: white !important; */
|
||||
color: darkslategray !important;
|
||||
}
|
||||
|
||||
/* Regular button styling */
|
||||
.stButton > button:not([kind]) {
|
||||
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%) !important;
|
||||
color: white !important;
|
||||
}
|
||||
|
||||
/* Button hover effect for all buttons */
|
||||
.stButton > button:hover {
|
||||
transform: translateY(-2px) !important;
|
||||
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
|
||||
}
|
||||
|
||||
/* Button active/pressed state */
|
||||
.stButton > button:active {
|
||||
transform: translateY(0px) !important;
|
||||
}
|
||||
|
||||
/* Metric card styling */
|
||||
.metric-card {
|
||||
background-color: #f0f2f6;
|
||||
padding: 1rem;
|
||||
border-radius: 0.5rem;
|
||||
border-left: 4px solid #ff6b6b;
|
||||
}
|
||||
|
||||
/* Center images and maintain aspect ratio */
|
||||
.stImage > img {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 500px;
|
||||
height: auto;
|
||||
max-height: 300px
|
||||
}
|
||||
|
||||
/* Larger text input for questions */
|
||||
.stTextInput > div > div > input {
|
||||
font-size: 18px !important;
|
||||
padding: 10px !important;
|
||||
}
|
||||
|
||||
/* Larger inference button */
|
||||
.inference-button button,
|
||||
.inference-button .stButton > button {
|
||||
padding: 20px !important;
|
||||
height: 200px !important;
|
||||
min-height: 100% !important;
|
||||
max-height: 100% !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
line-height: 1.2 !important;
|
||||
white-space: nowrap !important;
|
||||
font-size: large !important;
|
||||
}
|
||||
|
||||
|
||||
.card {
|
||||
background-color: #1F2022;
|
||||
color: #bdc3c7;
|
||||
padding: 1.5rem;
|
||||
border-radius: 0.5rem;
|
||||
/* outline: 2px solid var(--nvidia-green); */
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.10);
|
||||
text-align: left;
|
||||
height: 300px;
|
||||
overflow-y: auto;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.card--empty {
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
height: 300px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
/* Wandb iframe wrapper for cropping */
|
||||
.wandb-wrapper {
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
border-radius: 15px;
|
||||
height: var(--visible-height, 320px);
|
||||
}
|
||||
|
||||
/* Wandb iframe styling */
|
||||
.wandb-iframe {
|
||||
border: none;
|
||||
width: 100%;
|
||||
position: absolute;
|
||||
left: 0;
|
||||
height: var(--iframe-height, 600px);
|
||||
top: var(--shift-up, -230px);
|
||||
}
|
||||
@ -1,24 +0,0 @@
|
||||
data:
|
||||
dataset_id: data
|
||||
hyperparameters:
|
||||
batch_size: 4
|
||||
correctness_reward: 5.0
|
||||
enable_grpo: true
|
||||
format_reward: 2.0
|
||||
learning_rate: 1.0e-05
|
||||
num_generations: 2
|
||||
optimizer: adamw_torch
|
||||
output_dir: saved_model
|
||||
steps: 5
|
||||
model:
|
||||
finetune_attention_modules: true
|
||||
finetune_language_layers: true
|
||||
finetune_mlp_modules: true
|
||||
finetune_vision_layers: true
|
||||
lora_config:
|
||||
alpha: 64
|
||||
dropout: 0.05
|
||||
rank: 32
|
||||
max_seq_length: 16384
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
use_lora: true
|
||||
@ -1,228 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
import re
|
||||
import sys
|
||||
import yaml
|
||||
import shutil
|
||||
import signal
|
||||
|
||||
from PIL import ImageFile
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
REASONING_START = "<REASONING>"
|
||||
REASONING_END = "</REASONING>"
|
||||
SOLUTION_START = "<SOLUTION>"
|
||||
SOLUTION_END = "</SOLUTION>"
|
||||
|
||||
|
||||
def load_model_for_train(config):
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name=config["model"]["model_id"],
|
||||
max_seq_length=config["model"]["max_seq_length"],
|
||||
load_in_4bit=False,
|
||||
)
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers=config["model"]["finetune_vision_layers"],
|
||||
finetune_language_layers=config["model"]["finetune_language_layers"],
|
||||
finetune_attention_modules=config["model"]["finetune_attention_modules"],
|
||||
finetune_mlp_modules=config["model"]["finetune_mlp_modules"],
|
||||
r=config["model"]["lora_config"]["rank"],
|
||||
lora_alpha=config["model"]["lora_config"]["alpha"],
|
||||
lora_dropout=config["model"]["lora_config"]["dropout"],
|
||||
bias="none",
|
||||
random_state=42,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def format_instruction(sample, label_dict):
|
||||
label = label_dict.int2str(sample["label"])
|
||||
if label == "nowildfire":
|
||||
answer = "No"
|
||||
else:
|
||||
answer = "Yes"
|
||||
|
||||
# reasoning prompt
|
||||
prompt = "Identify if this region has been affected by a wildfire"
|
||||
prompt = (
|
||||
f"{prompt}. Also first provide your reasoning or working out"\
|
||||
f" on how you would go about identifying the presence of wildfire affected regions between {REASONING_START} and {REASONING_END}"
|
||||
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
|
||||
)
|
||||
|
||||
# convert image format
|
||||
image = sample["image"]
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# construct instruction prompt
|
||||
prompt = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
return {"prompt": prompt, "image": sample["image"], "answer": answer}
|
||||
|
||||
|
||||
def load_wildfire_dataset(config, tokenizer):
|
||||
# load dataset
|
||||
train_dataset = load_dataset(config["data"]["dataset_id"])["train"]
|
||||
|
||||
# preprocess the dataset
|
||||
train_dataset = train_dataset.map(
|
||||
lambda sample: format_instruction(sample, train_dataset.features["label"]),
|
||||
num_proc=8)
|
||||
train_dataset = train_dataset.map(
|
||||
lambda sample: {
|
||||
"prompt": tokenizer.apply_chat_template(
|
||||
sample["prompt"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
}, num_proc=8)
|
||||
|
||||
return train_dataset
|
||||
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
thinking_pattern = f'{REASONING_START}(.*?){REASONING_END}'
|
||||
answer_pattern = f'{SOLUTION_START}(.*?){SOLUTION_END}'
|
||||
|
||||
scores = []
|
||||
for completion in completions:
|
||||
score = 0
|
||||
thinking_matches = re.findall(thinking_pattern, completion, re.DOTALL)
|
||||
answer_matches = re.findall(answer_pattern, completion, re.DOTALL)
|
||||
if len(thinking_matches) == 1:
|
||||
score += 1.0
|
||||
if len(answer_matches) == 1:
|
||||
score += 1.0
|
||||
|
||||
# penalize excessive addCriterion predictions in qwen2.5-vl
|
||||
if len(completion) != 0:
|
||||
removal = completion.replace("addCriterion", "").replace("\n", "")
|
||||
if (len(completion)-len(removal))/len(completion) >= 0.5:
|
||||
score -= 1.0
|
||||
|
||||
scores.append(score)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
||||
answer_pattern = f'{SOLUTION_START}(.*?){SOLUTION_END}'
|
||||
|
||||
responses = [re.findall(answer_pattern, completion, re.DOTALL) for completion in completions]
|
||||
q = prompts[0]
|
||||
print("----------------------------------")
|
||||
print(f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{completions[0]}")
|
||||
return [
|
||||
5.0 if len(r)==1 and a == r[0].replace('\n','') else 0.0
|
||||
for r, a in zip(responses, answer)
|
||||
]
|
||||
|
||||
|
||||
def start_train(config):
|
||||
|
||||
# load dataset
|
||||
train_dataset = load_wildfire_dataset(config, tokenizer)
|
||||
|
||||
# define training arguments
|
||||
training_args = GRPOConfig(
|
||||
learning_rate=config["hyperparameters"]["learning_rate"],
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
optim="adamw_torch",
|
||||
logging_steps=1,
|
||||
log_completions=False,
|
||||
per_device_train_batch_size=config["hyperparameters"]["batch_size"],
|
||||
gradient_accumulation_steps=1,
|
||||
num_generations=config["hyperparameters"]["num_generations"],
|
||||
max_prompt_length=config["model"]["max_seq_length"],
|
||||
max_completion_length=config["model"]["max_seq_length"],
|
||||
max_steps=config["hyperparameters"]["steps"],
|
||||
save_steps=5,
|
||||
save_total_limit=2,
|
||||
max_grad_norm=0.1,
|
||||
report_to="none",
|
||||
output_dir=config["hyperparameters"]["output_dir"],
|
||||
importance_sampling_level="sequence",
|
||||
mask_truncated_completions=False,
|
||||
loss_type="dr_grpo",
|
||||
)
|
||||
|
||||
# start training
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
reward_funcs=[
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
handle_termination(None, None)
|
||||
|
||||
|
||||
def handle_termination(signum, frame):
|
||||
latest_checkpoint = get_last_checkpoint(config["hyperparameters"]["output_dir"])
|
||||
if latest_checkpoint:
|
||||
if config["model"]["use_lora"]:
|
||||
print("Merging LoRA weights and saving the model")
|
||||
shutil.rmtree(latest_checkpoint)
|
||||
model.save_pretrained_merged(latest_checkpoint, tokenizer, save_method="merged_16bit")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGTERM, handle_termination)
|
||||
signal.signal(signal.SIGINT, handle_termination)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("src/train.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# load base model for finetuning
|
||||
model, tokenizer = load_model_for_train(config)
|
||||
|
||||
start_train(config)
|
||||
@ -1,3 +0,0 @@
|
||||
[theme]
|
||||
base="dark"
|
||||
greenTextColor = "#76b900"
|
||||
@ -1,262 +0,0 @@
|
||||
# Video VLM Fine-tuning with InternVL3
|
||||
|
||||
This project builds on top of the image VLM fine-tuning recipe to extend to the video modality. The notebook demonstrates how to fine-tuning the InternVL3 model for domain specific video analysis. For this prototype example, we have used driving dashcam footage from the [Nexar Scap Dataset](nexar-ai/nexar_collision_prediction) dataset to generate structured data which will be used for fine-tuning.
|
||||
|
||||
## Workflow Overview
|
||||
|
||||
<figure>
|
||||
<img src="assets/training_video.png" alt="Workflow Overview" width="1000"/>
|
||||
<figcaption>Video VLM fine-tuning Workflow Overview</figcaption>
|
||||
</figure>
|
||||
|
||||
A typical workflow for a video fine-tuning includes the following:
|
||||
1. **Data Collection**: Collect raw footage/videos for a domain specific task. If the videos are very long, chunck them into reasonable sized files, for instance 5 sec duration.
|
||||
2. **Generate Structured caption**: Collect structured caption for each video either using human generate dlabels or a larger VLM.
|
||||
3. **Train InternVL3 Model**: Perform Supervised Finetuning on InternVL3-8B to extract structured metadata
|
||||
4. **Inference**: The fine-tuned model is noe ready for analysing domain specific videos.
|
||||
|
||||
## Contents
|
||||
1. [Dataset Preparation](#2-dataset-preparation)
|
||||
2. [Model Download](#2-model-download)
|
||||
3. [Base Model Inference](#3-base-model-inference)
|
||||
4. [SFT Finetuning](#4-sft-finetuning)
|
||||
5. [Finetuned Model Inference](#5-finetuned-model-inference)
|
||||
|
||||
## 1. Dataset Preparation
|
||||
|
||||
### 1.1 Data Source
|
||||
Identify a video data source which would benefit from structured data analysis. The videos can be either live footage or shorter video clips. In our case, we have chosen the [Nexar Scap Dataset](nexar-ai/nexar_collision_prediction).
|
||||
|
||||
### 1.2 Caption Schema
|
||||
Based on the structured metedata that you would like to analyze from your video dataset, come up with a caption schema that can concisely capture your requirements. In our case, we have used the following schema.
|
||||
|
||||
```json
|
||||
{
|
||||
"video": "videos/video1.mp4",
|
||||
"caption": "Description of the video events",
|
||||
"event_type": "collision" | "near_miss" | "no_incident",
|
||||
"rule_violations": choose relevant items from ["speeding", "failure_to_yield", "ignoring_traffic_signs"],
|
||||
"intended_driving_action": "turn_left" | "turn_right" | "change_lanes",
|
||||
"traffic_density": "low" | "high",
|
||||
"visibility": "good" | "bad"
|
||||
}
|
||||
```
|
||||
|
||||
### 1.3 Caption Generation
|
||||
With the cpation schema decided, we must now generate groundtruth, structured caption for all our videos. This can be achieved either by leveraging a larger VLM for AI-assisted annotation or human labellers to manually caption.
|
||||
|
||||
### 1.4 Dataset structure
|
||||
|
||||
```
|
||||
# Enter the correct directory
|
||||
cd ui_video
|
||||
```
|
||||
|
||||
Place all your videos in `dataset/videos`. Additionally, the captions should be placed inside the `metadata.jsonl`.
|
||||
|
||||
Your dataset should be structured as follows:
|
||||
```
|
||||
dataset/
|
||||
├── videos/
|
||||
│ ├── video1.mp4
|
||||
│ ├── video2.mp4
|
||||
│ └── ...
|
||||
└── metadata.jsonl
|
||||
```
|
||||
|
||||
Your `metadata.jsonl` should look like this.
|
||||
|
||||
```
|
||||
{"video": ..., "caption": ..., "event_type": ...}
|
||||
{"video": ..., "caption": ..., "event_type": ...}
|
||||
{"video": ..., "caption": ..., "event_type": ...}
|
||||
```
|
||||
|
||||
## 2. Model Download
|
||||
|
||||
> **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/assets/README.md`.
|
||||
|
||||
### 2.1 Download the pre-trained model
|
||||
|
||||
```bash
|
||||
hf download OpenGVLab/InternVL3-8B
|
||||
```
|
||||
|
||||
### 2.2 (Optional) Download the fine-tuned model
|
||||
|
||||
If you already have a fine-tuned checkpoint, place it in the `saved_model/` folder. Your directory structure should look something like this. Note that your checkpoint number can be different.
|
||||
|
||||
```
|
||||
saved_model/
|
||||
└── checkpoint-3/
|
||||
├── config.json
|
||||
├── generation_config.json
|
||||
├── model.safetensors.index.json
|
||||
├── model-00001-of-00004.safetensors
|
||||
├── model-00002-of-00004.safetensors
|
||||
├── model-00003-of-00004.safetensors
|
||||
├── model-00004-of-00004.safetensors
|
||||
├── preprocessor_config.json
|
||||
├── special_tokens_map.json
|
||||
├── tokenizer_config.json
|
||||
├── tokenizer.json
|
||||
├── merges.txt
|
||||
└── vocab.json
|
||||
```
|
||||
|
||||
If you already have a finetuned checkpoint that you would like to just use for a comparative analysis against the base model, skip directly to the [Finetuned Model Inference](#5-finetuned-model-inference) section.
|
||||
|
||||
## 3. Base Model Inference
|
||||
|
||||
Before going ahead to finetune our video VLM for this task, let's see how the base InternVL3-8B does.
|
||||
|
||||
### 3.1 Spin up the Streamlit demo
|
||||
|
||||
```bash
|
||||
# cd into vlm_finetuning/assets/ui_video if you haven't already
|
||||
streamlit run Video_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
### 3.2 Wait for demo spin-up
|
||||
|
||||
When you access the streamlit demo for the first time, the backend triggers Huggingface to spin up the base model. You will see a spinner on the demo site as the model is being loaded, which can take upto 10 mins.
|
||||
|
||||
### 3.3 Run base model inference
|
||||
|
||||
First, let's select a video from our dashcam gallery. Upon clicking the green file open icon near a video, you should see the video render and play automatically for our reference.
|
||||
|
||||
Scroll down, enter your prompt in the chat box and hit `Generate`. Your prompt would be first sent to the base model and you should see the generation response on the left chat box. If you did not provide a finetuned model, you should not see any generations from the right chat box.
|
||||
|
||||
<figure>
|
||||
<img src="assets/inference_screenshot.png" alt="Inference Screenshot" width="1000"/>
|
||||
<figcaption>Base model inference on the UI</figcaption>
|
||||
</figure>
|
||||
|
||||
As you can see, the base model is incapable of identifying the right events for this domain-specific task. Even if the base model can sometimes identify these events, it still only converts one form of unstructured data to another format of unstructured data. We cannot conduct reasonable data analytics for insights on large-scale video footage. Let's try to improve the model's accuracy and structured caption ability by performing SFT training.
|
||||
|
||||
If you are proceeding to train a finetuned model, ensure that the streamlit demo UI is brought down before proceeding to train. You can bring it by interrupting the terminal with `Ctrl+C` keystroke.
|
||||
|
||||
> **Note**: To clear out any extra occupied memory from your system, execute the following command outside the container after interrupting the ComfyUI server.
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
|
||||
## 4. SFT Finetuning
|
||||
|
||||
We will perform SFT finetuning to improve the quality of the base model and generate schema-adhering structured output.
|
||||
|
||||
### 4.1 Load the jupyter notebook
|
||||
|
||||
```bash
|
||||
# Inside the container, navigate to the training directory
|
||||
cd train
|
||||
jupyter notebook video_vlm.ipynb
|
||||
```
|
||||
|
||||
### 4.2 Train the model
|
||||
|
||||
Follow the instructions in the jupyter notebook to perform SFT finetuning on a video VLM. Ensure that you set the path to your dataset correctly in the appropriate cell.
|
||||
|
||||
```python
|
||||
dataset_path = "/path/to/your/dataset"
|
||||
```
|
||||
|
||||
### 4.3 Training Configuration
|
||||
|
||||
Here are some of the key training parameters that are configurable. Please note that for reasonable quality, you will need to train your video VLM for atleast 24 hours given the complexity of processing spatio-temporal video sequences.
|
||||
|
||||
- **Model**: InternVL3-8B
|
||||
- **Video Frames**: 12 to 16 frames per video
|
||||
- **Sampling Mode**: Uniform temporal sampling
|
||||
- **LoRA Configuration**: Efficient parameter updates for large-scale fine-tuning
|
||||
- **Hyperparameters**: Exhaustive suite of hyperparameters to tune for video VLM finetuning
|
||||
|
||||
### 4.4 Monitor Training
|
||||
|
||||
You can monitor and evaluate the training progress and metrics, as they will be continuously shown in the notebook.
|
||||
|
||||
### 4.5 Shutdown
|
||||
|
||||
After training, ensure that you shutdown the jupyter kernel in the notebook and kill the jupyter server in the terminal with a `Ctrl+C` keystroke.
|
||||
|
||||
> **Note**: To clear out any extra occupied memory from your system, execute the following command outside the container after interrupting the ComfyUI server.
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
|
||||
## 5. Finetuned Model Inference
|
||||
|
||||
Now we are ready to perform a comparative analysis between the base model and the finetuned model.
|
||||
|
||||
### 5.1 (Optional) Spin up the Streamlit demo
|
||||
|
||||
If you haven't spun up the streamlit demo already, execute the following command. If had just just stopped training and are still within the live UI, skip to the next step.
|
||||
|
||||
```bash
|
||||
streamlit run Video_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
### 5.2 Wait for demo spin-up
|
||||
|
||||
When you access the streamlit demo for the first time, the backend triggers Huggingface to spin up the base model. You will see a spinner on the demo site as the model is being loaded, which can take upto 10 mins.
|
||||
|
||||
### 5.3 Run finetuned model inference
|
||||
|
||||
Scroll down to the `Video Inference` section, and enter your prompt in the provided chat box. Upon clicking `Generate`, your prompt would be first sent to the base model and then to the finetuned model. You can use the following prompt to quickly test inference
|
||||
|
||||
`Analyze the dashcam footage for unsafe driver behavior`
|
||||
|
||||
If you trained your model sufficiently enough, you should see that the finetuned model is able to identify the salient events from the video and generate a structured output.
|
||||
|
||||
### 5.4 Further analysis
|
||||
|
||||
Since the model's output adheres to the schema we trained, we can directly export the model's prediction into a database for video analytics. For the image shown below, we have trained the model for over 24 hours.
|
||||
|
||||
<figure>
|
||||
<img src="assets/finetuned_screenshot.png" alt="Finetuned Screenshot" width="1000"/>
|
||||
<figcaption>Finetuned model inference on the UI</figcaption>
|
||||
</figure>
|
||||
|
||||
Feel free to play around with additional videos available in the gallery.
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
ui_video/
|
||||
├── README.md # This file
|
||||
├── Video_VLM.py # Streamlit web interface for inference
|
||||
├── src/
|
||||
│ ├── styles.css # CSS styling for Streamlit app
|
||||
│ └── video_vlm_config.yaml # Model and inference configuration
|
||||
├── train/
|
||||
│ └── video_vlm.ipynb # Jupyter notebook for model training
|
||||
└── assets/
|
||||
└── video_vlm/
|
||||
├── videos/ # Sample video files
|
||||
└── thumbnails/ # Video thumbnail previews
|
||||
|
||||
# Root directory also contains:
|
||||
├── Dockerfile # Multi-stage Docker build with FFmpeg/Decord
|
||||
└── launch.sh # Docker launch script
|
||||
```
|
||||
# Training checkpoints directory (update config to point here)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you are facing VRAM issues where the model fails to load or offloads to cpu/meta device, ensure you bring down all docker containers and flush out dangling memory.
|
||||
|
||||
```bash
|
||||
docker ps
|
||||
|
||||
docker rm <CONTAINER_ID_1>
|
||||
docker rm <CONTAINER_ID_2>
|
||||
|
||||
docker system prune
|
||||
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
@ -1,407 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import yaml
|
||||
import string
|
||||
import random
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import streamlit as st
|
||||
import torchvision.transforms as T
|
||||
from decord import VideoReader, cpu
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
SCAP_PROMPT = """You are a vision-language assistant analyzing driving videos. You will receive a 5-second video clip of a specific scene. {prompt}
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Dense Caption
|
||||
Generate a 2 sentence caption describing:
|
||||
- Ego vehicle behavior
|
||||
- Interactions with other vehicles or pedestrians
|
||||
|
||||
Focus on **what happens**, **when**, and **who/what is involved**, using only visible information and metadata.
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Structured JSON
|
||||
Generate the caption from the perspective of the ego vehicle in a structured JSON object with:
|
||||
|
||||
- `"caption"`: from Task 1
|
||||
- `"event_type"`: "collision" | "near_miss" | "no_incident"
|
||||
- `"rule_violations"`: choose relevant items from ["speeding", "failure_to_yield", "ignoring_traffic_signs"]
|
||||
- `"intended_action"`: "turn_left" | "turn_right" | "change_lanes"
|
||||
- `"traffic_density"`: "low" | "high"
|
||||
- `"visibility"`: "good" | "bad"
|
||||
- `"scene"`: "Urban" | "Sub-urban" | "Rural" | "Highway"
|
||||
|
||||
**Rules:**
|
||||
1. Use only visible info and metadata.
|
||||
2. Do not invent details.
|
||||
3. Include all fields; enum values must match allowed options.
|
||||
4. Output a single valid JSON object—no extra text or markdown.
|
||||
"""
|
||||
|
||||
|
||||
def random_id():
|
||||
chars = string.ascii_letters + string.digits
|
||||
return "".join(random.choices(chars, k=8)).lower()
|
||||
|
||||
|
||||
def initialize_session_state(base_model, finetuned_model):
|
||||
# Initialize page-specific session state
|
||||
st.session_state["base_video_vlm"] = st.session_state.get("base_video_vlm", base_model)
|
||||
st.session_state["finetuned_video_vlm"] = st.session_state.get("finetuned_video_vlm", finetuned_model)
|
||||
st.session_state["current_sample"] = st.session_state.get("current_sample", None)
|
||||
st.session_state["df"] = st.session_state.get("df",
|
||||
pd.DataFrame(columns=[
|
||||
"Driver ID",
|
||||
"Event Type",
|
||||
"Rule Violations",
|
||||
"Intended Action",
|
||||
"Traffic Density",
|
||||
"Driving Scene",
|
||||
"Visibility"]))
|
||||
|
||||
|
||||
def load_config():
|
||||
config_key = "config_video_vlm"
|
||||
if getattr(st.session_state, config_key, None) is None:
|
||||
with open("src/video_vlm_config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
setattr(st.session_state, config_key, config)
|
||||
else:
|
||||
config = getattr(st.session_state, config_key)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def initialize_model(model_path):
|
||||
model = InternVLModel(model_path)
|
||||
return {"model": model}
|
||||
|
||||
|
||||
def main():
|
||||
# set page ui
|
||||
st.title(":green[Video VLM on DGX Spark]")
|
||||
st.caption("Driver Behavior Analysis via Structured Video Captioning")
|
||||
# st.page_link("https://github.com/your-username/your-repo", label="GitHub", icon=":material/github:")
|
||||
|
||||
# load css
|
||||
with open("src/styles.css", "r") as f:
|
||||
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
||||
|
||||
# load resources
|
||||
config = load_config()
|
||||
if st.session_state.get("base_video_vlm", None) is None:
|
||||
st.toast("Loading model", icon="⏳", duration="short")
|
||||
base_model = initialize_model(config["inference"]["model_id"])
|
||||
finetuned_model_path = get_last_checkpoint(config["inference"]["finetuned_model_id"])
|
||||
if finetuned_model_path is not None:
|
||||
finetuned_model = initialize_model(finetuned_model_path)
|
||||
else:
|
||||
finetuned_model = {"model": None}
|
||||
if st.session_state.get("base_video_vlm", None) is None:
|
||||
st.toast("Model loaded", icon="✅", duration="short")
|
||||
initialize_session_state(base_model, finetuned_model)
|
||||
|
||||
# gallery section
|
||||
st.markdown("---")
|
||||
st.header("Dashcam Gallery")
|
||||
|
||||
columns = st.columns([4, 1, 1, 4, 1, 1, 4, 1], gap="small")
|
||||
|
||||
with columns[0]:
|
||||
st.image("assets/video_vlm/thumbnails/1.png")
|
||||
|
||||
with columns[1]:
|
||||
if st.button(":material/file_open:", key="video_1"):
|
||||
st.session_state["current_sample"] = "assets/video_vlm/videos/1.mp4"
|
||||
|
||||
with columns[3]:
|
||||
st.image("assets/video_vlm/thumbnails/2.png")
|
||||
|
||||
with columns[4]:
|
||||
if st.button(":material/file_open:", key="video_2"):
|
||||
st.session_state["current_sample"] = "assets/video_vlm/videos/2.mp4"
|
||||
|
||||
with columns[6]:
|
||||
st.image("assets/video_vlm/thumbnails/3.png")
|
||||
|
||||
with columns[7]:
|
||||
if st.button(":material/file_open:", key="video_3"):
|
||||
st.session_state["current_sample"] = "assets/video_vlm/videos/3.mp4"
|
||||
|
||||
# inference section
|
||||
st.markdown("---")
|
||||
st.header("Video Inference")
|
||||
|
||||
with st.container(border=True):
|
||||
if st.session_state["current_sample"]:
|
||||
st.video(st.session_state["current_sample"], autoplay=True, loop=True, muted=True)
|
||||
else:
|
||||
st.write(":gray[Please select a video from the dashcam gallery.]")
|
||||
|
||||
columns = st.columns(2, gap="small")
|
||||
with columns[0]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Base InternVL3-8B]")
|
||||
base_generation = st.empty()
|
||||
base_generation.write("...")
|
||||
|
||||
with columns[1]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Finetuned InternVL3-8B]")
|
||||
finetuned_generation = st.empty()
|
||||
finetuned_generation.write("...")
|
||||
|
||||
columns = st.columns([9, 1], gap="small")
|
||||
with columns[0]:
|
||||
prompt = st.text_input(
|
||||
"Prompt Input",
|
||||
label_visibility="collapsed",
|
||||
key="prompt_input",
|
||||
on_change=lambda: st.session_state.update(prompt=st.session_state.prompt_input)
|
||||
)
|
||||
|
||||
with columns[1]:
|
||||
if st.button("Generate", width="stretch"):
|
||||
if st.session_state.get("prompt", None):
|
||||
st.session_state["prompt"] = prompt
|
||||
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("base")
|
||||
base_generation.markdown(response)
|
||||
|
||||
if st.session_state["finetuned_video_vlm"].get("model", None) is not None:
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("finetuned")
|
||||
finetuned_generation.markdown(response)
|
||||
|
||||
response = json.loads(response[7: -3].strip())
|
||||
response["caption"] = random_id() # replace caption with driver id
|
||||
st.session_state["df"].loc[len(st.session_state["df"])] = list(response.values())
|
||||
else:
|
||||
finetuned_generation.markdown("```No response since there is no finetuned model```")
|
||||
|
||||
# data analysis section
|
||||
st.markdown("---")
|
||||
st.header("Data Analysis")
|
||||
|
||||
with st.container(border=True):
|
||||
st.dataframe(st.session_state["df"])
|
||||
|
||||
|
||||
class InternVLVideoProcessor:
|
||||
def __init__(self):
|
||||
self.frame_size = 448
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Resize(self.frame_size, interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * self.frame_size * self.frame_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def dynamic_preprocess(self, frame, min_num=1, max_num=12, use_thumbnail=False):
|
||||
orig_width, orig_height = frame.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, self.frame_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = self.frame_size * target_aspect_ratio[0]
|
||||
target_height = self.frame_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = frame.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // self.frame_size)) * self.frame_size,
|
||||
(i // (target_width // self.frame_size)) * self.frame_size,
|
||||
((i % (target_width // self.frame_size)) + 1) * self.frame_size,
|
||||
((i // (target_width // self.frame_size)) + 1) * self.frame_size
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = frame.resize((self.frame_size, self.frame_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
def load_video(self, video_path, num_frames, start_frame=None, end_frame=None):
|
||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
|
||||
if start_frame is None:
|
||||
start_frame = 0
|
||||
if end_frame is None:
|
||||
end_frame = len(vr) - 1
|
||||
|
||||
# sample a random number of equally-spaced frames from the video
|
||||
frame_indices = np.linspace(
|
||||
start_frame,
|
||||
end_frame,
|
||||
num_frames,
|
||||
dtype=int
|
||||
)
|
||||
|
||||
pixel_values_list, num_patches_list = [], []
|
||||
for frame_index in frame_indices:
|
||||
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
|
||||
img = self.dynamic_preprocess(img, use_thumbnail=True, max_num=1)
|
||||
pixel_values = [self.transform(tile) for tile in img]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
num_patches_list.append(pixel_values.shape[0])
|
||||
pixel_values_list.append(pixel_values)
|
||||
pixel_values = torch.cat(pixel_values_list)
|
||||
return pixel_values, num_patches_list
|
||||
|
||||
|
||||
class InternVLModel:
|
||||
def __init__(self, model_path):
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=False,
|
||||
).eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
use_fast=False
|
||||
)
|
||||
|
||||
self.processor = InternVLVideoProcessor()
|
||||
self.generation_config = dict(
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
temperature=0.0,
|
||||
num_beams=1,
|
||||
top_k=1,
|
||||
top_p=1.0)
|
||||
|
||||
@torch.no_grad()
|
||||
def infer(self, video_path, prompt, sampling_mode, num_frames=32, chunk_duration=2):
|
||||
if sampling_mode == "default":
|
||||
return self._infer_default(video_path, prompt, num_frames)
|
||||
elif sampling_mode == "real-time":
|
||||
return self._infer_realtime(video_path, prompt, num_frames, chunk_duration)
|
||||
|
||||
def _infer_default(self, video_path, prompt, num_frames):
|
||||
pixel_values, num_patches_list = self.processor.load_video(video_path, num_frames)
|
||||
pixel_values = pixel_values.to(device=self.model.device, dtype=torch.bfloat16)
|
||||
|
||||
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
|
||||
prompt = f"{video_prefix}{prompt}"
|
||||
|
||||
response = self.model.chat(
|
||||
self.tokenizer,
|
||||
pixel_values,
|
||||
prompt,
|
||||
self.generation_config,
|
||||
num_patches_list=num_patches_list
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _infer_realtime(self, video_path, prompt, num_frames, chunk_duration):
|
||||
video = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
fps = video.get_avg_fps()
|
||||
total_frames = len(video)
|
||||
frames_per_chunk = int(chunk_duration * fps)
|
||||
|
||||
for start_frame in range(0, total_frames, frames_per_chunk):
|
||||
end_frame = start_frame + frames_per_chunk
|
||||
if end_frame > total_frames:
|
||||
return
|
||||
|
||||
pixel_values, num_patches_list = self.processor.load_video(video_path, num_frames, start_frame, end_frame)
|
||||
pixel_values = pixel_values.to(device=self.model.device, dtype=torch.bfloat16)
|
||||
|
||||
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
|
||||
prompt = f"{video_prefix}{prompt}"
|
||||
|
||||
response = self.model.chat(
|
||||
self.tokenizer,
|
||||
pixel_values,
|
||||
prompt,
|
||||
self.generation_config,
|
||||
num_patches_list=num_patches_list
|
||||
)
|
||||
|
||||
yield response
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def start_inference(model_type):
|
||||
# define prompt
|
||||
prompt = st.session_state["prompt"]
|
||||
if model_type == "finetuned":
|
||||
prompt = SCAP_PROMPT.format(prompt=prompt)
|
||||
|
||||
print(print(model_type))
|
||||
response = st.session_state[f"{model_type}_video_vlm"]["model"].infer(
|
||||
st.session_state["current_sample"],
|
||||
prompt,
|
||||
num_frames=st.session_state["config_video_vlm"]["inference"]["num_frames"],
|
||||
sampling_mode=st.session_state["config_video_vlm"]["inference"]["sampling_mode"]
|
||||
)
|
||||
|
||||
if model_type == "finetuned":
|
||||
response = f"```json\n{json.dumps(json.loads(response), indent=2)}\n```"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Before Width: | Height: | Size: 1.5 MiB |
|
Before Width: | Height: | Size: 2.0 MiB |
|
Before Width: | Height: | Size: 554 KiB |
|
Before Width: | Height: | Size: 1.2 MiB |
|
Before Width: | Height: | Size: 1.1 MiB |
|
Before Width: | Height: | Size: 392 KiB |
@ -1,178 +0,0 @@
|
||||
/*
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
*/
|
||||
|
||||
/* VLM Fine-tuning Streamlit App Styles */
|
||||
:root {
|
||||
--nvidia-green: #76b900;
|
||||
}
|
||||
|
||||
/* Set maximum page width - multiple selectors for compatibility */
|
||||
.main .block-container,
|
||||
.block-container,
|
||||
.main > div,
|
||||
section.main > div {
|
||||
max-width: 1200px !important;
|
||||
margin: 0 auto !important;
|
||||
padding-left: 1rem !important;
|
||||
padding-right: 1rem !important;
|
||||
}
|
||||
|
||||
/* h3 {
|
||||
color:#76b900 !important;
|
||||
} */
|
||||
|
||||
.main > div {
|
||||
padding-top: 1rem;
|
||||
}
|
||||
|
||||
/* Global button styling - applies to ALL buttons */
|
||||
.stButton > button,
|
||||
button[kind="primary"],
|
||||
button[kind="secondary"] {
|
||||
width: 100% !important;
|
||||
font-size: 16px !important;
|
||||
font-weight: 600 !important;
|
||||
padding: 12px 20px !important;
|
||||
border-radius: 8px !important;
|
||||
border: none !important;
|
||||
transition: all 0.3s ease !important;
|
||||
cursor: pointer !important;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
|
||||
}
|
||||
|
||||
/* Primary button styling */
|
||||
.stButton > button[kind="primary"],
|
||||
button[kind="primary"] {
|
||||
background: linear-gradient(135deg, #32864c 0%, #6edd93 100%) !important;
|
||||
/* background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; */
|
||||
color: white !important;
|
||||
}
|
||||
|
||||
/* Primary button hover effect */
|
||||
.stButton > button[kind="primary"]:hover,
|
||||
button[kind="primary"]:hover {
|
||||
transform: translateY(-2px) !important;
|
||||
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
|
||||
}
|
||||
|
||||
/* Secondary button styling */
|
||||
.stButton > button[kind="secondary"],
|
||||
button[kind="secondary"] {
|
||||
background: linear-gradient(135deg, #b6fb93 0%, #57f586 100%) !important;
|
||||
/* background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; */
|
||||
/* color: white !important; */
|
||||
color: darkslategray !important;
|
||||
}
|
||||
|
||||
/* Regular button styling */
|
||||
.stButton > button:not([kind]) {
|
||||
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%) !important;
|
||||
color: white !important;
|
||||
}
|
||||
|
||||
/* Button hover effect for all buttons */
|
||||
.stButton > button:hover {
|
||||
transform: translateY(-2px) !important;
|
||||
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
|
||||
}
|
||||
|
||||
/* Button active/pressed state */
|
||||
.stButton > button:active {
|
||||
transform: translateY(0px) !important;
|
||||
}
|
||||
|
||||
/* Metric card styling */
|
||||
.metric-card {
|
||||
background-color: #f0f2f6;
|
||||
padding: 1rem;
|
||||
border-radius: 0.5rem;
|
||||
border-left: 4px solid #ff6b6b;
|
||||
}
|
||||
|
||||
/* Center images and maintain aspect ratio */
|
||||
.stImage > img {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 500px;
|
||||
height: auto;
|
||||
max-height: 300px
|
||||
}
|
||||
|
||||
/* Larger text input for questions */
|
||||
.stTextInput > div > div > input {
|
||||
font-size: 18px !important;
|
||||
padding: 10px !important;
|
||||
}
|
||||
|
||||
/* Larger inference button */
|
||||
.inference-button button,
|
||||
.inference-button .stButton > button {
|
||||
padding: 20px !important;
|
||||
height: 200px !important;
|
||||
min-height: 100% !important;
|
||||
max-height: 100% !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
line-height: 1.2 !important;
|
||||
white-space: nowrap !important;
|
||||
font-size: large !important;
|
||||
}
|
||||
|
||||
|
||||
.card {
|
||||
background-color: #1F2022;
|
||||
color: #bdc3c7;
|
||||
padding: 1.5rem;
|
||||
border-radius: 0.5rem;
|
||||
/* outline: 2px solid var(--nvidia-green); */
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.10);
|
||||
text-align: left;
|
||||
height: 300px;
|
||||
overflow-y: auto;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.card--empty {
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
height: 300px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
/* Wandb iframe wrapper for cropping */
|
||||
.wandb-wrapper {
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
border-radius: 15px;
|
||||
height: var(--visible-height, 320px);
|
||||
}
|
||||
|
||||
/* Wandb iframe styling */
|
||||
.wandb-iframe {
|
||||
border: none;
|
||||
width: 100%;
|
||||
position: absolute;
|
||||
left: 0;
|
||||
height: var(--iframe-height, 600px);
|
||||
top: var(--shift-up, -230px);
|
||||
}
|
||||
@ -1,22 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
inference:
|
||||
model_id: OpenGVLab/InternVL3-8B
|
||||
finetuned_model_id: saved_model
|
||||
num_frames: 12
|
||||
sampling_mode: default
|
||||
@ -1,767 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "31e8ca53",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Copyright Notice\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n",
|
||||
"SPDX-License-Identifier: Apache-2.0\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"you may not use this file except in compliance with the License.\n",
|
||||
"You may obtain a copy of the License at\n",
|
||||
"\n",
|
||||
"http://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software\n",
|
||||
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"See the License for the specific language governing permissions and\n",
|
||||
"limitations under the License.\n",
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ae293c8d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# VLM-Finetuning for Large Scale Data Analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0c471d1d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 📓 Notebook Overview \n",
|
||||
"In this notebook, we show how to train a VLM to genreate structured metadata about videos. The use case we target is using a VLM to analyze driving videos to generate json-formatted video descriptions and metadata like risky maneuvers to understand dangerous driving patterns. \n",
|
||||
"\n",
|
||||
"This is just one example, this workflow can be generalized to any large scale video data analysis task where it's helpful to have structured metadata and automated video analysis."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6d02f831",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Intitialize Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d06618bc",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"logging.basicConfig(\n",
|
||||
" level=logging.INFO,\n",
|
||||
" format=\"%(asctime)s - %(levelname)s - %(message)s\"\n",
|
||||
")\n",
|
||||
"logger = logging.getLogger(__name__)\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import gc\n",
|
||||
"import sys\n",
|
||||
"import json\n",
|
||||
"import random\n",
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"from PIL import Image\n",
|
||||
"from decord import VideoReader\n",
|
||||
"from tensorboard import program\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers.utils import hub\n",
|
||||
"import torchvision.transforms as T\n",
|
||||
"from huggingface_hub import snapshot_download\n",
|
||||
"from trl import SFTTrainer, SFTConfig\n",
|
||||
"from transformers.trainer_pt_utils import LabelSmoother\n",
|
||||
"from transformers.trainer_utils import get_last_checkpoint\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor\n",
|
||||
"from torchvision.transforms.functional import InterpolationMode"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f3cc9f01",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set seed for reproducibility\n",
|
||||
"def set_seed(seed):\n",
|
||||
" random.seed(seed)\n",
|
||||
" np.random.seed(seed)\n",
|
||||
" torch.manual_seed(seed)\n",
|
||||
" if torch.cuda.is_available():\n",
|
||||
" torch.cuda.manual_seed_all(seed)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"set_seed(42)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ef437e65",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set some constants\n",
|
||||
"seq_len = 8192\n",
|
||||
"model_name = \"OpenGVLab/InternVL3-8B\"\n",
|
||||
"ignore_token_id = LabelSmoother.ignore_index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0bf4800f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Model Loading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a4e2e572",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
||||
" model_name, \n",
|
||||
" trust_remote_code=True,\n",
|
||||
" use_fast=True\n",
|
||||
")\n",
|
||||
"tokenizer.padding_side = \"right\"\n",
|
||||
"\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" model_name,\n",
|
||||
" torch_dtype=torch.bfloat16,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" trust_remote_code=True,\n",
|
||||
" use_flash_attn=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the processor\n",
|
||||
"processor = AutoProcessor.from_pretrained(\n",
|
||||
" model_name,\n",
|
||||
" trust_remote_code=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b2c1476a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Data Processing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d3458cd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Action Required**: Please update the `dataset_path` with a path to your local dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "05d6cd17",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_path = \"path/to/your/dataset\"\n",
|
||||
"if os.path.exists(dataset_path):\n",
|
||||
" print('true')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "34fad55c",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load conversation template\n",
|
||||
"model_dir = hub.cached_file(model_name, \"conversation.py\", trust_remote_code=True)\n",
|
||||
"sys.path.append(os.path.dirname(model_dir))\n",
|
||||
"\n",
|
||||
"from conversation import get_conv_template"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f98733c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_jsonl_data(sample): \n",
|
||||
" \"\"\" Data processing helper: Define the labels based on the desired format we hope to teach \n",
|
||||
" the model to generate \n",
|
||||
" \"\"\"\n",
|
||||
" answer_dict = {\n",
|
||||
" \"caption\": sample['caption'],\n",
|
||||
" \"event_type\": sample['event_type'],\n",
|
||||
" \"rule_violations\": sample['rule_violations'],\n",
|
||||
" \"intended_action\": sample['intended_action'],\n",
|
||||
" \"traffic_density\": sample['traffic_density'],\n",
|
||||
" \"scene\": sample['scene'],\n",
|
||||
" \"visibility\": sample['visibility'],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" return json.dumps(answer_dict, ensure_ascii=False) # create a single line, valid JSON"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84d2f75f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load the Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bfb85f7c",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the dataset\n",
|
||||
"dataset = load_dataset(dataset_path)\n",
|
||||
"dataset = dataset['train']\n",
|
||||
"dataset = dataset.map(lambda ex: {\"labels\": get_jsonl_data(ex)})\n",
|
||||
"ds_splits = dataset.train_test_split(test_size=0.01, seed=42)\n",
|
||||
"train_dataset, val_dataset = ds_splits['train'], ds_splits['test']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bca23383",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6b98dbe9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9299e759",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Data Visualization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8a95631c",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_video_path(dataset_path, sample_path):\n",
|
||||
" \"\"\" Dataset speciic helper function -- this appends the sample path to the root path to create the full video path \"\"\"\n",
|
||||
" root_dir = dataset_path.split('/')[:-1]\n",
|
||||
" video_path = '/'.join(root_dir) + '/' + sample_path\n",
|
||||
" return video_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dcb4dfba",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"min_frames, max_frames = 8, 32\n",
|
||||
"\n",
|
||||
"# load frames from a video\n",
|
||||
"def load_video(video_path):\n",
|
||||
" video = VideoReader(video_path, num_threads=1)\n",
|
||||
"\n",
|
||||
" # sample a random number of equally-spaced frames from the video\n",
|
||||
" frame_indices = np.linspace(\n",
|
||||
" 0,\n",
|
||||
" len(video) - 1,\n",
|
||||
" random.randint(min_frames, max_frames),\n",
|
||||
" dtype=int\n",
|
||||
" )\n",
|
||||
" frames = video.get_batch(frame_indices).asnumpy()\n",
|
||||
" return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d388155a",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"video_path = get_video_path(dataset_path, train_dataset[0]['video'])\n",
|
||||
"display(load_video(video_path)[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2b21a463",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Data Processing\n",
|
||||
"Functions taken from [InternVL3 Documentation](https://internvl.readthedocs.io/en/latest/internvl3.0/quick_start.html#inference-with-transformers)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8f8e4c46",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# preprocessing code directly adopted HF model card\n",
|
||||
"IMAGENET_MEAN = (0.485, 0.456, 0.406)\n",
|
||||
"IMAGENET_STD = (0.229, 0.224, 0.225)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def build_transform(input_size):\n",
|
||||
" MEAN, STD = IMAGENET_MEAN, IMAGENET_STD\n",
|
||||
" transform = T.Compose([\n",
|
||||
" T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n",
|
||||
" T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),\n",
|
||||
" T.ToTensor(),\n",
|
||||
" T.Normalize(mean=MEAN, std=STD)\n",
|
||||
" ])\n",
|
||||
" return transform\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n",
|
||||
" best_ratio_diff = float('inf')\n",
|
||||
" best_ratio = (1, 1)\n",
|
||||
" area = width * height\n",
|
||||
" for ratio in target_ratios:\n",
|
||||
" target_aspect_ratio = ratio[0] / ratio[1]\n",
|
||||
" ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n",
|
||||
" if ratio_diff < best_ratio_diff:\n",
|
||||
" best_ratio_diff = ratio_diff\n",
|
||||
" best_ratio = ratio\n",
|
||||
" elif ratio_diff == best_ratio_diff:\n",
|
||||
" if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n",
|
||||
" best_ratio = ratio\n",
|
||||
" return best_ratio\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):\n",
|
||||
" orig_width, orig_height = image.size\n",
|
||||
" aspect_ratio = orig_width / orig_height\n",
|
||||
"\n",
|
||||
" # calculate the existing image aspect ratio\n",
|
||||
" target_ratios = set(\n",
|
||||
" (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if\n",
|
||||
" i * j <= max_num and i * j >= min_num)\n",
|
||||
" target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n",
|
||||
"\n",
|
||||
" # find the closest aspect ratio to the target\n",
|
||||
" target_aspect_ratio = find_closest_aspect_ratio(\n",
|
||||
" aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n",
|
||||
"\n",
|
||||
" # calculate the target width and height\n",
|
||||
" target_width = image_size * target_aspect_ratio[0]\n",
|
||||
" target_height = image_size * target_aspect_ratio[1]\n",
|
||||
" blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n",
|
||||
"\n",
|
||||
" # resize the image\n",
|
||||
" resized_img = image.resize((target_width, target_height))\n",
|
||||
" processed_images = []\n",
|
||||
" for i in range(blocks):\n",
|
||||
" box = (\n",
|
||||
" (i % (target_width // image_size)) * image_size,\n",
|
||||
" (i // (target_width // image_size)) * image_size,\n",
|
||||
" ((i % (target_width // image_size)) + 1) * image_size,\n",
|
||||
" ((i // (target_width // image_size)) + 1) * image_size\n",
|
||||
" )\n",
|
||||
" # split the image\n",
|
||||
" split_img = resized_img.crop(box)\n",
|
||||
" processed_images.append(split_img)\n",
|
||||
" assert len(processed_images) == blocks\n",
|
||||
" if use_thumbnail and len(processed_images) != 1:\n",
|
||||
" thumbnail_img = image.resize((image_size, image_size))\n",
|
||||
" processed_images.append(thumbnail_img)\n",
|
||||
" return processed_images\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# build the transform and get number of tokens per image (per tile technically)\n",
|
||||
"image_size = model.config.force_image_size\n",
|
||||
"transform = build_transform(input_size=image_size)\n",
|
||||
"num_image_tokens = int((image_size // model.config.vision_config.patch_size) ** 2 * (model.config.downsample_ratio ** 2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6097f018",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define user prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e24f2a4a",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"user_prompt = \"\"\"You are a vision-language assistant analyzing driving videos. You will receive a 5-second video clip of a specific scene. \n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Task 1: Dense Caption\n",
|
||||
"Generate a 2 sentence caption describing:\n",
|
||||
"- Ego vehicle behavior\n",
|
||||
"- Interactions with other vehicles or pedestrians\n",
|
||||
"\n",
|
||||
"Focus on **what happens**, **when**, and **who/what is involved**, using only visible information and metadata.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Task 2: Structured JSON\n",
|
||||
"Generate the caption from the perspective of the ego vehicle in a structured JSON object with:\n",
|
||||
"\n",
|
||||
"- `\"caption\"`: from Task 1 \n",
|
||||
"- `\"event_type\"`: \"collision\" | \"near_miss\" | \"no_incident\" \n",
|
||||
"- `\"rule_violations\"`: choose relevant items from [\"speeding\", \"failure_to_yield\", \"ignoring_traffic_signs\"] \n",
|
||||
"- `\"intended_action\"`: \"turn_left\" | \"turn_right\" | \"change_lanes\" \n",
|
||||
"- `\"traffic_density\"`: \"low\" | \"high\" \n",
|
||||
"- `\"visibility\"`: \"good\" | \"bad\" \n",
|
||||
"- `\"scene\"`: \"Urban\" | \"Sub-urban\" | \"Rural\" | \"Highway\"\n",
|
||||
"\n",
|
||||
"**Rules:**\n",
|
||||
"1. Use only visible info and metadata. \n",
|
||||
"2. Do not invent details. \n",
|
||||
"3. Include all fields; enum values must match allowed options. \n",
|
||||
"4. Output a single valid JSON object—no extra text or markdown. \n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e82e25f",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# View a sample label to confirm it's in the format we want\n",
|
||||
"sample = train_dataset[0]\n",
|
||||
"\n",
|
||||
"answer_dict = {\n",
|
||||
" \"caption\": sample['caption'],\n",
|
||||
" \"event_type\": sample['event_type'],\n",
|
||||
" \"rule_violations\": sample['rule_violations'],\n",
|
||||
" \"intended_action\": sample['intended_action'],\n",
|
||||
" \"traffic_density\": sample['traffic_density'],\n",
|
||||
" \"scene\": sample['scene'],\n",
|
||||
" \"visibility\": sample['visibility'],\n",
|
||||
"}\n",
|
||||
"answer_jsonl = json.dumps(answer_dict, ensure_ascii=False)\n",
|
||||
"\n",
|
||||
"print(answer_jsonl)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "24b3e453",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Custom Data Preprocessing\n",
|
||||
"\n",
|
||||
"This novel data tokenization function takes in a batch of samples and tokenizes them according to what InternVL3 source code expects, returning a processed batch of features with input IDs, labels, attention masks, etc."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "59243cf6",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Collate function for the dataset that does on-the-fly preprocessing and batching\n",
|
||||
"def collate_fn(samples, tokenizer, transform, seq_len, num_image_tokens, get_conv_template, ignore_token_id, load_video, get_video_path, dataset_path):\n",
|
||||
" input_ids_batch, labels_batch, attention_mask_batch, position_ids_batch, pixel_values_batch, image_flags_batch = [], [], [], [], [], []\n",
|
||||
" for sample in samples:\n",
|
||||
" # load the video frames\n",
|
||||
" video_frames = load_video(get_video_path(dataset_path, sample['video']))\n",
|
||||
" num_frames = len(video_frames)\n",
|
||||
"\n",
|
||||
" # preprocess the video frames\n",
|
||||
" pixel_values = [transform(frame) for frame in video_frames]\n",
|
||||
" pixel_values = torch.stack(pixel_values)\n",
|
||||
" num_tiles = pixel_values.size(0)\n",
|
||||
"\n",
|
||||
" # prepend special video tokens to the user message\n",
|
||||
" video_tokens = '\\n'.join(['Frame-{}: <image>'.format(i + 1) for i in range(num_frames)])\n",
|
||||
"\n",
|
||||
" # setup conversation\n",
|
||||
" conv_template = get_conv_template(\"internvl2_5\")\n",
|
||||
"\n",
|
||||
" system_instruction = user_prompt\n",
|
||||
" answer = get_jsonl_data(sample)\n",
|
||||
"\n",
|
||||
" conv_template.append_message(conv_template.roles[0], f'{video_tokens}\\n{system_instruction}')\n",
|
||||
" conv_template.append_message(conv_template.roles[1], answer)\n",
|
||||
"\n",
|
||||
" # replace image tokens with context tokens\n",
|
||||
" prompt = conv_template.get_prompt()\n",
|
||||
" prompt = prompt.replace(\"<image>\", f\"<img>{'<IMG_CONTEXT>' * num_image_tokens}</img>\")\n",
|
||||
"\n",
|
||||
" # create a list of messages\n",
|
||||
" messages = [f\"<|im_start|>{message}\" for message in prompt.split(\"<|im_start|>\")[1: ]]\n",
|
||||
"\n",
|
||||
" # tokenize the prompt (we manually truncate and pad the sequence)\n",
|
||||
" input_ids = tokenizer(\n",
|
||||
" messages,\n",
|
||||
" return_tensors=\"np\",\n",
|
||||
" padding=False,\n",
|
||||
" max_length=seq_len,\n",
|
||||
" truncation=False,\n",
|
||||
" ).input_ids\n",
|
||||
"\n",
|
||||
" # create targets by masking out system and user messages\n",
|
||||
" # since we only want to compute loss for the assistant message\n",
|
||||
" targets = []\n",
|
||||
" num_ignore_ids = tokenizer('<|im_start|>assistant\\n', return_tensors='np').input_ids[0].shape[0]\n",
|
||||
" for idx, input_id in enumerate(input_ids):\n",
|
||||
" if idx != 2:\n",
|
||||
" targets.append(np.full(input_id.shape, ignore_token_id))\n",
|
||||
" else:\n",
|
||||
" target = input_id.copy()\n",
|
||||
" target[: num_ignore_ids] = ignore_token_id\n",
|
||||
" target[-1: ] = ignore_token_id\n",
|
||||
" targets.append(target)\n",
|
||||
"\n",
|
||||
" # prepare the input_ids and targets\n",
|
||||
" input_ids = torch.tensor(np.concatenate(input_ids))[: seq_len]\n",
|
||||
" targets = torch.tensor(np.concatenate(targets))[: seq_len]\n",
|
||||
"\n",
|
||||
" # pad the input_ids and targets to the sequence length\n",
|
||||
" pad_len = seq_len - input_ids.shape[0]\n",
|
||||
" input_ids = F.pad(input_ids, (0, pad_len), value=tokenizer.pad_token_id)\n",
|
||||
" targets = F.pad(targets, (0, pad_len), value=ignore_token_id)\n",
|
||||
"\n",
|
||||
" # generate attention mask to filter out padding tokens\n",
|
||||
" attention_mask = input_ids.ne(tokenizer.pad_token_id)\n",
|
||||
"\n",
|
||||
" position_ids = attention_mask.long().cumsum(-1) - 1\n",
|
||||
" position_ids.masked_fill_(attention_mask == 0, 1)\n",
|
||||
"\n",
|
||||
" input_ids_batch.append(input_ids)\n",
|
||||
" labels_batch.append(targets)\n",
|
||||
" attention_mask_batch.append(attention_mask)\n",
|
||||
" position_ids_batch.append(position_ids)\n",
|
||||
" pixel_values_batch.append(pixel_values)\n",
|
||||
" image_flags_batch.append(torch.tensor([1] * num_tiles, dtype=torch.long))\n",
|
||||
"\n",
|
||||
" batch = {\n",
|
||||
" \"input_ids\": torch.stack(input_ids_batch),\n",
|
||||
" \"labels\": torch.stack(labels_batch),\n",
|
||||
" \"attention_mask\": torch.stack(attention_mask_batch),\n",
|
||||
" \"position_ids\": torch.stack(position_ids_batch),\n",
|
||||
" \"pixel_values\": torch.cat(pixel_values_batch),\n",
|
||||
" \"image_flags\": torch.cat(image_flags_batch)\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" return batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6a8912c4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "36c55eb2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Set Model Config Params for Training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "59766bcf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.img_context_token_id = tokenizer.convert_tokens_to_ids(\"<IMG_CONTEXT>\")\n",
|
||||
"\n",
|
||||
"model.train()\n",
|
||||
"\n",
|
||||
"model.language_model.config.use_cache = False\n",
|
||||
"model.vision_model.gradient_checkpointing = True\n",
|
||||
"model.vision_model.encoder.gradient_checkpointing = True\n",
|
||||
"model.language_model._set_gradient_checkpointing()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b6c2dec6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define Training Params"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "686f3cfd",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"save_dir = \"../saved_model\"\n",
|
||||
"\n",
|
||||
"trainer = SFTTrainer(\n",
|
||||
" model=model,\n",
|
||||
" data_collator=collate_fn,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
" eval_dataset=val_dataset,\n",
|
||||
" processing_class=processor,\n",
|
||||
" args=SFTConfig(\n",
|
||||
" num_train_epochs=30,\n",
|
||||
" per_device_train_batch_size=1,\n",
|
||||
" per_device_eval_batch_size=1,\n",
|
||||
" eval_steps=250,\n",
|
||||
" do_eval=True,\n",
|
||||
" warmup_ratio=0.03,\n",
|
||||
" lr_scheduler_type=\"cosine\",\n",
|
||||
" eval_strategy=\"steps\",\n",
|
||||
" label_names=[\"labels\"],\n",
|
||||
" dataloader_num_workers=4,\n",
|
||||
" gradient_accumulation_steps=4,\n",
|
||||
" dataloader_persistent_workers=True,\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" weight_decay=0.05,\n",
|
||||
" logging_steps=10,\n",
|
||||
" logging_dir=\"logs\",\n",
|
||||
" save_strategy=\"steps\",\n",
|
||||
" save_steps=100,\n",
|
||||
" output_dir=save_dir,\n",
|
||||
" save_total_limit=2,\n",
|
||||
" optim=\"adamw_torch\",\n",
|
||||
" bf16=True,\n",
|
||||
" remove_unused_columns=False,\n",
|
||||
" report_to=\"wandb\",\n",
|
||||
" dataset_kwargs = {\"skip_prepare_dataset\": True},\n",
|
||||
"\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7c364912",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Train the model\n",
|
||||
"**Note:** Remove the `resume_from_checkpoint` parameter of `trainer.train()` if you don't want to resume training from a checkpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ab9fa74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"last_ckpt = get_last_checkpoint(save_dir)\n",
|
||||
"print(f\"Resuming from {last_ckpt}\")\n",
|
||||
"trainer.train(resume_from_checkpoint=last_ckpt)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"jupytext": {
|
||||
"cell_metadata_filter": "-all",
|
||||
"encoding": "# coding: utf-8",
|
||||
"executable": "/usr/bin/env python",
|
||||
"main_language": "python",
|
||||
"notebook_metadata_filter": "-all"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||