chore: Regenerate all playbooks

This commit is contained in:
GitLab CI 2025-10-06 12:57:08 +00:00
parent 6818481902
commit 24114f735c
20 changed files with 603 additions and 488 deletions

View File

@ -94,6 +94,15 @@ For the provided prompt and random seed, the base Flux model generated the follo
<figcaption>Base FLUX.1 model workflow without custom concept knowledge</figcaption> <figcaption>Base FLUX.1 model workflow without custom concept knowledge</figcaption>
</figure> </figure>
After playing around with the base model, you have 2 possible next steps.
* If you already have fine-tuned LoRAs placed inside `models/loras/`, please skip to [Load the finetuned workflow](#52-load-the-finetuned-workflow) section.
* If you wish to train a LoRA for your custom concepts, first make sure that the ComfyUI inference container is brought down before proceeding to train. You can bring it by interrupting the terminal with `Ctrl+C` keystroke.
> **Note**: To clear out any extra occupied memory from your system, execute the following command after interrupting the ComfyUI server.
```bash
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
```
## 3. Dataset Preparation ## 3. Dataset Preparation
Let's prepare our dataset to perform Dreambooth LoRA finetuning on the FLUX.1-dev 12B model. However, if you wish to continue with the provided dataset of Toy Jensen and DGX Spark, feel free to skip to the [Training](#training) section. This dataset is a collection of public assets accessible via Google Images. Let's prepare our dataset to perform Dreambooth LoRA finetuning on the FLUX.1-dev 12B model. However, if you wish to continue with the provided dataset of Toy Jensen and DGX Spark, feel free to skip to the [Training](#training) section. This dataset is a collection of public assets accessible via Google Images.
@ -137,8 +146,6 @@ Now, let's modify the `flux_data/data.toml` file to reflect the concepts chosen.
### 4.1 Build the docker image ### 4.1 Build the docker image
Make sure that the ComfyUI inference container is brought down before proceeding to train. You can bring it by interrupting the terminal with `Ctrl+C` keystroke.
```bash ```bash
# Build the inference docker image # Build the inference docker image
docker build -f Dockerfile.train -t flux-train . docker build -f Dockerfile.train -t flux-train .

View File

@ -105,9 +105,7 @@ Install the package in editable mode with metrics support for training evaluatio
pip install -e ".[metrics]" pip install -e ".[metrics]"
``` ```
## Step 5. Configure PyTorch for CUDA 12.9 (if needed) ## Step 5. Configure PyTorch for CUDA 12.9 (skip if using Docker container from Step 2)
*If using standalone Python (skip if using Docker container)*
In a python virtual environment, uninstall existing PyTorch and reinstall with CUDA 12.9 support for ARM64 architecture. In a python virtual environment, uninstall existing PyTorch and reinstall with CUDA 12.9 support for ARM64 architecture.

View File

@ -119,6 +119,7 @@ From the root directory of the multi-agent-chatbot project, run the following co
docker compose -f docker-compose.yml -f docker-compose-models.yml down docker compose -f docker-compose.yml -f docker-compose-models.yml down
docker volume rm "$(basename "$PWD")_postgres_data" docker volume rm "$(basename "$PWD")_postgres_data"
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
``` ```
You can optionally run `docker volume prune` to remove all unused volumes at the end of the demo. You can optionally run `docker volume prune` to remove all unused volumes at the end of the demo.
> **Note**: If you do not execute these commands containers, will continue to run and take up memory. > **Note**: If you do not execute these commands containers, will continue to run and take up memory.

View File

@ -474,7 +474,7 @@ class ChatAgent:
config = {"configurable": {"thread_id": chat_id}} config = {"configurable": {"thread_id": chat_id}}
try: try:
existing_messages = await self.conversation_store.get_messages(chat_id, limit=10) existing_messages = await self.conversation_store.get_messages(chat_id, limit=1)
base_system_prompt = self.system_prompt base_system_prompt = self.system_prompt
if image_data: if image_data:

View File

@ -19,7 +19,7 @@ from typing import Dict
SUPERVISOR_AGENT_STR = """ SUPERVISOR_AGENT_STR = """
You are a supervisor agent whose role is to be a helpful planner that can use tools to answer questions. DO NOT WRITE CODE YOURSELF, ALWAYS USE THE TOOLS. You are a supervisor agent whose role is to be a helpful planner that can use tools to answer questions. Please be concise and to the point.
{% if tools %} {% if tools %}
IMPORTANT: You have access to these tools and you MUST use them when applicable and use tool response in your final answer: IMPORTANT: You have access to these tools and you MUST use them when applicable and use tool response in your final answer:

View File

@ -122,9 +122,9 @@ services:
- "--host" - "--host"
- "0.0.0.0" - "0.0.0.0"
- "-n" - "-n"
- "4096" - "65536"
- "--n-gpu-layers" - "--n-gpu-layers"
- "999" - "70"
- "--jinja" - "--jinja"
deepseek-coder: deepseek-coder:

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
*/ */
import type React from "react"; import type React from "react";
import { useRef, useEffect, useState } from "react"; import { useRef, useEffect, useState, useCallback } from "react";
import styles from "@/styles/QuerySection.module.css"; import styles from "@/styles/QuerySection.module.css";
import ReactMarkdown from 'react-markdown'; // NEW import ReactMarkdown from 'react-markdown'; // NEW
import remarkGfm from 'remark-gfm'; // NEW import remarkGfm from 'remark-gfm'; // NEW
@ -188,22 +188,16 @@ export default function QuerySection({
const chatContainerRef = useRef<HTMLDivElement>(null); const chatContainerRef = useRef<HTMLDivElement>(null);
const [showButtons, setShowButtons] = useState(false); const [showButtons, setShowButtons] = useState(false);
const [showWelcome, setShowWelcome] = useState(true); const [showWelcome, setShowWelcome] = useState(true);
const [inferenceStats, setInferenceStats] = useState({
tokensReceived: 0,
startTime: Date.now(),
tokensPerSecond: 0
});
const [selectedSources, setSelectedSources] = useState<string[]>([]); const [selectedSources, setSelectedSources] = useState<string[]>([]);
const wsRef = useRef<WebSocket | null>(null); const wsRef = useRef<WebSocket | null>(null);
const [uploadedImage, setUploadedImage] = useState<string | null>(null);
const [imagePreview, setImagePreview] = useState<string | null>(null);
const [isDragging, setIsDragging] = useState(false);
const [toolOutput, setToolOutput] = useState(""); const [toolOutput, setToolOutput] = useState("");
const [graphStatus, setGraphStatus] = useState(""); const [graphStatus, setGraphStatus] = useState("");
const [isPinnedToolOutputVisible, setPinnedToolOutputVisible] = useState(false); const [isPinnedToolOutputVisible, setPinnedToolOutputVisible] = useState(false);
const [isToolContentVisible, setIsToolContentVisible] = useState(false); const [isToolContentVisible, setIsToolContentVisible] = useState(false);
const [fadeIn, setFadeIn] = useState(false); const [fadeIn, setFadeIn] = useState(false);
const firstTokenReceived = useRef(false); const firstTokenReceived = useRef(false);
const hasAssistantContent = useRef(false);
const fadeTimeoutRef = useRef<NodeJS.Timeout | null>(null);
useEffect(() => { useEffect(() => {
const timer = setTimeout(() => { const timer = setTimeout(() => {
@ -212,15 +206,6 @@ export default function QuerySection({
return () => clearTimeout(timer); return () => clearTimeout(timer);
}, []); }, []);
useEffect(() => {
if (!isStreaming) {
setInferenceStats(prev => ({
...prev,
tokensReceived: 0,
startTime: 0
}));
}
}, [isStreaming]);
useEffect(() => { useEffect(() => {
const fetchSelectedSources = async () => { const fetchSelectedSources = async () => {
@ -276,9 +261,8 @@ export default function QuerySection({
case "token": { case "token": {
if (!text) break; if (!text) break;
if (!firstTokenReceived.current) { if (!firstTokenReceived.current) {
console.log('TTFT: ', new Date().toISOString());
firstTokenReceived.current = true; firstTokenReceived.current = true;
setIsStreaming(false); hasAssistantContent.current = true;
} }
setResponse(prev => { setResponse(prev => {
try { try {
@ -310,9 +294,6 @@ export default function QuerySection({
case "tool_end": case "tool_end":
case "node_end": { case "node_end": {
console.log(type, msg.data); console.log(type, msg.data);
if (msg.data === 'generate') {
console.log('generate complete. time: ', new Date().toISOString());
}
setGraphStatus(""); setGraphStatus("");
break; break;
} }
@ -342,7 +323,6 @@ export default function QuerySection({
return () => { return () => {
if (wsRef.current) { if (wsRef.current) {
wsRef.current.close(); wsRef.current.close();
setIsStreaming(false);
} }
}; };
}, [currentChatId]); }, [currentChatId]);
@ -361,70 +341,90 @@ export default function QuerySection({
useEffect(() => { useEffect(() => {
if (graphStatus) { if (graphStatus) {
setPinnedToolOutputVisible(true); setPinnedToolOutputVisible(true);
} else if (isPinnedToolOutputVisible) { // Trigger fade-in on next tick
if (fadeTimeoutRef.current) {
clearTimeout(fadeTimeoutRef.current);
}
setFadeIn(false);
fadeTimeoutRef.current = setTimeout(() => setFadeIn(true), 10);
} else {
// Delay hiding to allow fade-out // Delay hiding to allow fade-out
setFadeIn(false);
const timeout = setTimeout(() => { const timeout = setTimeout(() => {
setPinnedToolOutputVisible(false); setPinnedToolOutputVisible(false);
}, 800); // match CSS transition duration }, 800); // match CSS transition duration
return () => clearTimeout(timeout); return () => {
clearTimeout(timeout);
if (fadeTimeoutRef.current) {
clearTimeout(fadeTimeoutRef.current);
}
};
} }
}, [graphStatus, isPinnedToolOutputVisible]); }, [graphStatus]);
// Replace the effect for fade logic with this minimal version
useEffect(() => {
if (isPinnedToolOutputVisible && graphStatus) {
setFadeIn(false);
const t = setTimeout(() => setFadeIn(true), 10); // next tick for fade-in
return () => clearTimeout(t);
} else {
setFadeIn(false);
}
}, [isPinnedToolOutputVisible, graphStatus]);
// Cleanup image preview URL on unmount
useEffect(() => {
return () => {
if (imagePreview) {
URL.revokeObjectURL(imagePreview);
}
};
}, [imagePreview]);
const programmaticScroll = useRef(false); const programmaticScroll = useRef(false);
const scrollTimeout = useRef<number | null>(null); const scrollTimeout = useRef<number | null>(null);
const isUserScrollingRef = useRef(false);
const isNearBottomRef = useRef(true);
// Check if user is near the bottom of the chat
const handleDragEnter = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(true); }; const checkScrollPosition = useCallback(() => {
const handleDragLeave = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(false); }; if (chatContainerRef.current) {
const handleDragOver = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); }; const container = chatContainerRef.current;
const threshold = 100; // pixels from bottom
const handleDrop = async (e: React.DragEvent) => { const isNear = container.scrollHeight - container.scrollTop - container.clientHeight < threshold;
e.preventDefault(); isNearBottomRef.current = isNear;
e.stopPropagation();
setIsDragging(false);
const files = Array.from(e.dataTransfer.files);
const imageFile = files.find(file => file.type.startsWith('image/'));
if (imageFile) {
const previewUrl = URL.createObjectURL(imageFile);
setImagePreview(previewUrl);
const formData = new FormData();
formData.append('image', imageFile);
formData.append('chat_id', currentChatId || '');
try {
const response = await fetch('/api/upload-image', { method: 'POST', body: formData });
const result = await response.json();
setUploadedImage(result.image_id);
} catch (error) {
console.error('Error uploading image:', error);
URL.revokeObjectURL(previewUrl);
setImagePreview(null);
}
} }
}; }, []);
// Handle scroll events to detect user scrolling
useEffect(() => {
const container = chatContainerRef.current;
if (!container) return;
let scrollTimer: NodeJS.Timeout;
const handleScroll = () => {
isUserScrollingRef.current = true;
checkScrollPosition();
// Reset user scrolling flag after scroll stops
clearTimeout(scrollTimer);
scrollTimer = setTimeout(() => {
isUserScrollingRef.current = false;
}, 150);
};
container.addEventListener('scroll', handleScroll, { passive: true });
return () => {
container.removeEventListener('scroll', handleScroll);
clearTimeout(scrollTimer);
};
}, [checkScrollPosition]);
// Auto-scroll to bottom when response changes
useEffect(() => {
// Only scroll if we have assistant content and user hasn't manually scrolled away
if (!hasAssistantContent.current || isUserScrollingRef.current || !isNearBottomRef.current) {
return;
}
const scrollToBottom = () => {
if (messagesEndRef.current) {
messagesEndRef.current.scrollIntoView({
behavior: 'smooth',
block: 'end'
});
}
if (chatContainerRef.current) {
chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight;
}
};
scrollToBottom();
}, [response]);
const handleQuerySubmit = async (e: React.FormEvent<HTMLFormElement>) => { const handleQuerySubmit = async (e: React.FormEvent<HTMLFormElement>) => {
e.preventDefault(); e.preventDefault();
@ -434,13 +434,11 @@ export default function QuerySection({
setQuery(""); setQuery("");
setIsStreaming(true); setIsStreaming(true);
firstTokenReceived.current = false; firstTokenReceived.current = false;
hasAssistantContent.current = false;
try { try {
console.log('sending uploaded image: ', uploadedImage, ' with query: ', currentQuery)
console.log('current time: ', new Date().toISOString());
wsRef.current.send(JSON.stringify({ wsRef.current.send(JSON.stringify({
message: currentQuery, message: currentQuery
image_id: uploadedImage
})); }));
setResponse(prev => { setResponse(prev => {
@ -455,14 +453,6 @@ export default function QuerySection({
return prev + `\n\nHuman: ${currentQuery}\n\nAssistant: `; return prev + `\n\nHuman: ${currentQuery}\n\nAssistant: `;
} }
}); });
// NEW CODE
if (imagePreview) {
URL.revokeObjectURL(imagePreview);
}
setUploadedImage(null);
setImagePreview(null);
// NEW CODE
} catch (error) { } catch (error) {
console.error("Error sending message:", error); console.error("Error sending message:", error);
setIsStreaming(false); setIsStreaming(false);
@ -581,36 +571,7 @@ export default function QuerySection({
</div> </div>
<form onSubmit={handleQuerySubmit} className={styles.inputContainer}> <form onSubmit={handleQuerySubmit} className={styles.inputContainer}>
{/* NEW CODE - Image preview moved to the left of inputWrapper */} <div className={styles.inputWrapper}>
{imagePreview && (
<div className={styles.imagePreview}>
<img
src={imagePreview}
alt="Image preview"
className={styles.previewImage}
/>
<button
className={styles.removeImageButton}
onClick={() => {
if (imagePreview) {
URL.revokeObjectURL(imagePreview);
}
setUploadedImage(null);
setImagePreview(null);
}}
>
</button>
</div>
)}
{/* NEW CODE */}
<div
className={`${styles.inputWrapper} ${isDragging ? styles.dragging : ''}`}
onDragEnter={handleDragEnter}
onDragLeave={handleDragLeave}
onDragOver={handleDragOver}
onDrop={handleDrop}
>
<textarea <textarea
rows={1} rows={1}
value={query} value={query}
@ -648,14 +609,10 @@ export default function QuerySection({
<div className={styles.disclaimer}> <div className={styles.disclaimer}>
This is a concept demo to showcase multiple models and MCP use. It is not optimized for performance. Developers can customize and further optimize it for performance. This is a concept demo to showcase multiple models and MCP use. It is not optimized for performance. Developers can customize and further optimize it for performance.
<br /> <br />
<span className={styles.info}>Note: If a response is cut short, please start a new chat to continue.</span>
<br />
<span className={styles.warning}>Don't forget to shutdown docker containers at the end of the demo.</span> <span className={styles.warning}>Don't forget to shutdown docker containers at the end of the demo.</span>
</div> </div>
{inferenceStats.tokensPerSecond > 0 && (
<div className={styles.inferenceStats}>
{inferenceStats.tokensPerSecond} tokens/sec
</div>
)}
</div> </div>
); );
} }

View File

@ -165,7 +165,7 @@ export default function Sidebar({
}, [refreshTrigger, fetchSources]); }, [refreshTrigger, fetchSources]);
// Add function to fetch chat metadata // Add function to fetch chat metadata
const fetchChatMetadata = async (chatId: string) => { const fetchChatMetadata = useCallback(async (chatId: string) => {
try { try {
const response = await fetch(`/api/chat/${chatId}/metadata`); const response = await fetch(`/api/chat/${chatId}/metadata`);
if (response.ok) { if (response.ok) {
@ -178,7 +178,7 @@ export default function Sidebar({
} catch (error) { } catch (error) {
console.error(`Error fetching metadata for chat ${chatId}:`, error); console.error(`Error fetching metadata for chat ${chatId}:`, error);
} }
}; }, []);
// Update fetchChats to also fetch metadata // Update fetchChats to also fetch metadata
const fetchChats = useCallback(async () => { const fetchChats = useCallback(async () => {
@ -201,7 +201,7 @@ export default function Sidebar({
} finally { } finally {
setIsLoadingChats(false); setIsLoadingChats(false);
} }
}, []); }, [fetchChatMetadata]);
// Fetch chats when history section is expanded // Fetch chats when history section is expanded
useEffect(() => { useEffect(() => {

View File

@ -346,11 +346,12 @@
.inputContainer { .inputContainer {
display: flex; display: flex;
padding: 12px 16px; padding: 12px 16px 4px 16px;
position: relative; position: relative;
justify-content: center; justify-content: center;
align-items: center; align-items: center;
background-color: #f9fafb; background-color: #f9fafb;
margin-top: 0;
margin-bottom: 0; margin-bottom: 0;
width: 100%; width: 100%;
opacity: 0; opacity: 0;
@ -473,7 +474,7 @@
} }
} }
.sendButton.show, .streamingCancelButton.show, .uploadButton.show { .sendButton.show, .streamingCancelButton.show {
animation: buttonAppear 0.3s ease forwards; animation: buttonAppear 0.3s ease forwards;
} }
@ -559,21 +560,6 @@
} }
} }
.inferenceStats {
position: fixed;
bottom: 10px;
right: 10px;
background-color: rgba(0, 0, 0, 0.6);
color: white;
padding: 4px 8px;
border-radius: 4px;
font-size: 12px;
font-family: var(--font-geist-mono);
z-index: 10;
pointer-events: none;
opacity: 0.8;
}
.querySection { .querySection {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
@ -665,97 +651,6 @@
40% { transform: scale(1); } 40% { transform: scale(1); }
} }
.uploadButton {
width: 36px;
height: 36px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
border: none;
cursor: pointer;
background-color: transparent;
color: #76B900;
transition: all 0.2s ease;
padding: 0;
opacity: 0;
visibility: hidden;
position: absolute;
left: 8px;
top: 50%;
transform: translateY(-50%);
z-index: 2;
}
.uploadButton.show {
animation: buttonAppear 0.3s ease forwards;
}
.uploadButton:hover {
transform: translateY(-50%) scale(1.1);
color: #669f00;
}
:global(.dark) .uploadButton {
color: #76B900;
}
:global(.dark) .uploadButton:hover {
color: #669f00;
}
.inputWrapper.dragging {
border: 2px dashed #76B900;
background-color: rgba(118, 185, 0, 0.1);
}
.imagePreview {
position: relative;
width: 150px;
height: 100px;
display: flex;
align-items: center;
justify-content: center;
}
/* :global(.dark) .imagePreview {
background: #1f2937;
} */
.previewImage {
width: 100%;
height: 100%;
object-fit: cover;
display: block;
border-radius: 8px;
}
.removeImageButton {
position: absolute;
top: -2px;
right: -2px;
width: 20px;
height: 20px;
border-radius: 50%;
background: rgba(0, 0, 0, 0.251); /* red with 50% opacity */
color: white;
border: none;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
font-size: 12px;
font-weight: bold;
transition: background-color 0.2s ease;
}
.removeImageButton:hover {
background: #e53e3e;
}
/* NEW CODE */
/* Adding styles for tool calling interface */ /* Adding styles for tool calling interface */
.toolCallingWrapper { .toolCallingWrapper {
display: flex; display: flex;
@ -1033,10 +928,10 @@
.disclaimer { .disclaimer {
text-align: center; text-align: center;
font-size: 12px; font-size: 10px;
color: #6b7280; color: #6b7280;
padding: 2px 16px; padding: 2px 16px;
margin: -16px auto 0 auto; margin: -12px auto 0 auto;
max-width: 600px; max-width: 600px;
line-height: 1.4; line-height: 1.4;
opacity: 0.8; opacity: 0.8;
@ -1047,6 +942,15 @@
color: #9ca3af; color: #9ca3af;
} }
.info {
color: #3b82f6;
font-weight: 500;
}
:global(.dark) .info {
color: #60a5fa;
}
.warning { .warning {
color: #f59e0b; color: #f59e0b;
font-weight: 500; font-weight: 500;

View File

@ -159,10 +159,25 @@ RUN git clone https://github.com/triton-lang/triton.git && \
# install xformers from source for blackwell support # install xformers from source for blackwell support
RUN git clone --depth=1 https://github.com/facebookresearch/xformers --recursive && \ RUN git clone --depth=1 https://github.com/facebookresearch/xformers --recursive && \
cd xformers && \ cd xformers && \
git checkout 5146f2ab37b2163985c19fb4e8fbf6183e82f8ce && \
export TORCH_CUDA_ARCH_LIST="12.1" && \ export TORCH_CUDA_ARCH_LIST="12.1" && \
python setup.py install && \ python setup.py install && \
cd .. cd ..
# install unsloth without depedencies so we can build them from source # install unsloth without depedencies so we can build them from source
RUN pip install unsloth unsloth_zoo bitsandbytes==0.48.0 RUN pip install unsloth==2025.9.11 unsloth_zoo==2025.9.14 bitsandbytes==0.48.0
CMD ["/bin/bash"] CMD ["/bin/bash"]
# docker run \
# --rm \
# --gpus=all \
# --ipc=host \
# --net=host \
# --ulimit memlock=-1 \
# --ulimit stack=67108864 \
# -w $(pwd) \
# -v $(pwd):$(pwd) \
# -v $HOME/.cache/huggingface:/root/.cache/huggingface \
# nvcr.io/nvidia/vllm:25.09-py3 \
# vllm serve "unsloth/Qwen2.5-VL-7B-Instruct" --port "8000" --served-model-name "base-model" --max-model-len 16384 --gpu-memory-utilization 0.3 --async-scheduling --enable_prefix_caching

View File

@ -26,18 +26,13 @@ docker build --build-arg HF_TOKEN=$HF_TOKEN -t vlm_demo .
### 2. Launch the Container ### 2. Launch the Container
```bash ```bash
# Run the container with GPU support # Enter the correct directory for building the image
docker run -it \ cd vlm-finetuning/assets
--gpus=all \
--net=host \
--ipc=host \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-v $(pwd):/vlm_finetuning \
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
vlm_demo
# Enter the mounted directory # Run the container with GPU support
sh launch.sh
# Enter the mounted directory within the container
cd /vlm_finetuning cd /vlm_finetuning
``` ```

View File

@ -20,7 +20,12 @@ docker run -it \
--gpus=all \ --gpus=all \
--net=host \ --net=host \
--ipc=host \ --ipc=host \
-w $HOME \ --ulimit memlock=-1 \
-v $HOME:$HOME \ --ulimit stack=67108864 \
-e HOST_HOME=$HOME \
-e HOST_PWD=$(pwd) \
-v $(pwd):/vlm_finetuning \
-v $HOME/.cache/huggingface:/root/.cache/huggingface \ -v $HOME/.cache/huggingface:/root/.cache/huggingface \
-v /var/run/docker.sock:/var/run/docker.sock \
-v /usr/bin/docker:/usr/bin/docker \
vlm_demo vlm_demo

View File

@ -15,20 +15,20 @@
# limitations under the License. # limitations under the License.
# #
from unsloth import FastVisionModel
import os import os
import re import re
import gc import json
import yaml import yaml
import glob import glob
import time
import base64
import random import random
import requests
import subprocess import subprocess
import wandb import pandas as pd
import torch
from PIL import Image
import streamlit as st import streamlit as st
from transformers.trainer_utils import get_last_checkpoint
REASONING_START = "<REASONING>" REASONING_START = "<REASONING>"
@ -37,14 +37,6 @@ SOLUTION_START = "<SOLUTION>"
SOLUTION_END = "</SOLUTION>" SOLUTION_END = "</SOLUTION>"
def initialize_session_state(resources):
# Initialize page-specific session state
st.session_state["base"] = st.session_state.get("base", resources["base"])
st.session_state["finetuned"] = st.session_state.get("finetuned", resources["finetuned"])
st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[0])
st.session_state["train_process"] = st.session_state.get("train_process", None)
def load_config(): def load_config():
config_key = "config" config_key = "config"
if getattr(st.session_state, config_key, None) is None: if getattr(st.session_state, config_key, None) is None:
@ -58,14 +50,106 @@ def load_config():
@st.cache_resource @st.cache_resource
def initialize_resources(inference_config): def start_vllm_server(model_id, model_type, max_seq_length, port):
base_model, base_tokenizer = load_model_for_inference(inference_config, "base") # get pwd
finetuned_model, finetuned_tokenizer = load_model_for_inference(inference_config, "finetuned") return subprocess.Popen([
"docker", "run",
"--rm",
"--gpus=all",
"--ipc=host",
"--net=host",
"--ulimit", "memlock=-1",
"--ulimit", "stack=67108864",
"-v", f"{os.environ.get('HOST_HOME')}/.cache/huggingface:/root/.cache/huggingface",
"-v", f"{os.environ.get('HOST_PWD')}/ui_image/saved_model:/workspace/saved_model",
"nvcr.io/nvidia/vllm:25.09-py3",
"vllm", "serve",
model_id,
"--port", str(port),
"--served-model-name", model_type,
"--max-model-len", str(max_seq_length),
"--gpu-memory-utilization", "0.45",
"--async-scheduling",
"--enable_prefix_caching"
])
return {
"base": {"model": base_model, "tokenizer": base_tokenizer}, def check_vllm_health(model_type, port):
"finetuned": {"model": finetuned_model, "tokenizer": finetuned_tokenizer}, try :
} output = json.loads(subprocess.check_output(
["curl", "-s", f"http://localhost:{port}/v1/models"],
text=True
))
return output["data"][0]["id"] == model_type
except:
return False
def invoke_vllm_server(model_type, prompt, image, port):
with open(image, "rb") as f:
image = base64.b64encode(f.read()).decode("utf-8")
payload = json.dumps({
"model": model_type,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image}"
}
}
]
}
],
"max_tokens": 1024,
"temperature": 0,
"top_p": 1,
})
return requests.post(
f"http://localhost:{port}/v1/chat/completions",
headers={"Content-Type": "application/json"},
data=payload
).json()["choices"][0]["message"]["content"]
def initialize_state(config):
st.session_state["mode"] = st.session_state.get("mode", "inference")
st.session_state["base"] = st.session_state.get("base", {})
st.session_state["finetuned"] = st.session_state.get("finetuned", {})
st.session_state["base"]["port"] = st.session_state["base"].get("port", "8000")
st.session_state["finetuned"]["port"] = st.session_state["finetuned"].get("port", "8001")
if st.session_state["mode"] == "inference":
st.session_state["base"]["process"] = start_vllm_server(
config["model_id"], "base", config["max_seq_length"], st.session_state["base"]["port"])
finetuned_model_path = get_last_checkpoint(config["finetuned_model_id"])
if finetuned_model_path is not None:
st.session_state["finetuned"]["process"] = start_vllm_server(
finetuned_model_path, "finetuned", config["max_seq_length"], st.session_state["finetuned"]["port"])
if not check_vllm_health("base", st.session_state["base"]["port"]):
with st.spinner("Loading vLLM server for base model..."):
while not check_vllm_health("base", st.session_state["base"]["port"]):
time.sleep(1)
st.toast("Base model loaded", icon="", duration="short")
if finetuned_model_path is not None:
if not check_vllm_health("finetuned", st.session_state["finetuned"]["port"]):
with st.spinner("Loading vLLM server for finetuned model..."):
while not check_vllm_health("finetuned", st.session_state["finetuned"]["port"]):
time.sleep(1)
st.toast("Finetuned model loaded", icon="", duration="short")
st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[-1])
st.session_state["train_process"] = st.session_state.get("train_process", None)
def main(): def main():
@ -80,12 +164,7 @@ def main():
# load resources # load resources
config = load_config() config = load_config()
if st.session_state.get("base", None) is None: initialize_state(config["inference"])
st.toast("Loading model", icon="", duration="short")
resource = initialize_resources(config["inference"])
if st.session_state.get("base", None) is None:
st.toast("Model loaded", icon="", duration="short")
initialize_session_state(resource)
# train section # train section
st.markdown("---") st.markdown("---")
@ -103,11 +182,11 @@ def train_section():
with column_1: with column_1:
finetuning_method = st.selectbox( finetuning_method = st.selectbox(
"Finetuning Method:", "Finetuning Method:",
["LoRA", "QLoRA", "Full Fine-tuning"], ["LoRA", "Full Fine-tuning"],
) )
# update lora config # update lora config
if finetuning_method in ("QLoRA", "LoRA"): if finetuning_method == "LoRA":
lora_config = st.session_state["config"]["train"]["model"]["lora_config"] lora_config = st.session_state["config"]["train"]["model"]["lora_config"]
with column_2: with column_2:
@ -135,7 +214,6 @@ def train_section():
# update model config based on selection # update model config based on selection
st.session_state["config"]["train"]["model"]["use_lora"] = finetuning_method == "LoRA" st.session_state["config"]["train"]["model"]["use_lora"] = finetuning_method == "LoRA"
st.session_state["config"]["train"]["model"]["use_qlora"] = finetuning_method == "QLoRA"
# update train config # update train config
st.write("") st.write("")
@ -149,12 +227,12 @@ def train_section():
finetune_language_layers = st.toggle( finetune_language_layers = st.toggle(
"Finetune Language Layers", "Finetune Language Layers",
value=st.session_state["config"]["train"]["model"]["finetune_language_layers"]) value=st.session_state["config"]["train"]["model"]["finetune_language_layers"])
with column_3: with column_3:
finetune_attention_modules = st.toggle( finetune_attention_modules = st.toggle(
"Finetune Attention Modules", "Finetune Attention Modules",
value=st.session_state["config"]["train"]["model"]["finetune_attention_modules"]) value=st.session_state["config"]["train"]["model"]["finetune_attention_modules"])
with column_4: with column_4:
finetune_mlp_modules = st.toggle( finetune_mlp_modules = st.toggle(
"Finetune MLP Modules", "Finetune MLP Modules",
@ -163,11 +241,11 @@ def train_section():
st.write("") st.write("")
column_1, column_2, column_3, column_4 = st.columns(4, gap="large") column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
with column_1: with column_1:
epochs = st.slider( steps = st.slider(
"Epochs", "Steps",
min_value=1, min_value=1,
max_value=100, max_value=1000,
value=st.session_state["config"]["train"]["hyperparameters"]["epochs"]) value=st.session_state["config"]["train"]["hyperparameters"]["steps"])
with column_2: with column_2:
batch_size = st.select_slider( batch_size = st.select_slider(
@ -189,7 +267,7 @@ def train_section():
options=["adamw_torch", "adafactor"]) options=["adamw_torch", "adafactor"])
st.session_state["config"]["train"]["hyperparameters"].update({ st.session_state["config"]["train"]["hyperparameters"].update({
'epochs': epochs, 'steps': steps,
'batch_size': batch_size, 'batch_size': batch_size,
'learning_rate': learning_rate, 'learning_rate': learning_rate,
'optimizer': optimizer, 'optimizer': optimizer,
@ -216,7 +294,7 @@ def train_section():
min_value=0.0, min_value=0.0,
max_value=5.0, max_value=5.0,
value=float(st.session_state["config"]["train"]["hyperparameters"]["format_reward"]), value=float(st.session_state["config"]["train"]["hyperparameters"]["format_reward"]),
format="%.2e") format="%.2f")
with column_3: with column_3:
correctness_reward = st.number_input( correctness_reward = st.number_input(
@ -224,15 +302,14 @@ def train_section():
min_value=0.0, min_value=0.0,
max_value=5.0, max_value=5.0,
value=float(st.session_state["config"]["train"]["hyperparameters"]["correctness_reward"]), value=float(st.session_state["config"]["train"]["hyperparameters"]["correctness_reward"]),
format="%.2e") format="%.2f")
with column_4: with column_4:
num_generations = st.number_input( num_generations = st.number_input(
"Number of generations", "Number of generations",
min_value=1, min_value=1,
max_value=16, max_value=16,
value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"], value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"])
format="%.2e")
# Training control # Training control
st.write("") st.write("")
@ -242,15 +319,29 @@ def train_section():
button_type = "secondary" if st.session_state["train_process"] else "primary" button_type = "secondary" if st.session_state["train_process"] else "primary"
if st.button("▶️ Start Finetuning", type=button_type, width="stretch", disabled=bool(st.session_state["train_process"])): if st.button("▶️ Start Finetuning", type=button_type, width="stretch", disabled=bool(st.session_state["train_process"])):
if st.session_state["train_process"] is None: if st.session_state["train_process"] is None:
st.session_state["base"]["process"].terminate()
st.session_state["base"]["process"].wait()
st.session_state["base"]["process"] = None
if "finetuned" in st.session_state and "process" in st.session_state["finetuned"]:
st.session_state["finetuned"]["process"].terminate()
st.session_state["finetuned"]["process"].wait()
st.session_state["finetuned"]["process"] = None
st.session_state["mode"] = "train"
st.cache_resource.clear()
# store config # store config
with open("src/train.yaml", "w") as f: with open("src/train.yaml", "w") as f:
yaml.dump(st.session_state["config"]["train"], f, default_flow_style=False) yaml.dump(st.session_state["config"]["train"], f, default_flow_style=False)
# start training # start training
st.session_state["train_process"] = subprocess.Popen( with open("/tmp/logs.txt", "w") as f:
["python", "src/train_image_vlm.py"], st.session_state["train_process"] = subprocess.Popen(
stdout=None, stderr=None ["python", "-u", "src/train_image_vlm.py"],
) stdout=f,
stderr=subprocess.STDOUT,
text=True
)
st.toast("Training started", icon="", duration="short")
else: else:
st.toast("Training already in progress", icon="", duration="short") st.toast("Training already in progress", icon="", duration="short")
@ -259,38 +350,71 @@ def train_section():
if st.button("⏹️ Stop Finetuning", type=button_type, width="stretch", disabled=not bool(st.session_state["train_process"])): if st.button("⏹️ Stop Finetuning", type=button_type, width="stretch", disabled=not bool(st.session_state["train_process"])):
if st.session_state["train_process"] is not None: if st.session_state["train_process"] is not None:
st.session_state["train_process"].terminate() st.session_state["train_process"].terminate()
st.session_state["train_process"].wait()
st.session_state["train_process"] = None st.session_state["train_process"] = None
st.session_state["mode"] = "inference"
st.toast("Training stopped", icon="", duration="short") st.toast("Training stopped", icon="", duration="short")
st.toast("Re-deploy the app with updated finetuned model", icon=":material/info:", duration="short") st.rerun()
else: else:
st.toast("No training to stop", icon="", duration="short") st.toast("No training to stop", icon="", duration="short")
with column_3: with column_3:
if st.session_state["train_process"]: badge_holder = st.empty()
st.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch")
else:
st.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch")
# display wandb # create empty holders
runs = wandb.Api().runs(f"{os.environ.get('WANDB_ENTITY')}/{os.environ.get('WANDB_PROJECT')}") columns = st.columns(4)
if runs: with columns[0]:
base_url = runs[0].url steps_holder = st.empty()
loss_url = f"{base_url}?panelDisplayName=train%2Floss&panelSectionName=train" with columns[1]:
memory_url = f"{base_url}?panelDisplayName=GPU+Memory+Allocated+%28%25%29&panelSectionName=System" format_reward_holder = st.empty()
with columns[2]:
correctness_reward_holder = st.empty()
with columns[3]:
total_reward_holder = st.empty()
df_holder = st.empty()
column_1, column_2 = st.columns(2) # parse grpo logs
with column_1: if st.session_state["train_process"] is not None:
st.markdown(f""" while True:
<div class="wandb-wrapper"> output = open("/tmp/logs.txt", "r").read().strip()
<iframe src="{loss_url}" class="wandb-iframe"></iframe>
</div> logs = []
""", unsafe_allow_html=True) for line in output.split("\n"):
with column_2: if "{" in line and "}" in line:
st.markdown(f""" dict_match = re.search(r"\{[^}]+\}", line)
<div class="wandb-wrapper"> if dict_match:
<iframe src="{memory_url}" class="wandb-iframe"></iframe> log_dict = eval(dict_match.group())
</div> if isinstance(log_dict, dict) and any(k in log_dict for k in [
""", unsafe_allow_html=True) "rewards/format_reward_func/mean",
"rewards/correctness_reward_func/mean",
"reward",
]):
logs.append(log_dict)
df = pd.DataFrame(logs)
if "reward" in df.columns:
steps_holder.metric("Steps", f"{len(df)}" if len(df) > 0 else "N/A")
format_reward_holder.metric("Format Reward", f"{df['rewards/format_reward_func/mean'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
correctness_reward_holder.metric("Correctness Reward", f"{df['rewards/correctness_reward_func/mean'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
total_reward_holder.metric("Total Reward", f"{df['reward'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
badge_holder.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch")
else:
badge_holder.badge("Loading", icon=":material/hourglass_empty:", color="yellow", width="stretch")
df_holder.dataframe(df, width="stretch", hide_index=True)
time.sleep(1)
if st.session_state["train_process"] is None or st.session_state["train_process"].poll() is not None:
st.session_state["train_process"].terminate()
st.session_state["train_process"].wait()
st.session_state["train_process"] = None
st.session_state["mode"] = "inference"
st.toast("Training stopped", icon="", duration="short")
st.rerun()
else:
badge_holder.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch")
def inference_section(): def inference_section():
@ -342,33 +466,15 @@ def inference_section():
response = start_inference("base") response = start_inference("base")
base_generation.markdown(response) base_generation.markdown(response)
with st.spinner("Running..."): if "finetuned" in st.session_state and "process" in st.session_state["finetuned"]:
response = start_inference("finetuned") with st.spinner("Running..."):
finetuned_generation.markdown(response) response = start_inference("finetuned")
finetuned_generation.markdown(response)
else:
finetuned_generation.markdown("```No response since there is no finetuned model```")
def load_model_for_inference(config, model_type):
if model_type == "finetuned":
model_name = config["finetuned_model_id"]
elif model_type == "base":
model_name = config["model_id"]
else:
raise ValueError(f"Invalid model type: {model_type}")
model, tokenizer = FastVisionModel.from_pretrained(
model_name=model_name,
max_seq_length=config["max_seq_length"],
load_in_4bit=False,
)
FastVisionModel.for_inference(model)
return model, tokenizer
@torch.no_grad()
def start_inference(model_type): def start_inference(model_type):
# define prompt
prompt = st.session_state["prompt"] prompt = st.session_state["prompt"]
if model_type == "finetuned": if model_type == "finetuned":
prompt = ( prompt = (
@ -377,48 +483,13 @@ def start_inference(model_type):
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}" f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
) )
# load image response = invoke_vllm_server(
image = Image.open(st.session_state["current_image"]) model_type,
if image.mode != "RGB":
image = image.convert("RGB")
# construct instruction prompt
prompt = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
]
# apply chat template
prompt = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].apply_chat_template(
prompt, prompt,
tokenize=False, st.session_state["current_image"],
add_generation_prompt=True, st.session_state[model_type]["port"]
) )
# tokenize inputs
inputs = st.session_state[f"{model_type}_image_vlm"]["tokenizer"](
image,
prompt,
add_special_tokens=False,
return_tensors="pt",
).to("cuda")
# perform inference
response = st.session_state[f"{model_type}_image_vlm"]["model"].generate(
**inputs,
max_new_tokens=1024,
use_cache=True,
do_sample=False
)[0][inputs["input_ids"].shape[1]: ]
# decode tokens
response = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].decode(response, skip_special_tokens=True)
# format response # format response
if model_type == "finetuned": if model_type == "finetuned":
response = response.replace(REASONING_START, "```") response = response.replace(REASONING_START, "```")

View File

@ -7,111 +7,172 @@ This project demonstrates fine-tuning Vision-Language Models (VLMs) for image un
The project includes: The project includes:
- **Interactive Training Interface**: Streamlit-based UI for configuring and monitoring VLM fine-tuning - **Interactive Training Interface**: Streamlit-based UI for configuring and monitoring VLM fine-tuning
- **GRPO Training**: Advanced preference optimization for better reasoning capabilities - **GRPO Training**: Advanced preference optimization for better reasoning capabilities
- **Multiple Fine-tuning Methods**: Support for LoRA, QLoRA, and Full Finetuning - **Multiple Fine-tuning Methods**: Support for LoRA and Full Finetuning
- **Side-by-side Inference**: Compare base model vs fine-tuned model performance - **Side-by-side Inference using vLLM**: Run the base model and fine-tuned model side-by-side to compare performance
## Getting Started ## Contents
1. [Model Download](#1-model-download)
2. [Dataset Preparation](#2-dataset-preparation)
3. [Base Model Inference](#3-base-model-inference)
4. [GRPO Finetuning](#4-grpo-finetuning)
5. [Finetuned Model Inference](#5-finetuned-model-inference)
> **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/`. ## 1. Model Download
### 1. Set Up Weights & Biases > **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/assets/README.md`.
Configure your wandb credentials for training monitoring: ### 1.1 Download the pre-trained model
```bash ```bash
export WANDB_PROJECT="vlm_finetuning" hf download Qwen/Qwen2.5-VL-7B-Instruct
export WANDB_ENTITY=<WANDB_USERNAME>
export WANDB_API_KEY=<WANDB_API_KEY>
``` ```
### 2. Launch the Application ### 1.2 (Optional) Download the fine-tuned model
```bash If you already have a fine-tuned checkpoint, place it in the `saved_model/` folder.
# Start the Streamlit interface
streamlit run Image_VLM.py
```
The application will be available at `http://localhost:8501` # TODO: SHOW TREE AND SKIP TO INFERENCE
## Training If you already have a finetuned checkpoint that you would like to just use for a comparative analysis against the base model, skip directly to the [Finetuned Model Inference](#5-finetuned-model-inference) section.
### Dataset ## 2. Dataset Preparation
The project uses a **wildfire detection dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes: The project uses a **Wildfire Detection Dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes:
- Satellite and aerial imagery from wildfire-affected areas - Satellite and aerial imagery from wildfire-affected areas
- Binary classification: wildfire vs no wildfire - Binary classification: wildfire vs no wildfire
#### Dataset Setup ### 2.1 Create a dataset folder
1. **Download from Kaggle**: Visit the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) on Kaggle
2. **Get the curl command**: On the Kaggle dataset page, click the download button and copy the curl command provided
3. **Download and extract**: Run the following commands in your container:
```bash ```bash
mkdir data mkdir -p ui_image/data
cd data cd ui_image/data
```
# Paste the curl command from Kaggle here, and then continue to unzip the dataset ### 2.2 Dataset Download
For this finetuning playbook, we will use the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) from Kaggle. Visit the kaggle dataset page [here](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) to click the download button. Select the `cURL` option in the `Download Via` dropdown and copy the curl command.
> **Note**: You will need to be logged into Kaggle and may need to accept the dataset terms before the download link works.
Run the following commands in your container:
```bash
# Past and run the curl command from Kaggle here, and then continue to unzip the dataset
unzip -qq wildfire-prediction-dataset.zip unzip -qq wildfire-prediction-dataset.zip
rm wildfire-prediction-dataset.zip rm wildfire-prediction-dataset.zip
cd .. cd ..
``` ```
> **Note**: You'll need to be logged into Kaggle and may need to accept the dataset terms before the download link works. ## 3. Base Model Inference
### Training Configuration Before we start finetuning, let's start spin up the demo UI to evaluate the base model's performance on this task.
Configure training through the interactive interface: ### 3.1 Spin up the Streamlit demo
#### Model Settings ```bash
- **Base Model**: Qwen/Qwen2.5-VL-7B-Instruct streamlit run Image_VLM.py
- **Fine-tuning Method**: Choose from LoRA, QLoRA, or Full Finetuning ```
- **LoRA Parameters**: Adjustable rank (8-64) and alpha (8-64)
#### Training Parameters Access the streamlit demo at http://localhost:8501/.
- **Epochs**: 1-100 (default: 10)
- **Batch Size**: 1, 2, 4, 8, or 16 (default: 2)
- **Learning Rate**: 1e-6 to 1e-2 (default: 1e-5)
- **Optimizer**: AdamW or Adafactor
#### GRPO Settings ### 3.2 Wait for demo spin-up
- **Format Reward**: 2.0 (reward for proper reasoning format)
- **Correctness Reward**: 5.0 (reward for correct answers)
- **Number of Generations**: 4 (for preference optimization)
### Training Process When you access the streamlit demo for the first time, the backend triggers vLLM servers to spin up for the base model. You will see a spinner on the demo site as vLLM is being brought up for optimized inference. This step can take upto 15 mins.
1. **Configure Parameters**: Use the web interface to set training hyperparameters After the streamlit demo is fully loaded, you should be able to see a similar UI state that is ready for inference.
2. **Start Training**: Click "▶️ Start Finetuning" to begin GRPO training
3. **Monitor Progress**: View real-time loss curves and GPU utilization via embedded wandb charts
4. **Stop if Needed**: Use "⏹️ Stop Finetuning" to halt training early
> **Important**: After training completes, follow these steps: <figure>
> 1. **Stop the UI**: Use Ctrl+C to stop the Streamlit application <img src="assets/inference_page.png" alt="Inference Page" width="1000"/>
> 2. **Update Config**: Edit `src/image_vlm_config.yaml` and change the `finetuned_model_id` path to point to your newly trained model in the `saved_model/` directory <figcaption>Inference demo on the UI</figcaption>
> 3. **Restart UI**: Launch the application again to test your fine-tuned model </figure>
## Inference ### 3.3 Run base model inference
### Interactive Comparison Since we are currently focused on inferring the base model, let's scroll down to the `Image Inference` section of the UI. Here, you should see a sample pre-loaded satellite image of a potentially wildfire-affected region.
![Inference Interface](assets/inference_screenshot.png) Enter your prompt in the chat box and hit `Generate`. Your prompt would be first sent to the base model and you should see the generation response on the left chat box. If you did not provide a finetuned model, you should not see any generations from the right chat box.
*Side-by-side comparison showing base model vs fine-tuned model performance on wildfire detection*
The inference section provides: As you can see, the base model is incapable of providing the right response for this domain-specific task. Let's try to improve the model's accuracy by performing GRPO finetuning.
- **Sample Images**: Test on wildfire detection samples from both categories (wildfire/no wildfire)
- **Dual Inference**: Run both base and fine-tuned models simultaneously
- **Random Sampling**: Test different samples with the "🎲 Test another sample" button
- **Structured Reasoning**: Fine-tuned model provides reasoning in `<REASONING>` tags before final answer
### Sample Questions ## 4. GRPO Finetuning
The interface includes prompts for wildfire detection: We will perform GRPO finetuning to add reasoning capabilities to our base model and improve the model's understanding to the underlying domain. Considering that you have already spun up the streamlit demo, scroll to the `GRPO Training section`.
- "Identify if this region has been affected by a wildfire"
- The fine-tuned model provides structured reasoning followed by a Yes/No answer <figure>
<img src="assets/training_page.png" alt="Training Page" width="1000"/>
<figcaption>Training parameters on the UI</figcaption>
</figure>
### 4.1 Model Settings
Configure the finetuning method and lora parameters based on the following options.
- `Finetuning Method`: Choose from Full Finetuning or LoRA
- `LoRA Parameters`: Adjustable rank (8-64) and alpha (8-64)
### 4.1 Finetune layers
You can additionally choose whether the layers you want to finetune in the VLM. For the best performance, ensure that all options are toggled on. Note that this will increase the model training time as well.
### 4.2 Training parameters
In this section, we can select certain model parameters as relevant to our training run.
- `Epochs`: 1-100
- `Batch Size`: 1, 2, 4, 8, or 16
- `Learning Rate`: 1e-6 to 1e-2
- `Optimizer`: AdamW or Adafactor
### 4.3 GRPO settings
For a GRPO setup, we also have the flexibility in choosing the reward that is assigned to the model based on certain criteria
- `Format Reward`: 2.0 (reward for proper reasoning format)
- `Correctness Reward`: 5.0 (reward for correct answers)
- `Number of Generations`: 4 (for preference optimization)
### 4.4 Start training
After configuring all the parameters, hit `Start Finetuning` to begin the training process. You will need to wait about 15 mins for the model to load and start recording metadata on the UI. As the training progresses, information such as the loss, epoch and GRPO rewards will be recorded on a live table.
### 4.5 Stop training
If you wish to stop training, just hit the `Stop Finetuning` button. Ensure that you stop the training with atleast 50 steps complete to ensure that a finetuned checkpoint is stored.
Once you stop training, the UI will automatically bring up the vLLM servers for the base model and the newly finetuned model.
## 5. Finetuned Model Inference
Now we are ready to perform a comparative analysis between the base model and the finetuned model.
### 5.1 (Optional) Spin up the Streamlit demo
If you haven't spun up the streamlit demo already, execute the following command. If had just just stopped training and are still within the live UI, skip to the next step.
```bash
streamlit run Image_VLM.py
```
Access the streamlit demo at http://localhost:8501/.
### 5.2 vLLM startup
Regardless of whether you just spun up the demo or just stopped training, please wait about 15 mins for the vLLM servers to be brought up.
### 5.3 Run finetuned model inference
Scroll down to the `Image Inference` section, and enter your prompt in the provided chat box. Upon clicking `Generate`, your prompt would be first sent to the base model and then to the finetuned model. You can use the following prompt to quickly test inference
`Identify if this region has been affected by a wildfire`
If you trained your model sufficiently enough, you should see that the finetuned model is able to perform reasoning and provide a concise, accurate answer to the prompt. The reasoning steps are provided in the markdown format, while the final answer is bolded and provided at the end of the model's response.
For the image shown below, we have trained the model for 1000 steps, which took about 4 hours.
### 5.4 Further analysis
If you wish to play around with these models with additional images, the `Test another sample` button will load another random satellite image.
## File Structure ## File Structure
@ -131,7 +192,3 @@ ui_image/
│ └── inference_screenshot.png # UI demonstration screenshot │ └── inference_screenshot.png # UI demonstration screenshot
└── saved_model/ # Training checkpoints directory (update config to point here) └── saved_model/ # Training checkpoints directory (update config to point here)
``` ```
## Configuration
The `src/image_vlm_config.yaml` file contains all training and inference settings:

Binary file not shown.

After

Width:  |  Height:  |  Size: 753 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 358 KiB

View File

@ -17,15 +17,14 @@
inference: inference:
model_id: unsloth/Qwen2.5-VL-7B-Instruct model_id: unsloth/Qwen2.5-VL-7B-Instruct
finetuned_model_id: RLakshmi24/qwen_wildfire_qrpo_lora finetuned_model_id: saved_model
max_seq_length: 16384 max_seq_length: 8192
train: train:
model: model:
model_id: unsloth/Qwen2.5-VL-7B-Instruct model_id: unsloth/Qwen2.5-VL-7B-Instruct
max_seq_length: 16384 max_seq_length: 16384
use_lora: true use_lora: true
use_qlora: false
lora_config: lora_config:
rank: 32 rank: 32
alpha: 64 alpha: 64
@ -39,10 +38,10 @@ train:
dataset_id: data dataset_id: data
hyperparameters: hyperparameters:
epochs: 10 steps: 100
batch_size: 2 batch_size: 4
enable_grpo: true enable_grpo: true
num_generations: 4 num_generations: 2
format_reward: 2.0 format_reward: 2.0
learning_rate: 1e-5 learning_rate: 1e-5
correctness_reward: 5.0 correctness_reward: 5.0

View File

@ -1,15 +1,15 @@
data: data:
dataset_id: data dataset_id: data
hyperparameters: hyperparameters:
batch_size: 1 batch_size: 4
correctness_reward: 5.0 correctness_reward: 5.0
enable_grpo: true enable_grpo: true
epochs: 2
format_reward: 2.0 format_reward: 2.0
learning_rate: 1.0e-05 learning_rate: 1.0e-05
num_generations: 4 num_generations: 2
optimizer: adamw_torch optimizer: adamw_torch
output_dir: saved_model output_dir: saved_model
steps: 5
model: model:
finetune_attention_modules: true finetune_attention_modules: true
finetune_language_layers: true finetune_language_layers: true
@ -18,8 +18,7 @@ model:
lora_config: lora_config:
alpha: 64 alpha: 64
dropout: 0.05 dropout: 0.05
rank: 16 rank: 32
max_seq_length: 16384 max_seq_length: 16384
model_id: unsloth/Qwen2.5-VL-7B-Instruct model_id: unsloth/Qwen2.5-VL-7B-Instruct
use_lora: true use_lora: true
use_qlora: false

View File

@ -18,11 +18,15 @@
from unsloth import FastVisionModel from unsloth import FastVisionModel
import re import re
import sys
import yaml import yaml
import shutil
import signal
from PIL import ImageFile from PIL import ImageFile
from datasets import load_dataset from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer from trl import GRPOConfig, GRPOTrainer
from transformers.trainer_utils import get_last_checkpoint
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
@ -37,7 +41,7 @@ def load_model_for_train(config):
model, tokenizer = FastVisionModel.from_pretrained( model, tokenizer = FastVisionModel.from_pretrained(
model_name=config["model"]["model_id"], model_name=config["model"]["model_id"],
max_seq_length=config["model"]["max_seq_length"], max_seq_length=config["model"]["max_seq_length"],
load_in_4bit=config["model"]["use_qlora"], load_in_4bit=False,
) )
model = FastVisionModel.get_peft_model( model = FastVisionModel.get_peft_model(
@ -152,8 +156,6 @@ def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[floa
def start_train(config): def start_train(config):
# load base model for finetuning
model, tokenizer = load_model_for_train(config)
# load dataset # load dataset
train_dataset = load_wildfire_dataset(config, tokenizer) train_dataset = load_wildfire_dataset(config, tokenizer)
@ -171,17 +173,17 @@ def start_train(config):
log_completions=False, log_completions=False,
per_device_train_batch_size=config["hyperparameters"]["batch_size"], per_device_train_batch_size=config["hyperparameters"]["batch_size"],
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
num_generations=2, num_generations=config["hyperparameters"]["num_generations"],
max_prompt_length=config["model"]["max_seq_length"], max_prompt_length=config["model"]["max_seq_length"],
max_completion_length=config["model"]["max_seq_length"], max_completion_length=config["model"]["max_seq_length"],
num_train_epochs=config["hyperparameters"]["epochs"], max_steps=config["hyperparameters"]["steps"],
save_steps=100, save_steps=3,
max_grad_norm=0.1, max_grad_norm=0.1,
report_to="none", report_to="none",
output_dir=config["hyperparameters"]["output_dir"], output_dir=config["hyperparameters"]["output_dir"],
# importance_sampling_level="sequence", importance_sampling_level="sequence",
# mask_truncated_completions=False, mask_truncated_completions=False,
# loss_type="dr_grpo", loss_type="dr_grpo",
) )
# start training # start training
@ -197,8 +199,29 @@ def start_train(config):
) )
trainer.train() trainer.train()
handle_termination(None, None)
def handle_termination(signum, frame):
latest_checkpoint = get_last_checkpoint(config["hyperparameters"]["output_dir"])
if latest_checkpoint:
if config["model"]["use_lora"]:
print("Merging LoRA weights and saving the model")
shutil.rmtree(latest_checkpoint)
model.save_pretrained_merged(latest_checkpoint, tokenizer, save_method="merged_16bit")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_termination)
signal.signal(signal.SIGINT, handle_termination)
if __name__ == "__main__": if __name__ == "__main__":
with open("src/train.yaml", "r") as f: with open("src/train.yaml", "r") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
start_train(config) # load base model for finetuning
model, tokenizer = load_model_for_train(config)
start_train(config)

View File

@ -0,0 +1,84 @@
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))== Unsloth 2025.9.11: Fast Qwen2_5_Vl patching. Transformers: 4.56.2.
\\ /| NVIDIA GB10. Num GPUs = 1. Max memory: 119.699 GB. Platform: Linux.
O^O/ \_/ \ Torch: 2.9.0a0+50eac811a6.nv25.09. CUDA: 12.1. CUDA Toolkit: 13.0. Triton: 3.4.0
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.33+5146f2a.d20251005. FA2 = True]
"-____-" Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s] Loading checkpoint shards: 25%|██▌ | 1/4 [00:29<01:28, 29.57s/it] Loading checkpoint shards: 50%|█████ | 2/4 [00:58<00:58, 29.48s/it] Loading checkpoint shards: 75%|███████▌ | 3/4 [01:27<00:28, 28.99s/it] Loading checkpoint shards: 100%|██████████| 4/4 [01:36<00:00, 21.25s/it] Loading checkpoint shards: 100%|██████████| 4/4 [01:36<00:00, 24.20s/it]
Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
\\ /| Num examples = 30,250 | Num Epochs = 1 | Total steps = 30,250
O^O/ \_/ \ Batch size per device = 2 | Gradient accumulation steps = 1
\ / Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2
"-____-" Trainable parameters = 103,043,072 of 8,395,209,728 (1.23% trained)
0%| | 0/30250 [00:00<?, ?it/s]`generation_config` default values have been modified to match model-specific defaults: {'max_length': 32768, 'temperature': 1e-06, 'repetition_penalty': 1.05}. If this is not desired, please set these values explicitly.
----------------------------------
Question:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>Identify if this region has been affected by a wildfire. Also first provide your reasoning or working out on how you would go about identifying the presence of wildfire affected regions between <REASONING> and </REASONING> and then your final answer between <SOLUTION> and (put a simple Yes or No here) </SOLUTION><|im_end|>
<|im_start|>assistant
Answer:
No
Response:
<REASONING>
To identify if the region has been affected by a wildfire, we need to look for signs such as charred areas, absence of vegetation, or changes in the landscape that might indicate fire activity. In satellite imagery, these signs can be subtle but may include:
1. **Burnt Vegetation**: Areas where the vegetation appears to have been consumed by fire, often appearing as darkened patches.
2. **Ash or Ash Patches**: These can appear as lighter-colored areas compared to surrounding areas.
3. **Changes in Land Cover**: The presence of new structures or changes in the type of land cover that could indicate recent construction or clearing.
4. **Smoke Plumes**: While not visible in this still image, smoke plumes from wildfires can sometimes be seen in satellite imagery.
In this image:
- There are no obvious darkened or charred areas that would suggest recent fire activity.
- The vegetation appears intact, with no signs of ash or burnt areas.
- There are no visible smoke plumes or signs of recent construction that would indicate recent fire activity.
</REASONING>
<SOLUTION>
No
</SOLUTION>
0%| | 1/30250 [01:15<631:36:01, 75.17s/it] {'loss': 0.0, 'grad_norm': 3.589693005778827e-05, 'learning_rate': 0.0, 'num_tokens': 879.0, 'completions/mean_length': 205.5, 'completions/min_length': 177.0, 'completions/max_length': 234.0, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 205.5, 'completions/min_terminated_length': 177.0, 'completions/max_terminated_length': 234.0, 'rewards/format_reward_func/mean': 2.0, 'rewards/format_reward_func/std': 0.0, 'rewards/correctness_reward_func/mean': 5.0, 'rewards/correctness_reward_func/std': 0.0, 'reward': 7.0, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'completion_length': 205.5, 'kl': 0.0005636619171127677, 'epoch': 0.0}
0%| | 1/30250 [01:15<631:36:01, 75.17s/it]----------------------------------
Question:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>Identify if this region has been affected by a wildfire. Also first provide your reasoning or working out on how you would go about identifying the presence of wildfire affected regions between <REASONING> and </REASONING> and then your final answer between <SOLUTION> and (put a simple Yes or No here) </SOLUTION><|im_end|>
<|im_start|>assistant
Answer:
Yes
Response:
<REASONING>
To identify if a region has been affected by a wildfire, one would typically look for signs such as charred areas, burn scars, or changes in vegetation color and texture. In satellite imagery, these changes can be subtle but often noticeable.
1. **Burn Scars**: These are areas where the ground has been scorched, often appearing as lighter or different shades compared to surrounding areas.
2. **Vegetation Changes**: Wildfires can alter the appearance of vegetation, making it appear more brown or less dense than before.
3. **Ash and Debris**: Ash and debris from a fire can leave a distinct mark on the landscape, which might not be immediately obvious but can be seen with close inspection.
In the provided image:
- The area appears to have a mix of green and brown tones, which could indicate recent vegetation change.
- There is a distinct darker area that could be interpreted as a burn scar, but without more detailed information or higher resolution imagery, it's difficult to confirm.
- The overall pattern suggests a mix of forested and open areas, which is typical for many regions.
However, without more specific indicators like smoke plumes, ash deposits, or detailed post-fire analysis, it's challenging to definitively conclude the presence of a wildfire.
</REASONING>
<SOLUTION>
No
</SOLUTION>
Unsloth: Will smartly offload gradients to save VRAM!
0%| | 2/30250 [02:21<587:04:23, 69.87s/it] {'loss': 0.0, 'grad_norm': 3.785433727898635e-05, 'learning_rate': 3.3057851239669425e-09, 'num_tokens': 1864.0, 'completions/mean_length': 258.5, 'completions/min_length': 242.0, 'completions/max_length': 275.0, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 258.5, 'completions/min_terminated_length': 242.0, 'completions/max_terminated_length': 275.0, 'rewards/format_reward_func/mean': 2.0, 'rewards/format_reward_func/std': 0.0, 'rewards/correctness_reward_func/mean': 0.0, 'rewards/correctness_reward_func/std': 0.0, 'reward': 2.0, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'completion_length': 258.5, 'kl': 0.0006416599499061704, 'epoch': 0.0}
0%| | 2/30250 [02:21<587:04:23, 69.87s/it]