mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-22 18:13:52 +00:00
chore: Regenerate all playbooks
This commit is contained in:
parent
6818481902
commit
24114f735c
@ -94,6 +94,15 @@ For the provided prompt and random seed, the base Flux model generated the follo
|
||||
<figcaption>Base FLUX.1 model workflow without custom concept knowledge</figcaption>
|
||||
</figure>
|
||||
|
||||
After playing around with the base model, you have 2 possible next steps.
|
||||
* If you already have fine-tuned LoRAs placed inside `models/loras/`, please skip to [Load the finetuned workflow](#52-load-the-finetuned-workflow) section.
|
||||
* If you wish to train a LoRA for your custom concepts, first make sure that the ComfyUI inference container is brought down before proceeding to train. You can bring it by interrupting the terminal with `Ctrl+C` keystroke.
|
||||
|
||||
> **Note**: To clear out any extra occupied memory from your system, execute the following command after interrupting the ComfyUI server.
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
|
||||
## 3. Dataset Preparation
|
||||
|
||||
Let's prepare our dataset to perform Dreambooth LoRA finetuning on the FLUX.1-dev 12B model. However, if you wish to continue with the provided dataset of Toy Jensen and DGX Spark, feel free to skip to the [Training](#training) section. This dataset is a collection of public assets accessible via Google Images.
|
||||
@ -137,8 +146,6 @@ Now, let's modify the `flux_data/data.toml` file to reflect the concepts chosen.
|
||||
|
||||
### 4.1 Build the docker image
|
||||
|
||||
Make sure that the ComfyUI inference container is brought down before proceeding to train. You can bring it by interrupting the terminal with `Ctrl+C` keystroke.
|
||||
|
||||
```bash
|
||||
# Build the inference docker image
|
||||
docker build -f Dockerfile.train -t flux-train .
|
||||
|
||||
@ -105,9 +105,7 @@ Install the package in editable mode with metrics support for training evaluatio
|
||||
pip install -e ".[metrics]"
|
||||
```
|
||||
|
||||
## Step 5. Configure PyTorch for CUDA 12.9 (if needed)
|
||||
|
||||
*If using standalone Python (skip if using Docker container)*
|
||||
## Step 5. Configure PyTorch for CUDA 12.9 (skip if using Docker container from Step 2)
|
||||
|
||||
In a python virtual environment, uninstall existing PyTorch and reinstall with CUDA 12.9 support for ARM64 architecture.
|
||||
|
||||
|
||||
@ -119,6 +119,7 @@ From the root directory of the multi-agent-chatbot project, run the following co
|
||||
docker compose -f docker-compose.yml -f docker-compose-models.yml down
|
||||
|
||||
docker volume rm "$(basename "$PWD")_postgres_data"
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
You can optionally run `docker volume prune` to remove all unused volumes at the end of the demo.
|
||||
> **Note**: If you do not execute these commands containers, will continue to run and take up memory.
|
||||
|
||||
@ -474,7 +474,7 @@ class ChatAgent:
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
try:
|
||||
existing_messages = await self.conversation_store.get_messages(chat_id, limit=10)
|
||||
existing_messages = await self.conversation_store.get_messages(chat_id, limit=1)
|
||||
|
||||
base_system_prompt = self.system_prompt
|
||||
if image_data:
|
||||
|
||||
@ -19,7 +19,7 @@ from typing import Dict
|
||||
|
||||
|
||||
SUPERVISOR_AGENT_STR = """
|
||||
You are a supervisor agent whose role is to be a helpful planner that can use tools to answer questions. DO NOT WRITE CODE YOURSELF, ALWAYS USE THE TOOLS.
|
||||
You are a supervisor agent whose role is to be a helpful planner that can use tools to answer questions. Please be concise and to the point.
|
||||
|
||||
{% if tools %}
|
||||
IMPORTANT: You have access to these tools and you MUST use them when applicable and use tool response in your final answer:
|
||||
|
||||
@ -122,9 +122,9 @@ services:
|
||||
- "--host"
|
||||
- "0.0.0.0"
|
||||
- "-n"
|
||||
- "4096"
|
||||
- "65536"
|
||||
- "--n-gpu-layers"
|
||||
- "999"
|
||||
- "70"
|
||||
- "--jinja"
|
||||
|
||||
deepseek-coder:
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
*/
|
||||
import type React from "react";
|
||||
import { useRef, useEffect, useState } from "react";
|
||||
import { useRef, useEffect, useState, useCallback } from "react";
|
||||
import styles from "@/styles/QuerySection.module.css";
|
||||
import ReactMarkdown from 'react-markdown'; // NEW
|
||||
import remarkGfm from 'remark-gfm'; // NEW
|
||||
@ -188,22 +188,16 @@ export default function QuerySection({
|
||||
const chatContainerRef = useRef<HTMLDivElement>(null);
|
||||
const [showButtons, setShowButtons] = useState(false);
|
||||
const [showWelcome, setShowWelcome] = useState(true);
|
||||
const [inferenceStats, setInferenceStats] = useState({
|
||||
tokensReceived: 0,
|
||||
startTime: Date.now(),
|
||||
tokensPerSecond: 0
|
||||
});
|
||||
const [selectedSources, setSelectedSources] = useState<string[]>([]);
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const [uploadedImage, setUploadedImage] = useState<string | null>(null);
|
||||
const [imagePreview, setImagePreview] = useState<string | null>(null);
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [toolOutput, setToolOutput] = useState("");
|
||||
const [graphStatus, setGraphStatus] = useState("");
|
||||
const [isPinnedToolOutputVisible, setPinnedToolOutputVisible] = useState(false);
|
||||
const [isToolContentVisible, setIsToolContentVisible] = useState(false);
|
||||
const [fadeIn, setFadeIn] = useState(false);
|
||||
const firstTokenReceived = useRef(false);
|
||||
const hasAssistantContent = useRef(false);
|
||||
const fadeTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
@ -212,15 +206,6 @@ export default function QuerySection({
|
||||
return () => clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isStreaming) {
|
||||
setInferenceStats(prev => ({
|
||||
...prev,
|
||||
tokensReceived: 0,
|
||||
startTime: 0
|
||||
}));
|
||||
}
|
||||
}, [isStreaming]);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchSelectedSources = async () => {
|
||||
@ -276,9 +261,8 @@ export default function QuerySection({
|
||||
case "token": {
|
||||
if (!text) break;
|
||||
if (!firstTokenReceived.current) {
|
||||
console.log('TTFT: ', new Date().toISOString());
|
||||
firstTokenReceived.current = true;
|
||||
setIsStreaming(false);
|
||||
hasAssistantContent.current = true;
|
||||
}
|
||||
setResponse(prev => {
|
||||
try {
|
||||
@ -310,9 +294,6 @@ export default function QuerySection({
|
||||
case "tool_end":
|
||||
case "node_end": {
|
||||
console.log(type, msg.data);
|
||||
if (msg.data === 'generate') {
|
||||
console.log('generate complete. time: ', new Date().toISOString());
|
||||
}
|
||||
setGraphStatus("");
|
||||
break;
|
||||
}
|
||||
@ -342,7 +323,6 @@ export default function QuerySection({
|
||||
return () => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
setIsStreaming(false);
|
||||
}
|
||||
};
|
||||
}, [currentChatId]);
|
||||
@ -361,70 +341,90 @@ export default function QuerySection({
|
||||
useEffect(() => {
|
||||
if (graphStatus) {
|
||||
setPinnedToolOutputVisible(true);
|
||||
} else if (isPinnedToolOutputVisible) {
|
||||
// Trigger fade-in on next tick
|
||||
if (fadeTimeoutRef.current) {
|
||||
clearTimeout(fadeTimeoutRef.current);
|
||||
}
|
||||
setFadeIn(false);
|
||||
fadeTimeoutRef.current = setTimeout(() => setFadeIn(true), 10);
|
||||
} else {
|
||||
// Delay hiding to allow fade-out
|
||||
setFadeIn(false);
|
||||
const timeout = setTimeout(() => {
|
||||
setPinnedToolOutputVisible(false);
|
||||
}, 800); // match CSS transition duration
|
||||
return () => clearTimeout(timeout);
|
||||
return () => {
|
||||
clearTimeout(timeout);
|
||||
if (fadeTimeoutRef.current) {
|
||||
clearTimeout(fadeTimeoutRef.current);
|
||||
}
|
||||
};
|
||||
}
|
||||
}, [graphStatus, isPinnedToolOutputVisible]);
|
||||
|
||||
// Replace the effect for fade logic with this minimal version
|
||||
useEffect(() => {
|
||||
if (isPinnedToolOutputVisible && graphStatus) {
|
||||
setFadeIn(false);
|
||||
const t = setTimeout(() => setFadeIn(true), 10); // next tick for fade-in
|
||||
return () => clearTimeout(t);
|
||||
} else {
|
||||
setFadeIn(false);
|
||||
}
|
||||
}, [isPinnedToolOutputVisible, graphStatus]);
|
||||
|
||||
// Cleanup image preview URL on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (imagePreview) {
|
||||
URL.revokeObjectURL(imagePreview);
|
||||
}
|
||||
};
|
||||
}, [imagePreview]);
|
||||
}, [graphStatus]);
|
||||
|
||||
const programmaticScroll = useRef(false);
|
||||
const scrollTimeout = useRef<number | null>(null);
|
||||
const isUserScrollingRef = useRef(false);
|
||||
const isNearBottomRef = useRef(true);
|
||||
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(true); };
|
||||
const handleDragLeave = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(false); };
|
||||
const handleDragOver = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); };
|
||||
|
||||
const handleDrop = async (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
const imageFile = files.find(file => file.type.startsWith('image/'));
|
||||
|
||||
if (imageFile) {
|
||||
const previewUrl = URL.createObjectURL(imageFile);
|
||||
setImagePreview(previewUrl);
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('image', imageFile);
|
||||
formData.append('chat_id', currentChatId || '');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/upload-image', { method: 'POST', body: formData });
|
||||
const result = await response.json();
|
||||
setUploadedImage(result.image_id);
|
||||
} catch (error) {
|
||||
console.error('Error uploading image:', error);
|
||||
URL.revokeObjectURL(previewUrl);
|
||||
setImagePreview(null);
|
||||
}
|
||||
// Check if user is near the bottom of the chat
|
||||
const checkScrollPosition = useCallback(() => {
|
||||
if (chatContainerRef.current) {
|
||||
const container = chatContainerRef.current;
|
||||
const threshold = 100; // pixels from bottom
|
||||
const isNear = container.scrollHeight - container.scrollTop - container.clientHeight < threshold;
|
||||
isNearBottomRef.current = isNear;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Handle scroll events to detect user scrolling
|
||||
useEffect(() => {
|
||||
const container = chatContainerRef.current;
|
||||
if (!container) return;
|
||||
|
||||
let scrollTimer: NodeJS.Timeout;
|
||||
|
||||
const handleScroll = () => {
|
||||
isUserScrollingRef.current = true;
|
||||
checkScrollPosition();
|
||||
|
||||
// Reset user scrolling flag after scroll stops
|
||||
clearTimeout(scrollTimer);
|
||||
scrollTimer = setTimeout(() => {
|
||||
isUserScrollingRef.current = false;
|
||||
}, 150);
|
||||
};
|
||||
|
||||
container.addEventListener('scroll', handleScroll, { passive: true });
|
||||
|
||||
return () => {
|
||||
container.removeEventListener('scroll', handleScroll);
|
||||
clearTimeout(scrollTimer);
|
||||
};
|
||||
}, [checkScrollPosition]);
|
||||
|
||||
// Auto-scroll to bottom when response changes
|
||||
useEffect(() => {
|
||||
// Only scroll if we have assistant content and user hasn't manually scrolled away
|
||||
if (!hasAssistantContent.current || isUserScrollingRef.current || !isNearBottomRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scrollToBottom = () => {
|
||||
if (messagesEndRef.current) {
|
||||
messagesEndRef.current.scrollIntoView({
|
||||
behavior: 'smooth',
|
||||
block: 'end'
|
||||
});
|
||||
}
|
||||
|
||||
if (chatContainerRef.current) {
|
||||
chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight;
|
||||
}
|
||||
};
|
||||
|
||||
scrollToBottom();
|
||||
}, [response]);
|
||||
|
||||
const handleQuerySubmit = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
@ -434,13 +434,11 @@ export default function QuerySection({
|
||||
setQuery("");
|
||||
setIsStreaming(true);
|
||||
firstTokenReceived.current = false;
|
||||
hasAssistantContent.current = false;
|
||||
|
||||
try {
|
||||
console.log('sending uploaded image: ', uploadedImage, ' with query: ', currentQuery)
|
||||
console.log('current time: ', new Date().toISOString());
|
||||
wsRef.current.send(JSON.stringify({
|
||||
message: currentQuery,
|
||||
image_id: uploadedImage
|
||||
message: currentQuery
|
||||
}));
|
||||
|
||||
setResponse(prev => {
|
||||
@ -455,14 +453,6 @@ export default function QuerySection({
|
||||
return prev + `\n\nHuman: ${currentQuery}\n\nAssistant: `;
|
||||
}
|
||||
});
|
||||
|
||||
// NEW CODE
|
||||
if (imagePreview) {
|
||||
URL.revokeObjectURL(imagePreview);
|
||||
}
|
||||
setUploadedImage(null);
|
||||
setImagePreview(null);
|
||||
// NEW CODE
|
||||
} catch (error) {
|
||||
console.error("Error sending message:", error);
|
||||
setIsStreaming(false);
|
||||
@ -581,36 +571,7 @@ export default function QuerySection({
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleQuerySubmit} className={styles.inputContainer}>
|
||||
{/* NEW CODE - Image preview moved to the left of inputWrapper */}
|
||||
{imagePreview && (
|
||||
<div className={styles.imagePreview}>
|
||||
<img
|
||||
src={imagePreview}
|
||||
alt="Image preview"
|
||||
className={styles.previewImage}
|
||||
/>
|
||||
<button
|
||||
className={styles.removeImageButton}
|
||||
onClick={() => {
|
||||
if (imagePreview) {
|
||||
URL.revokeObjectURL(imagePreview);
|
||||
}
|
||||
setUploadedImage(null);
|
||||
setImagePreview(null);
|
||||
}}
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* NEW CODE */}
|
||||
<div
|
||||
className={`${styles.inputWrapper} ${isDragging ? styles.dragging : ''}`}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className={styles.inputWrapper}>
|
||||
<textarea
|
||||
rows={1}
|
||||
value={query}
|
||||
@ -648,14 +609,10 @@ export default function QuerySection({
|
||||
<div className={styles.disclaimer}>
|
||||
This is a concept demo to showcase multiple models and MCP use. It is not optimized for performance. Developers can customize and further optimize it for performance.
|
||||
<br />
|
||||
<span className={styles.info}>Note: If a response is cut short, please start a new chat to continue.</span>
|
||||
<br />
|
||||
<span className={styles.warning}>Don't forget to shutdown docker containers at the end of the demo.</span>
|
||||
</div>
|
||||
|
||||
{inferenceStats.tokensPerSecond > 0 && (
|
||||
<div className={styles.inferenceStats}>
|
||||
{inferenceStats.tokensPerSecond} tokens/sec
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@ export default function Sidebar({
|
||||
}, [refreshTrigger, fetchSources]);
|
||||
|
||||
// Add function to fetch chat metadata
|
||||
const fetchChatMetadata = async (chatId: string) => {
|
||||
const fetchChatMetadata = useCallback(async (chatId: string) => {
|
||||
try {
|
||||
const response = await fetch(`/api/chat/${chatId}/metadata`);
|
||||
if (response.ok) {
|
||||
@ -178,7 +178,7 @@ export default function Sidebar({
|
||||
} catch (error) {
|
||||
console.error(`Error fetching metadata for chat ${chatId}:`, error);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Update fetchChats to also fetch metadata
|
||||
const fetchChats = useCallback(async () => {
|
||||
@ -201,7 +201,7 @@ export default function Sidebar({
|
||||
} finally {
|
||||
setIsLoadingChats(false);
|
||||
}
|
||||
}, []);
|
||||
}, [fetchChatMetadata]);
|
||||
|
||||
// Fetch chats when history section is expanded
|
||||
useEffect(() => {
|
||||
|
||||
@ -346,11 +346,12 @@
|
||||
|
||||
.inputContainer {
|
||||
display: flex;
|
||||
padding: 12px 16px;
|
||||
padding: 12px 16px 4px 16px;
|
||||
position: relative;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
background-color: #f9fafb;
|
||||
margin-top: 0;
|
||||
margin-bottom: 0;
|
||||
width: 100%;
|
||||
opacity: 0;
|
||||
@ -473,7 +474,7 @@
|
||||
}
|
||||
}
|
||||
|
||||
.sendButton.show, .streamingCancelButton.show, .uploadButton.show {
|
||||
.sendButton.show, .streamingCancelButton.show {
|
||||
animation: buttonAppear 0.3s ease forwards;
|
||||
}
|
||||
|
||||
@ -559,21 +560,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
.inferenceStats {
|
||||
position: fixed;
|
||||
bottom: 10px;
|
||||
right: 10px;
|
||||
background-color: rgba(0, 0, 0, 0.6);
|
||||
color: white;
|
||||
padding: 4px 8px;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
font-family: var(--font-geist-mono);
|
||||
z-index: 10;
|
||||
pointer-events: none;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.querySection {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@ -665,97 +651,6 @@
|
||||
40% { transform: scale(1); }
|
||||
}
|
||||
|
||||
.uploadButton {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
background-color: transparent;
|
||||
color: #76B900;
|
||||
transition: all 0.2s ease;
|
||||
padding: 0;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
position: absolute;
|
||||
left: 8px;
|
||||
top: 50%;
|
||||
transform: translateY(-50%);
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.uploadButton.show {
|
||||
animation: buttonAppear 0.3s ease forwards;
|
||||
}
|
||||
|
||||
.uploadButton:hover {
|
||||
transform: translateY(-50%) scale(1.1);
|
||||
color: #669f00;
|
||||
}
|
||||
|
||||
:global(.dark) .uploadButton {
|
||||
color: #76B900;
|
||||
}
|
||||
|
||||
:global(.dark) .uploadButton:hover {
|
||||
color: #669f00;
|
||||
}
|
||||
|
||||
.inputWrapper.dragging {
|
||||
border: 2px dashed #76B900;
|
||||
background-color: rgba(118, 185, 0, 0.1);
|
||||
}
|
||||
|
||||
.imagePreview {
|
||||
position: relative;
|
||||
width: 150px;
|
||||
height: 100px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
/* :global(.dark) .imagePreview {
|
||||
background: #1f2937;
|
||||
} */
|
||||
|
||||
.previewImage {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
display: block;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.removeImageButton {
|
||||
position: absolute;
|
||||
top: -2px;
|
||||
right: -2px;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border-radius: 50%;
|
||||
background: rgba(0, 0, 0, 0.251); /* red with 50% opacity */
|
||||
color: white;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 12px;
|
||||
font-weight: bold;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.removeImageButton:hover {
|
||||
background: #e53e3e;
|
||||
}
|
||||
/* NEW CODE */
|
||||
|
||||
|
||||
|
||||
/* Adding styles for tool calling interface */
|
||||
.toolCallingWrapper {
|
||||
display: flex;
|
||||
@ -1033,10 +928,10 @@
|
||||
|
||||
.disclaimer {
|
||||
text-align: center;
|
||||
font-size: 12px;
|
||||
font-size: 10px;
|
||||
color: #6b7280;
|
||||
padding: 2px 16px;
|
||||
margin: -16px auto 0 auto;
|
||||
margin: -12px auto 0 auto;
|
||||
max-width: 600px;
|
||||
line-height: 1.4;
|
||||
opacity: 0.8;
|
||||
@ -1047,6 +942,15 @@
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
.info {
|
||||
color: #3b82f6;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
:global(.dark) .info {
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
.warning {
|
||||
color: #f59e0b;
|
||||
font-weight: 500;
|
||||
|
||||
@ -159,10 +159,25 @@ RUN git clone https://github.com/triton-lang/triton.git && \
|
||||
# install xformers from source for blackwell support
|
||||
RUN git clone --depth=1 https://github.com/facebookresearch/xformers --recursive && \
|
||||
cd xformers && \
|
||||
git checkout 5146f2ab37b2163985c19fb4e8fbf6183e82f8ce && \
|
||||
export TORCH_CUDA_ARCH_LIST="12.1" && \
|
||||
python setup.py install && \
|
||||
cd ..
|
||||
# install unsloth without depedencies so we can build them from source
|
||||
RUN pip install unsloth unsloth_zoo bitsandbytes==0.48.0
|
||||
RUN pip install unsloth==2025.9.11 unsloth_zoo==2025.9.14 bitsandbytes==0.48.0
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
|
||||
# docker run \
|
||||
# --rm \
|
||||
# --gpus=all \
|
||||
# --ipc=host \
|
||||
# --net=host \
|
||||
# --ulimit memlock=-1 \
|
||||
# --ulimit stack=67108864 \
|
||||
# -w $(pwd) \
|
||||
# -v $(pwd):$(pwd) \
|
||||
# -v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||
# nvcr.io/nvidia/vllm:25.09-py3 \
|
||||
# vllm serve "unsloth/Qwen2.5-VL-7B-Instruct" --port "8000" --served-model-name "base-model" --max-model-len 16384 --gpu-memory-utilization 0.3 --async-scheduling --enable_prefix_caching
|
||||
@ -26,18 +26,13 @@ docker build --build-arg HF_TOKEN=$HF_TOKEN -t vlm_demo .
|
||||
### 2. Launch the Container
|
||||
|
||||
```bash
|
||||
# Run the container with GPU support
|
||||
docker run -it \
|
||||
--gpus=all \
|
||||
--net=host \
|
||||
--ipc=host \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
-v $(pwd):/vlm_finetuning \
|
||||
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||
vlm_demo
|
||||
# Enter the correct directory for building the image
|
||||
cd vlm-finetuning/assets
|
||||
|
||||
# Enter the mounted directory
|
||||
# Run the container with GPU support
|
||||
sh launch.sh
|
||||
|
||||
# Enter the mounted directory within the container
|
||||
cd /vlm_finetuning
|
||||
```
|
||||
|
||||
|
||||
@ -20,7 +20,12 @@ docker run -it \
|
||||
--gpus=all \
|
||||
--net=host \
|
||||
--ipc=host \
|
||||
-w $HOME \
|
||||
-v $HOME:$HOME \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
-e HOST_HOME=$HOME \
|
||||
-e HOST_PWD=$(pwd) \
|
||||
-v $(pwd):/vlm_finetuning \
|
||||
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v /usr/bin/docker:/usr/bin/docker \
|
||||
vlm_demo
|
||||
|
||||
@ -15,20 +15,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
import os
|
||||
import re
|
||||
import gc
|
||||
import json
|
||||
import yaml
|
||||
import glob
|
||||
import time
|
||||
import base64
|
||||
import random
|
||||
import requests
|
||||
import subprocess
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
|
||||
REASONING_START = "<REASONING>"
|
||||
@ -37,14 +37,6 @@ SOLUTION_START = "<SOLUTION>"
|
||||
SOLUTION_END = "</SOLUTION>"
|
||||
|
||||
|
||||
def initialize_session_state(resources):
|
||||
# Initialize page-specific session state
|
||||
st.session_state["base"] = st.session_state.get("base", resources["base"])
|
||||
st.session_state["finetuned"] = st.session_state.get("finetuned", resources["finetuned"])
|
||||
st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[0])
|
||||
st.session_state["train_process"] = st.session_state.get("train_process", None)
|
||||
|
||||
|
||||
def load_config():
|
||||
config_key = "config"
|
||||
if getattr(st.session_state, config_key, None) is None:
|
||||
@ -58,14 +50,106 @@ def load_config():
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def initialize_resources(inference_config):
|
||||
base_model, base_tokenizer = load_model_for_inference(inference_config, "base")
|
||||
finetuned_model, finetuned_tokenizer = load_model_for_inference(inference_config, "finetuned")
|
||||
def start_vllm_server(model_id, model_type, max_seq_length, port):
|
||||
# get pwd
|
||||
return subprocess.Popen([
|
||||
"docker", "run",
|
||||
"--rm",
|
||||
"--gpus=all",
|
||||
"--ipc=host",
|
||||
"--net=host",
|
||||
"--ulimit", "memlock=-1",
|
||||
"--ulimit", "stack=67108864",
|
||||
"-v", f"{os.environ.get('HOST_HOME')}/.cache/huggingface:/root/.cache/huggingface",
|
||||
"-v", f"{os.environ.get('HOST_PWD')}/ui_image/saved_model:/workspace/saved_model",
|
||||
"nvcr.io/nvidia/vllm:25.09-py3",
|
||||
"vllm", "serve",
|
||||
model_id,
|
||||
"--port", str(port),
|
||||
"--served-model-name", model_type,
|
||||
"--max-model-len", str(max_seq_length),
|
||||
"--gpu-memory-utilization", "0.45",
|
||||
"--async-scheduling",
|
||||
"--enable_prefix_caching"
|
||||
])
|
||||
|
||||
return {
|
||||
"base": {"model": base_model, "tokenizer": base_tokenizer},
|
||||
"finetuned": {"model": finetuned_model, "tokenizer": finetuned_tokenizer},
|
||||
}
|
||||
|
||||
def check_vllm_health(model_type, port):
|
||||
try :
|
||||
output = json.loads(subprocess.check_output(
|
||||
["curl", "-s", f"http://localhost:{port}/v1/models"],
|
||||
text=True
|
||||
))
|
||||
|
||||
return output["data"][0]["id"] == model_type
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def invoke_vllm_server(model_type, prompt, image, port):
|
||||
with open(image, "rb") as f:
|
||||
image = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
payload = json.dumps({
|
||||
"model": model_type,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
})
|
||||
|
||||
return requests.post(
|
||||
f"http://localhost:{port}/v1/chat/completions",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=payload
|
||||
).json()["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
def initialize_state(config):
|
||||
st.session_state["mode"] = st.session_state.get("mode", "inference")
|
||||
|
||||
st.session_state["base"] = st.session_state.get("base", {})
|
||||
st.session_state["finetuned"] = st.session_state.get("finetuned", {})
|
||||
|
||||
st.session_state["base"]["port"] = st.session_state["base"].get("port", "8000")
|
||||
st.session_state["finetuned"]["port"] = st.session_state["finetuned"].get("port", "8001")
|
||||
|
||||
if st.session_state["mode"] == "inference":
|
||||
st.session_state["base"]["process"] = start_vllm_server(
|
||||
config["model_id"], "base", config["max_seq_length"], st.session_state["base"]["port"])
|
||||
finetuned_model_path = get_last_checkpoint(config["finetuned_model_id"])
|
||||
if finetuned_model_path is not None:
|
||||
st.session_state["finetuned"]["process"] = start_vllm_server(
|
||||
finetuned_model_path, "finetuned", config["max_seq_length"], st.session_state["finetuned"]["port"])
|
||||
|
||||
if not check_vllm_health("base", st.session_state["base"]["port"]):
|
||||
with st.spinner("Loading vLLM server for base model..."):
|
||||
while not check_vllm_health("base", st.session_state["base"]["port"]):
|
||||
time.sleep(1)
|
||||
st.toast("Base model loaded", icon="✅", duration="short")
|
||||
|
||||
if finetuned_model_path is not None:
|
||||
if not check_vllm_health("finetuned", st.session_state["finetuned"]["port"]):
|
||||
with st.spinner("Loading vLLM server for finetuned model..."):
|
||||
while not check_vllm_health("finetuned", st.session_state["finetuned"]["port"]):
|
||||
time.sleep(1)
|
||||
st.toast("Finetuned model loaded", icon="✅", duration="short")
|
||||
|
||||
st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[-1])
|
||||
st.session_state["train_process"] = st.session_state.get("train_process", None)
|
||||
|
||||
|
||||
def main():
|
||||
@ -80,12 +164,7 @@ def main():
|
||||
|
||||
# load resources
|
||||
config = load_config()
|
||||
if st.session_state.get("base", None) is None:
|
||||
st.toast("Loading model", icon="⏳", duration="short")
|
||||
resource = initialize_resources(config["inference"])
|
||||
if st.session_state.get("base", None) is None:
|
||||
st.toast("Model loaded", icon="✅", duration="short")
|
||||
initialize_session_state(resource)
|
||||
initialize_state(config["inference"])
|
||||
|
||||
# train section
|
||||
st.markdown("---")
|
||||
@ -103,11 +182,11 @@ def train_section():
|
||||
with column_1:
|
||||
finetuning_method = st.selectbox(
|
||||
"Finetuning Method:",
|
||||
["LoRA", "QLoRA", "Full Fine-tuning"],
|
||||
["LoRA", "Full Fine-tuning"],
|
||||
)
|
||||
|
||||
|
||||
# update lora config
|
||||
if finetuning_method in ("QLoRA", "LoRA"):
|
||||
if finetuning_method == "LoRA":
|
||||
lora_config = st.session_state["config"]["train"]["model"]["lora_config"]
|
||||
|
||||
with column_2:
|
||||
@ -135,7 +214,6 @@ def train_section():
|
||||
|
||||
# update model config based on selection
|
||||
st.session_state["config"]["train"]["model"]["use_lora"] = finetuning_method == "LoRA"
|
||||
st.session_state["config"]["train"]["model"]["use_qlora"] = finetuning_method == "QLoRA"
|
||||
|
||||
# update train config
|
||||
st.write("")
|
||||
@ -149,12 +227,12 @@ def train_section():
|
||||
finetune_language_layers = st.toggle(
|
||||
"Finetune Language Layers",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_language_layers"])
|
||||
|
||||
|
||||
with column_3:
|
||||
finetune_attention_modules = st.toggle(
|
||||
"Finetune Attention Modules",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_attention_modules"])
|
||||
|
||||
|
||||
with column_4:
|
||||
finetune_mlp_modules = st.toggle(
|
||||
"Finetune MLP Modules",
|
||||
@ -163,11 +241,11 @@ def train_section():
|
||||
st.write("")
|
||||
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
|
||||
with column_1:
|
||||
epochs = st.slider(
|
||||
"Epochs",
|
||||
steps = st.slider(
|
||||
"Steps",
|
||||
min_value=1,
|
||||
max_value=100,
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["epochs"])
|
||||
max_value=1000,
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["steps"])
|
||||
|
||||
with column_2:
|
||||
batch_size = st.select_slider(
|
||||
@ -189,7 +267,7 @@ def train_section():
|
||||
options=["adamw_torch", "adafactor"])
|
||||
|
||||
st.session_state["config"]["train"]["hyperparameters"].update({
|
||||
'epochs': epochs,
|
||||
'steps': steps,
|
||||
'batch_size': batch_size,
|
||||
'learning_rate': learning_rate,
|
||||
'optimizer': optimizer,
|
||||
@ -216,7 +294,7 @@ def train_section():
|
||||
min_value=0.0,
|
||||
max_value=5.0,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["format_reward"]),
|
||||
format="%.2e")
|
||||
format="%.2f")
|
||||
|
||||
with column_3:
|
||||
correctness_reward = st.number_input(
|
||||
@ -224,15 +302,14 @@ def train_section():
|
||||
min_value=0.0,
|
||||
max_value=5.0,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["correctness_reward"]),
|
||||
format="%.2e")
|
||||
format="%.2f")
|
||||
|
||||
with column_4:
|
||||
num_generations = st.number_input(
|
||||
"Number of generations",
|
||||
min_value=1,
|
||||
max_value=16,
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"],
|
||||
format="%.2e")
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"])
|
||||
|
||||
# Training control
|
||||
st.write("")
|
||||
@ -242,15 +319,29 @@ def train_section():
|
||||
button_type = "secondary" if st.session_state["train_process"] else "primary"
|
||||
if st.button("▶️ Start Finetuning", type=button_type, width="stretch", disabled=bool(st.session_state["train_process"])):
|
||||
if st.session_state["train_process"] is None:
|
||||
st.session_state["base"]["process"].terminate()
|
||||
st.session_state["base"]["process"].wait()
|
||||
st.session_state["base"]["process"] = None
|
||||
if "finetuned" in st.session_state and "process" in st.session_state["finetuned"]:
|
||||
st.session_state["finetuned"]["process"].terminate()
|
||||
st.session_state["finetuned"]["process"].wait()
|
||||
st.session_state["finetuned"]["process"] = None
|
||||
st.session_state["mode"] = "train"
|
||||
st.cache_resource.clear()
|
||||
|
||||
# store config
|
||||
with open("src/train.yaml", "w") as f:
|
||||
yaml.dump(st.session_state["config"]["train"], f, default_flow_style=False)
|
||||
|
||||
# start training
|
||||
st.session_state["train_process"] = subprocess.Popen(
|
||||
["python", "src/train_image_vlm.py"],
|
||||
stdout=None, stderr=None
|
||||
)
|
||||
with open("/tmp/logs.txt", "w") as f:
|
||||
st.session_state["train_process"] = subprocess.Popen(
|
||||
["python", "-u", "src/train_image_vlm.py"],
|
||||
stdout=f,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True
|
||||
)
|
||||
st.toast("Training started", icon="✅", duration="short")
|
||||
else:
|
||||
st.toast("Training already in progress", icon="❌", duration="short")
|
||||
|
||||
@ -259,38 +350,71 @@ def train_section():
|
||||
if st.button("⏹️ Stop Finetuning", type=button_type, width="stretch", disabled=not bool(st.session_state["train_process"])):
|
||||
if st.session_state["train_process"] is not None:
|
||||
st.session_state["train_process"].terminate()
|
||||
st.session_state["train_process"].wait()
|
||||
st.session_state["train_process"] = None
|
||||
st.session_state["mode"] = "inference"
|
||||
st.toast("Training stopped", icon="✅", duration="short")
|
||||
st.toast("Re-deploy the app with updated finetuned model", icon=":material/info:", duration="short")
|
||||
st.rerun()
|
||||
else:
|
||||
st.toast("No training to stop", icon="❌", duration="short")
|
||||
|
||||
|
||||
with column_3:
|
||||
if st.session_state["train_process"]:
|
||||
st.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch")
|
||||
else:
|
||||
st.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch")
|
||||
badge_holder = st.empty()
|
||||
|
||||
# display wandb
|
||||
runs = wandb.Api().runs(f"{os.environ.get('WANDB_ENTITY')}/{os.environ.get('WANDB_PROJECT')}")
|
||||
if runs:
|
||||
base_url = runs[0].url
|
||||
loss_url = f"{base_url}?panelDisplayName=train%2Floss&panelSectionName=train"
|
||||
memory_url = f"{base_url}?panelDisplayName=GPU+Memory+Allocated+%28%25%29&panelSectionName=System"
|
||||
# create empty holders
|
||||
columns = st.columns(4)
|
||||
with columns[0]:
|
||||
steps_holder = st.empty()
|
||||
with columns[1]:
|
||||
format_reward_holder = st.empty()
|
||||
with columns[2]:
|
||||
correctness_reward_holder = st.empty()
|
||||
with columns[3]:
|
||||
total_reward_holder = st.empty()
|
||||
df_holder = st.empty()
|
||||
|
||||
column_1, column_2 = st.columns(2)
|
||||
with column_1:
|
||||
st.markdown(f"""
|
||||
<div class="wandb-wrapper">
|
||||
<iframe src="{loss_url}" class="wandb-iframe"></iframe>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
with column_2:
|
||||
st.markdown(f"""
|
||||
<div class="wandb-wrapper">
|
||||
<iframe src="{memory_url}" class="wandb-iframe"></iframe>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
# parse grpo logs
|
||||
if st.session_state["train_process"] is not None:
|
||||
while True:
|
||||
output = open("/tmp/logs.txt", "r").read().strip()
|
||||
|
||||
logs = []
|
||||
for line in output.split("\n"):
|
||||
if "{" in line and "}" in line:
|
||||
dict_match = re.search(r"\{[^}]+\}", line)
|
||||
if dict_match:
|
||||
log_dict = eval(dict_match.group())
|
||||
if isinstance(log_dict, dict) and any(k in log_dict for k in [
|
||||
"rewards/format_reward_func/mean",
|
||||
"rewards/correctness_reward_func/mean",
|
||||
"reward",
|
||||
]):
|
||||
logs.append(log_dict)
|
||||
|
||||
df = pd.DataFrame(logs)
|
||||
if "reward" in df.columns:
|
||||
steps_holder.metric("Steps", f"{len(df)}" if len(df) > 0 else "N/A")
|
||||
format_reward_holder.metric("Format Reward", f"{df['rewards/format_reward_func/mean'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
|
||||
correctness_reward_holder.metric("Correctness Reward", f"{df['rewards/correctness_reward_func/mean'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
|
||||
total_reward_holder.metric("Total Reward", f"{df['reward'].iloc[-1]:.4f}" if len(df) > 0 else "N/A")
|
||||
|
||||
badge_holder.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch")
|
||||
else:
|
||||
badge_holder.badge("Loading", icon=":material/hourglass_empty:", color="yellow", width="stretch")
|
||||
|
||||
df_holder.dataframe(df, width="stretch", hide_index=True)
|
||||
time.sleep(1)
|
||||
|
||||
if st.session_state["train_process"] is None or st.session_state["train_process"].poll() is not None:
|
||||
st.session_state["train_process"].terminate()
|
||||
st.session_state["train_process"].wait()
|
||||
st.session_state["train_process"] = None
|
||||
st.session_state["mode"] = "inference"
|
||||
st.toast("Training stopped", icon="✅", duration="short")
|
||||
st.rerun()
|
||||
|
||||
else:
|
||||
badge_holder.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch")
|
||||
|
||||
|
||||
def inference_section():
|
||||
@ -342,33 +466,15 @@ def inference_section():
|
||||
response = start_inference("base")
|
||||
base_generation.markdown(response)
|
||||
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("finetuned")
|
||||
finetuned_generation.markdown(response)
|
||||
if "finetuned" in st.session_state and "process" in st.session_state["finetuned"]:
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("finetuned")
|
||||
finetuned_generation.markdown(response)
|
||||
else:
|
||||
finetuned_generation.markdown("```No response since there is no finetuned model```")
|
||||
|
||||
|
||||
def load_model_for_inference(config, model_type):
|
||||
if model_type == "finetuned":
|
||||
model_name = config["finetuned_model_id"]
|
||||
elif model_type == "base":
|
||||
model_name = config["model_id"]
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {model_type}")
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=config["max_seq_length"],
|
||||
load_in_4bit=False,
|
||||
)
|
||||
|
||||
FastVisionModel.for_inference(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def start_inference(model_type):
|
||||
# define prompt
|
||||
prompt = st.session_state["prompt"]
|
||||
if model_type == "finetuned":
|
||||
prompt = (
|
||||
@ -377,48 +483,13 @@ def start_inference(model_type):
|
||||
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
|
||||
)
|
||||
|
||||
# load image
|
||||
image = Image.open(st.session_state["current_image"])
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# construct instruction prompt
|
||||
prompt = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# apply chat template
|
||||
prompt = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].apply_chat_template(
|
||||
response = invoke_vllm_server(
|
||||
model_type,
|
||||
prompt,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
st.session_state["current_image"],
|
||||
st.session_state[model_type]["port"]
|
||||
)
|
||||
|
||||
# tokenize inputs
|
||||
inputs = st.session_state[f"{model_type}_image_vlm"]["tokenizer"](
|
||||
image,
|
||||
prompt,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
).to("cuda")
|
||||
|
||||
# perform inference
|
||||
response = st.session_state[f"{model_type}_image_vlm"]["model"].generate(
|
||||
**inputs,
|
||||
max_new_tokens=1024,
|
||||
use_cache=True,
|
||||
do_sample=False
|
||||
)[0][inputs["input_ids"].shape[1]: ]
|
||||
|
||||
# decode tokens
|
||||
response = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].decode(response, skip_special_tokens=True)
|
||||
|
||||
# format response
|
||||
if model_type == "finetuned":
|
||||
response = response.replace(REASONING_START, "```")
|
||||
|
||||
@ -7,111 +7,172 @@ This project demonstrates fine-tuning Vision-Language Models (VLMs) for image un
|
||||
The project includes:
|
||||
- **Interactive Training Interface**: Streamlit-based UI for configuring and monitoring VLM fine-tuning
|
||||
- **GRPO Training**: Advanced preference optimization for better reasoning capabilities
|
||||
- **Multiple Fine-tuning Methods**: Support for LoRA, QLoRA, and Full Finetuning
|
||||
- **Side-by-side Inference**: Compare base model vs fine-tuned model performance
|
||||
- **Multiple Fine-tuning Methods**: Support for LoRA and Full Finetuning
|
||||
- **Side-by-side Inference using vLLM**: Run the base model and fine-tuned model side-by-side to compare performance
|
||||
|
||||
## Getting Started
|
||||
## Contents
|
||||
1. [Model Download](#1-model-download)
|
||||
2. [Dataset Preparation](#2-dataset-preparation)
|
||||
3. [Base Model Inference](#3-base-model-inference)
|
||||
4. [GRPO Finetuning](#4-grpo-finetuning)
|
||||
5. [Finetuned Model Inference](#5-finetuned-model-inference)
|
||||
|
||||
> **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/`.
|
||||
## 1. Model Download
|
||||
|
||||
### 1. Set Up Weights & Biases
|
||||
> **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/assets/README.md`.
|
||||
|
||||
Configure your wandb credentials for training monitoring:
|
||||
### 1.1 Download the pre-trained model
|
||||
|
||||
```bash
|
||||
export WANDB_PROJECT="vlm_finetuning"
|
||||
export WANDB_ENTITY=<WANDB_USERNAME>
|
||||
export WANDB_API_KEY=<WANDB_API_KEY>
|
||||
hf download Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
### 2. Launch the Application
|
||||
### 1.2 (Optional) Download the fine-tuned model
|
||||
|
||||
```bash
|
||||
# Start the Streamlit interface
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
If you already have a fine-tuned checkpoint, place it in the `saved_model/` folder.
|
||||
|
||||
The application will be available at `http://localhost:8501`
|
||||
# TODO: SHOW TREE AND SKIP TO INFERENCE
|
||||
|
||||
## Training
|
||||
If you already have a finetuned checkpoint that you would like to just use for a comparative analysis against the base model, skip directly to the [Finetuned Model Inference](#5-finetuned-model-inference) section.
|
||||
|
||||
### Dataset
|
||||
## 2. Dataset Preparation
|
||||
|
||||
The project uses a **wildfire detection dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes:
|
||||
The project uses a **Wildfire Detection Dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes:
|
||||
- Satellite and aerial imagery from wildfire-affected areas
|
||||
- Binary classification: wildfire vs no wildfire
|
||||
|
||||
#### Dataset Setup
|
||||
|
||||
1. **Download from Kaggle**: Visit the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) on Kaggle
|
||||
|
||||
2. **Get the curl command**: On the Kaggle dataset page, click the download button and copy the curl command provided
|
||||
|
||||
3. **Download and extract**: Run the following commands in your container:
|
||||
### 2.1 Create a dataset folder
|
||||
|
||||
```bash
|
||||
mkdir data
|
||||
cd data
|
||||
mkdir -p ui_image/data
|
||||
cd ui_image/data
|
||||
```
|
||||
|
||||
# Paste the curl command from Kaggle here, and then continue to unzip the dataset
|
||||
### 2.2 Dataset Download
|
||||
|
||||
For this finetuning playbook, we will use the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) from Kaggle. Visit the kaggle dataset page [here](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) to click the download button. Select the `cURL` option in the `Download Via` dropdown and copy the curl command.
|
||||
|
||||
> **Note**: You will need to be logged into Kaggle and may need to accept the dataset terms before the download link works.
|
||||
|
||||
Run the following commands in your container:
|
||||
|
||||
```bash
|
||||
# Past and run the curl command from Kaggle here, and then continue to unzip the dataset
|
||||
|
||||
unzip -qq wildfire-prediction-dataset.zip
|
||||
rm wildfire-prediction-dataset.zip
|
||||
cd ..
|
||||
```
|
||||
|
||||
> **Note**: You'll need to be logged into Kaggle and may need to accept the dataset terms before the download link works.
|
||||
## 3. Base Model Inference
|
||||
|
||||
### Training Configuration
|
||||
Before we start finetuning, let's start spin up the demo UI to evaluate the base model's performance on this task.
|
||||
|
||||
Configure training through the interactive interface:
|
||||
### 3.1 Spin up the Streamlit demo
|
||||
|
||||
#### Model Settings
|
||||
- **Base Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
- **Fine-tuning Method**: Choose from LoRA, QLoRA, or Full Finetuning
|
||||
- **LoRA Parameters**: Adjustable rank (8-64) and alpha (8-64)
|
||||
```bash
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
|
||||
#### Training Parameters
|
||||
- **Epochs**: 1-100 (default: 10)
|
||||
- **Batch Size**: 1, 2, 4, 8, or 16 (default: 2)
|
||||
- **Learning Rate**: 1e-6 to 1e-2 (default: 1e-5)
|
||||
- **Optimizer**: AdamW or Adafactor
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
#### GRPO Settings
|
||||
- **Format Reward**: 2.0 (reward for proper reasoning format)
|
||||
- **Correctness Reward**: 5.0 (reward for correct answers)
|
||||
- **Number of Generations**: 4 (for preference optimization)
|
||||
### 3.2 Wait for demo spin-up
|
||||
|
||||
### Training Process
|
||||
When you access the streamlit demo for the first time, the backend triggers vLLM servers to spin up for the base model. You will see a spinner on the demo site as vLLM is being brought up for optimized inference. This step can take upto 15 mins.
|
||||
|
||||
1. **Configure Parameters**: Use the web interface to set training hyperparameters
|
||||
2. **Start Training**: Click "▶️ Start Finetuning" to begin GRPO training
|
||||
3. **Monitor Progress**: View real-time loss curves and GPU utilization via embedded wandb charts
|
||||
4. **Stop if Needed**: Use "⏹️ Stop Finetuning" to halt training early
|
||||
After the streamlit demo is fully loaded, you should be able to see a similar UI state that is ready for inference.
|
||||
|
||||
> **Important**: After training completes, follow these steps:
|
||||
> 1. **Stop the UI**: Use Ctrl+C to stop the Streamlit application
|
||||
> 2. **Update Config**: Edit `src/image_vlm_config.yaml` and change the `finetuned_model_id` path to point to your newly trained model in the `saved_model/` directory
|
||||
> 3. **Restart UI**: Launch the application again to test your fine-tuned model
|
||||
<figure>
|
||||
<img src="assets/inference_page.png" alt="Inference Page" width="1000"/>
|
||||
<figcaption>Inference demo on the UI</figcaption>
|
||||
</figure>
|
||||
|
||||
## Inference
|
||||
### 3.3 Run base model inference
|
||||
|
||||
### Interactive Comparison
|
||||
Since we are currently focused on inferring the base model, let's scroll down to the `Image Inference` section of the UI. Here, you should see a sample pre-loaded satellite image of a potentially wildfire-affected region.
|
||||
|
||||

|
||||
*Side-by-side comparison showing base model vs fine-tuned model performance on wildfire detection*
|
||||
Enter your prompt in the chat box and hit `Generate`. Your prompt would be first sent to the base model and you should see the generation response on the left chat box. If you did not provide a finetuned model, you should not see any generations from the right chat box.
|
||||
|
||||
The inference section provides:
|
||||
- **Sample Images**: Test on wildfire detection samples from both categories (wildfire/no wildfire)
|
||||
- **Dual Inference**: Run both base and fine-tuned models simultaneously
|
||||
- **Random Sampling**: Test different samples with the "🎲 Test another sample" button
|
||||
- **Structured Reasoning**: Fine-tuned model provides reasoning in `<REASONING>` tags before final answer
|
||||
As you can see, the base model is incapable of providing the right response for this domain-specific task. Let's try to improve the model's accuracy by performing GRPO finetuning.
|
||||
|
||||
### Sample Questions
|
||||
## 4. GRPO Finetuning
|
||||
|
||||
The interface includes prompts for wildfire detection:
|
||||
- "Identify if this region has been affected by a wildfire"
|
||||
- The fine-tuned model provides structured reasoning followed by a Yes/No answer
|
||||
We will perform GRPO finetuning to add reasoning capabilities to our base model and improve the model's understanding to the underlying domain. Considering that you have already spun up the streamlit demo, scroll to the `GRPO Training section`.
|
||||
|
||||
<figure>
|
||||
<img src="assets/training_page.png" alt="Training Page" width="1000"/>
|
||||
<figcaption>Training parameters on the UI</figcaption>
|
||||
</figure>
|
||||
|
||||
### 4.1 Model Settings
|
||||
|
||||
Configure the finetuning method and lora parameters based on the following options.
|
||||
|
||||
- `Finetuning Method`: Choose from Full Finetuning or LoRA
|
||||
- `LoRA Parameters`: Adjustable rank (8-64) and alpha (8-64)
|
||||
|
||||
### 4.1 Finetune layers
|
||||
|
||||
You can additionally choose whether the layers you want to finetune in the VLM. For the best performance, ensure that all options are toggled on. Note that this will increase the model training time as well.
|
||||
|
||||
### 4.2 Training parameters
|
||||
|
||||
In this section, we can select certain model parameters as relevant to our training run.
|
||||
|
||||
- `Epochs`: 1-100
|
||||
- `Batch Size`: 1, 2, 4, 8, or 16
|
||||
- `Learning Rate`: 1e-6 to 1e-2
|
||||
- `Optimizer`: AdamW or Adafactor
|
||||
|
||||
### 4.3 GRPO settings
|
||||
|
||||
For a GRPO setup, we also have the flexibility in choosing the reward that is assigned to the model based on certain criteria
|
||||
|
||||
- `Format Reward`: 2.0 (reward for proper reasoning format)
|
||||
- `Correctness Reward`: 5.0 (reward for correct answers)
|
||||
- `Number of Generations`: 4 (for preference optimization)
|
||||
|
||||
### 4.4 Start training
|
||||
|
||||
After configuring all the parameters, hit `Start Finetuning` to begin the training process. You will need to wait about 15 mins for the model to load and start recording metadata on the UI. As the training progresses, information such as the loss, epoch and GRPO rewards will be recorded on a live table.
|
||||
|
||||
### 4.5 Stop training
|
||||
|
||||
If you wish to stop training, just hit the `Stop Finetuning` button. Ensure that you stop the training with atleast 50 steps complete to ensure that a finetuned checkpoint is stored.
|
||||
|
||||
Once you stop training, the UI will automatically bring up the vLLM servers for the base model and the newly finetuned model.
|
||||
|
||||
## 5. Finetuned Model Inference
|
||||
|
||||
Now we are ready to perform a comparative analysis between the base model and the finetuned model.
|
||||
|
||||
### 5.1 (Optional) Spin up the Streamlit demo
|
||||
|
||||
If you haven't spun up the streamlit demo already, execute the following command. If had just just stopped training and are still within the live UI, skip to the next step.
|
||||
|
||||
```bash
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
|
||||
Access the streamlit demo at http://localhost:8501/.
|
||||
|
||||
### 5.2 vLLM startup
|
||||
|
||||
Regardless of whether you just spun up the demo or just stopped training, please wait about 15 mins for the vLLM servers to be brought up.
|
||||
|
||||
### 5.3 Run finetuned model inference
|
||||
|
||||
Scroll down to the `Image Inference` section, and enter your prompt in the provided chat box. Upon clicking `Generate`, your prompt would be first sent to the base model and then to the finetuned model. You can use the following prompt to quickly test inference
|
||||
|
||||
`Identify if this region has been affected by a wildfire`
|
||||
|
||||
If you trained your model sufficiently enough, you should see that the finetuned model is able to perform reasoning and provide a concise, accurate answer to the prompt. The reasoning steps are provided in the markdown format, while the final answer is bolded and provided at the end of the model's response.
|
||||
|
||||
For the image shown below, we have trained the model for 1000 steps, which took about 4 hours.
|
||||
|
||||
### 5.4 Further analysis
|
||||
|
||||
If you wish to play around with these models with additional images, the `Test another sample` button will load another random satellite image.
|
||||
|
||||
## File Structure
|
||||
|
||||
@ -131,7 +192,3 @@ ui_image/
|
||||
│ └── inference_screenshot.png # UI demonstration screenshot
|
||||
└── saved_model/ # Training checkpoints directory (update config to point here)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The `src/image_vlm_config.yaml` file contains all training and inference settings:
|
||||
|
||||
BIN
nvidia/vlm-finetuning/assets/ui_image/assets/inference_page.png
Normal file
BIN
nvidia/vlm-finetuning/assets/ui_image/assets/inference_page.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 753 KiB |
BIN
nvidia/vlm-finetuning/assets/ui_image/assets/training_page.png
Normal file
BIN
nvidia/vlm-finetuning/assets/ui_image/assets/training_page.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 358 KiB |
@ -17,15 +17,14 @@
|
||||
|
||||
inference:
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
finetuned_model_id: RLakshmi24/qwen_wildfire_qrpo_lora
|
||||
max_seq_length: 16384
|
||||
finetuned_model_id: saved_model
|
||||
max_seq_length: 8192
|
||||
|
||||
train:
|
||||
model:
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
max_seq_length: 16384
|
||||
use_lora: true
|
||||
use_qlora: false
|
||||
lora_config:
|
||||
rank: 32
|
||||
alpha: 64
|
||||
@ -39,10 +38,10 @@ train:
|
||||
dataset_id: data
|
||||
|
||||
hyperparameters:
|
||||
epochs: 10
|
||||
batch_size: 2
|
||||
steps: 100
|
||||
batch_size: 4
|
||||
enable_grpo: true
|
||||
num_generations: 4
|
||||
num_generations: 2
|
||||
format_reward: 2.0
|
||||
learning_rate: 1e-5
|
||||
correctness_reward: 5.0
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
data:
|
||||
dataset_id: data
|
||||
hyperparameters:
|
||||
batch_size: 1
|
||||
batch_size: 4
|
||||
correctness_reward: 5.0
|
||||
enable_grpo: true
|
||||
epochs: 2
|
||||
format_reward: 2.0
|
||||
learning_rate: 1.0e-05
|
||||
num_generations: 4
|
||||
num_generations: 2
|
||||
optimizer: adamw_torch
|
||||
output_dir: saved_model
|
||||
steps: 5
|
||||
model:
|
||||
finetune_attention_modules: true
|
||||
finetune_language_layers: true
|
||||
@ -18,8 +18,7 @@ model:
|
||||
lora_config:
|
||||
alpha: 64
|
||||
dropout: 0.05
|
||||
rank: 16
|
||||
rank: 32
|
||||
max_seq_length: 16384
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
use_lora: true
|
||||
use_qlora: false
|
||||
|
||||
@ -18,11 +18,15 @@
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
import re
|
||||
import sys
|
||||
import yaml
|
||||
import shutil
|
||||
import signal
|
||||
|
||||
from PIL import ImageFile
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
@ -37,7 +41,7 @@ def load_model_for_train(config):
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name=config["model"]["model_id"],
|
||||
max_seq_length=config["model"]["max_seq_length"],
|
||||
load_in_4bit=config["model"]["use_qlora"],
|
||||
load_in_4bit=False,
|
||||
)
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
@ -152,8 +156,6 @@ def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[floa
|
||||
|
||||
|
||||
def start_train(config):
|
||||
# load base model for finetuning
|
||||
model, tokenizer = load_model_for_train(config)
|
||||
|
||||
# load dataset
|
||||
train_dataset = load_wildfire_dataset(config, tokenizer)
|
||||
@ -171,17 +173,17 @@ def start_train(config):
|
||||
log_completions=False,
|
||||
per_device_train_batch_size=config["hyperparameters"]["batch_size"],
|
||||
gradient_accumulation_steps=1,
|
||||
num_generations=2,
|
||||
num_generations=config["hyperparameters"]["num_generations"],
|
||||
max_prompt_length=config["model"]["max_seq_length"],
|
||||
max_completion_length=config["model"]["max_seq_length"],
|
||||
num_train_epochs=config["hyperparameters"]["epochs"],
|
||||
save_steps=100,
|
||||
max_steps=config["hyperparameters"]["steps"],
|
||||
save_steps=3,
|
||||
max_grad_norm=0.1,
|
||||
report_to="none",
|
||||
output_dir=config["hyperparameters"]["output_dir"],
|
||||
# importance_sampling_level="sequence",
|
||||
# mask_truncated_completions=False,
|
||||
# loss_type="dr_grpo",
|
||||
importance_sampling_level="sequence",
|
||||
mask_truncated_completions=False,
|
||||
loss_type="dr_grpo",
|
||||
)
|
||||
|
||||
# start training
|
||||
@ -197,8 +199,29 @@ def start_train(config):
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
handle_termination(None, None)
|
||||
|
||||
|
||||
def handle_termination(signum, frame):
|
||||
latest_checkpoint = get_last_checkpoint(config["hyperparameters"]["output_dir"])
|
||||
if latest_checkpoint:
|
||||
if config["model"]["use_lora"]:
|
||||
print("Merging LoRA weights and saving the model")
|
||||
shutil.rmtree(latest_checkpoint)
|
||||
model.save_pretrained_merged(latest_checkpoint, tokenizer, save_method="merged_16bit")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGTERM, handle_termination)
|
||||
signal.signal(signal.SIGINT, handle_termination)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("src/train.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
start_train(config)
|
||||
# load base model for finetuning
|
||||
model, tokenizer = load_model_for_train(config)
|
||||
|
||||
start_train(config)
|
||||
|
||||
84
nvidia/vlm-finetuning/assets/ui_image/training_logs.txt
Normal file
84
nvidia/vlm-finetuning/assets/ui_image/training_logs.txt
Normal file
@ -0,0 +1,84 @@
|
||||
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
|
||||
import pynvml # type: ignore[import]
|
||||
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
|
||||
🦥 Unsloth Zoo will now patch everything to make training faster!
|
||||
==((====))== Unsloth 2025.9.11: Fast Qwen2_5_Vl patching. Transformers: 4.56.2.
|
||||
\\ /| NVIDIA GB10. Num GPUs = 1. Max memory: 119.699 GB. Platform: Linux.
|
||||
O^O/ \_/ \ Torch: 2.9.0a0+50eac811a6.nv25.09. CUDA: 12.1. CUDA Toolkit: 13.0. Triton: 3.4.0
|
||||
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.33+5146f2a.d20251005. FA2 = True]
|
||||
"-____-" Free license: http://github.com/unslothai/unsloth
|
||||
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
|
||||
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.
|
||||
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 25%|██▌ | 1/4 [00:29<01:28, 29.57s/it]
Loading checkpoint shards: 50%|█████ | 2/4 [00:58<00:58, 29.48s/it]
Loading checkpoint shards: 75%|███████▌ | 3/4 [01:27<00:28, 28.99s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [01:36<00:00, 21.25s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [01:36<00:00, 24.20s/it]
|
||||
Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
|
||||
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
|
||||
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
|
||||
==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
|
||||
\\ /| Num examples = 30,250 | Num Epochs = 1 | Total steps = 30,250
|
||||
O^O/ \_/ \ Batch size per device = 2 | Gradient accumulation steps = 1
|
||||
\ / Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2
|
||||
"-____-" Trainable parameters = 103,043,072 of 8,395,209,728 (1.23% trained)
|
||||
0%| | 0/30250 [00:00<?, ?it/s]`generation_config` default values have been modified to match model-specific defaults: {'max_length': 32768, 'temperature': 1e-06, 'repetition_penalty': 1.05}. If this is not desired, please set these values explicitly.
|
||||
----------------------------------
|
||||
Question:
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
<|vision_start|><|image_pad|><|vision_end|>Identify if this region has been affected by a wildfire. Also first provide your reasoning or working out on how you would go about identifying the presence of wildfire affected regions between <REASONING> and </REASONING> and then your final answer between <SOLUTION> and (put a simple Yes or No here) </SOLUTION><|im_end|>
|
||||
<|im_start|>assistant
|
||||
|
||||
Answer:
|
||||
No
|
||||
Response:
|
||||
<REASONING>
|
||||
To identify if the region has been affected by a wildfire, we need to look for signs such as charred areas, absence of vegetation, or changes in the landscape that might indicate fire activity. In satellite imagery, these signs can be subtle but may include:
|
||||
|
||||
1. **Burnt Vegetation**: Areas where the vegetation appears to have been consumed by fire, often appearing as darkened patches.
|
||||
2. **Ash or Ash Patches**: These can appear as lighter-colored areas compared to surrounding areas.
|
||||
3. **Changes in Land Cover**: The presence of new structures or changes in the type of land cover that could indicate recent construction or clearing.
|
||||
4. **Smoke Plumes**: While not visible in this still image, smoke plumes from wildfires can sometimes be seen in satellite imagery.
|
||||
|
||||
In this image:
|
||||
- There are no obvious darkened or charred areas that would suggest recent fire activity.
|
||||
- The vegetation appears intact, with no signs of ash or burnt areas.
|
||||
- There are no visible smoke plumes or signs of recent construction that would indicate recent fire activity.
|
||||
|
||||
</REASONING>
|
||||
|
||||
<SOLUTION>
|
||||
No
|
||||
</SOLUTION>
|
||||
0%| | 1/30250 [01:15<631:36:01, 75.17s/it]
{'loss': 0.0, 'grad_norm': 3.589693005778827e-05, 'learning_rate': 0.0, 'num_tokens': 879.0, 'completions/mean_length': 205.5, 'completions/min_length': 177.0, 'completions/max_length': 234.0, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 205.5, 'completions/min_terminated_length': 177.0, 'completions/max_terminated_length': 234.0, 'rewards/format_reward_func/mean': 2.0, 'rewards/format_reward_func/std': 0.0, 'rewards/correctness_reward_func/mean': 5.0, 'rewards/correctness_reward_func/std': 0.0, 'reward': 7.0, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'completion_length': 205.5, 'kl': 0.0005636619171127677, 'epoch': 0.0}
|
||||
0%| | 1/30250 [01:15<631:36:01, 75.17s/it]----------------------------------
|
||||
Question:
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
<|vision_start|><|image_pad|><|vision_end|>Identify if this region has been affected by a wildfire. Also first provide your reasoning or working out on how you would go about identifying the presence of wildfire affected regions between <REASONING> and </REASONING> and then your final answer between <SOLUTION> and (put a simple Yes or No here) </SOLUTION><|im_end|>
|
||||
<|im_start|>assistant
|
||||
|
||||
Answer:
|
||||
Yes
|
||||
Response:
|
||||
<REASONING>
|
||||
To identify if a region has been affected by a wildfire, one would typically look for signs such as charred areas, burn scars, or changes in vegetation color and texture. In satellite imagery, these changes can be subtle but often noticeable.
|
||||
|
||||
1. **Burn Scars**: These are areas where the ground has been scorched, often appearing as lighter or different shades compared to surrounding areas.
|
||||
2. **Vegetation Changes**: Wildfires can alter the appearance of vegetation, making it appear more brown or less dense than before.
|
||||
3. **Ash and Debris**: Ash and debris from a fire can leave a distinct mark on the landscape, which might not be immediately obvious but can be seen with close inspection.
|
||||
|
||||
In the provided image:
|
||||
- The area appears to have a mix of green and brown tones, which could indicate recent vegetation change.
|
||||
- There is a distinct darker area that could be interpreted as a burn scar, but without more detailed information or higher resolution imagery, it's difficult to confirm.
|
||||
- The overall pattern suggests a mix of forested and open areas, which is typical for many regions.
|
||||
|
||||
However, without more specific indicators like smoke plumes, ash deposits, or detailed post-fire analysis, it's challenging to definitively conclude the presence of a wildfire.
|
||||
|
||||
</REASONING>
|
||||
|
||||
<SOLUTION>
|
||||
No
|
||||
</SOLUTION>
|
||||
Unsloth: Will smartly offload gradients to save VRAM!
|
||||
0%| | 2/30250 [02:21<587:04:23, 69.87s/it]
{'loss': 0.0, 'grad_norm': 3.785433727898635e-05, 'learning_rate': 3.3057851239669425e-09, 'num_tokens': 1864.0, 'completions/mean_length': 258.5, 'completions/min_length': 242.0, 'completions/max_length': 275.0, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 258.5, 'completions/min_terminated_length': 242.0, 'completions/max_terminated_length': 275.0, 'rewards/format_reward_func/mean': 2.0, 'rewards/format_reward_func/std': 0.0, 'rewards/correctness_reward_func/mean': 0.0, 'rewards/correctness_reward_func/std': 0.0, 'reward': 2.0, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'completion_length': 258.5, 'kl': 0.0006416599499061704, 'epoch': 0.0}
|
||||
0%| | 2/30250 [02:21<587:04:23, 69.87s/it]
|
||||
Loading…
Reference in New Issue
Block a user