# # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from unsloth import FastVisionModel import os import re import gc import yaml import glob import random import subprocess import wandb import torch from PIL import Image import streamlit as st REASONING_START = "" REASONING_END = "" SOLUTION_START = "" SOLUTION_END = "" def initialize_session_state(resources): # Initialize page-specific session state st.session_state["base"] = st.session_state.get("base", resources["base"]) st.session_state["finetuned"] = st.session_state.get("finetuned", resources["finetuned"]) st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[0]) st.session_state["train_process"] = st.session_state.get("train_process", None) def load_config(): config_key = "config" if getattr(st.session_state, config_key, None) is None: with open("src/image_vlm_config.yaml", "r") as f: config = yaml.safe_load(f) setattr(st.session_state, config_key, config) else: config = getattr(st.session_state, config_key) return config @st.cache_resource def initialize_resources(inference_config): base_model, base_tokenizer = load_model_for_inference(inference_config, "base") finetuned_model, finetuned_tokenizer = load_model_for_inference(inference_config, "finetuned") return { "base": {"model": base_model, "tokenizer": base_tokenizer}, "finetuned": {"model": finetuned_model, "tokenizer": finetuned_tokenizer}, } def main(): # set page ui st.title("Image VLM Finetuning") st.caption("A DGX Spark showcase for on-device VLM finetuning") # st.page_link("https://github.com/your-username/your-repo", label="GitHub", icon=":material/github:") # load css with open("src/styles.css", "r") as f: st.markdown(f"", unsafe_allow_html=True) # load resources config = load_config() if st.session_state.get("base", None) is None: st.toast("Loading model", icon="⏳", duration="short") resource = initialize_resources(config["inference"]) if st.session_state.get("base", None) is None: st.toast("Model loaded", icon="✅", duration="short") initialize_session_state(resource) # train section st.markdown("---") train_section() # inference Section st.markdown("---") inference_section() def train_section(): st.header("GRPO Training") column_1, column_2, column_3 = st.columns(3, gap="large") with column_1: finetuning_method = st.selectbox( "Finetuning Method:", ["LoRA", "QLoRA", "Full Fine-tuning"], ) # update lora config if finetuning_method in ("QLoRA", "LoRA"): lora_config = st.session_state["config"]["train"]["model"]["lora_config"] with column_2: lora_rank = st.slider( "LoRA Rank", min_value=8, max_value=64, value=lora_config["rank"], step=8, ) with column_3: lora_alpha = st.slider( "LoRA Alpha", min_value=8, max_value=64, value=lora_config["alpha"], step=8, ) st.session_state["config"]["train"]["model"]["lora_config"].update({ 'rank': lora_rank, 'alpha': lora_alpha, }) # update model config based on selection st.session_state["config"]["train"]["model"]["use_lora"] = finetuning_method == "LoRA" st.session_state["config"]["train"]["model"]["use_qlora"] = finetuning_method == "QLoRA" # update train config st.write("") column_1, column_2, column_3, column_4 = st.columns(4, gap="large") with column_1: finetune_vision_layers = st.toggle( "Finetune Vision Layers", value=st.session_state["config"]["train"]["model"]["finetune_vision_layers"]) with column_2: finetune_language_layers = st.toggle( "Finetune Language Layers", value=st.session_state["config"]["train"]["model"]["finetune_language_layers"]) with column_3: finetune_attention_modules = st.toggle( "Finetune Attention Modules", value=st.session_state["config"]["train"]["model"]["finetune_attention_modules"]) with column_4: finetune_mlp_modules = st.toggle( "Finetune MLP Modules", value=st.session_state["config"]["train"]["model"]["finetune_mlp_modules"]) st.write("") column_1, column_2, column_3, column_4 = st.columns(4, gap="large") with column_1: epochs = st.slider( "Epochs", min_value=1, max_value=100, value=st.session_state["config"]["train"]["hyperparameters"]["epochs"]) with column_2: batch_size = st.select_slider( "Batch Size", options=[1, 2, 4, 8, 16], value=st.session_state["config"]["train"]["hyperparameters"]["batch_size"]) with column_3: learning_rate = st.number_input( "Learning Rate", min_value=1e-6, max_value=1e-2, value=float(st.session_state["config"]["train"]["hyperparameters"]["learning_rate"]), format="%.2e") with column_4: optimizer = st.selectbox( "Optimizer", options=["adamw_torch", "adafactor"]) st.session_state["config"]["train"]["hyperparameters"].update({ 'epochs': epochs, 'batch_size': batch_size, 'learning_rate': learning_rate, 'optimizer': optimizer, }) st.session_state["config"]["train"]["model"].update({ 'finetune_vision_layers': finetune_vision_layers, 'finetune_language_layers': finetune_language_layers, 'finetune_attention_modules': finetune_attention_modules, 'finetune_mlp_modules': finetune_mlp_modules, }) st.write("") column_1, column_2, column_3, column_4 = st.columns(4, gap="large") with column_1: enable_grpo = st.toggle( "Enable GRPO", value=st.session_state["config"]["train"]["hyperparameters"]["enable_grpo"], disabled=True) with column_2: format_reward = st.number_input( "Reward for reasoning format", min_value=0.0, max_value=5.0, value=float(st.session_state["config"]["train"]["hyperparameters"]["format_reward"]), format="%.2e") with column_3: correctness_reward = st.number_input( "Reward for correct response", min_value=0.0, max_value=5.0, value=float(st.session_state["config"]["train"]["hyperparameters"]["correctness_reward"]), format="%.2e") with column_4: num_generations = st.number_input( "Number of generations", min_value=1, max_value=16, value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"], format="%.2e") # Training control st.write("") column_1, column_2, column_3 = st.columns([4, 4, 1]) with column_1: button_type = "secondary" if st.session_state["train_process"] else "primary" if st.button("▶️ Start Finetuning", type=button_type, width="stretch", disabled=bool(st.session_state["train_process"])): if st.session_state["train_process"] is None: # store config with open("src/train.yaml", "w") as f: yaml.dump(st.session_state["config"]["train"], f, default_flow_style=False) # start training st.session_state["train_process"] = subprocess.Popen( ["python", "src/train_image_vlm.py"], stdout=None, stderr=None ) else: st.toast("Training already in progress", icon="❌", duration="short") with column_2: button_type = "primary" if st.session_state["train_process"] else "secondary" if st.button("⏹️ Stop Finetuning", type=button_type, width="stretch", disabled=not bool(st.session_state["train_process"])): if st.session_state["train_process"] is not None: st.session_state["train_process"].terminate() st.session_state["train_process"] = None st.toast("Training stopped", icon="✅", duration="short") st.toast("Re-deploy the app with updated finetuned model", icon=":material/info:", duration="short") else: st.toast("No training to stop", icon="❌", duration="short") with column_3: if st.session_state["train_process"]: st.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch") else: st.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch") # display wandb runs = wandb.Api().runs(f"{os.environ.get('WANDB_ENTITY')}/{os.environ.get('WANDB_PROJECT')}") if runs: base_url = runs[0].url loss_url = f"{base_url}?panelDisplayName=train%2Floss&panelSectionName=train" memory_url = f"{base_url}?panelDisplayName=GPU+Memory+Allocated+%28%25%29&panelSectionName=System" column_1, column_2 = st.columns(2) with column_1: st.markdown(f"""
""", unsafe_allow_html=True) with column_2: st.markdown(f"""
""", unsafe_allow_html=True) def inference_section(): st.header("Image Inference") columns = st.columns([3, 3, 1, 2]) with columns[1]: with st.container(border=True, horizontal_alignment="center", vertical_alignment="center"): image_holder = st.empty() image_holder.image(st.session_state["current_image"]) with columns[3]: if st.button("🎲 Test another sample"): while True: current_image = random.choice(glob.glob("assets/image_vlm/images/*/*")) if current_image != st.session_state["current_image"]: break st.session_state["current_image"] = current_image image_holder.image(st.session_state["current_image"]) columns = st.columns(2, gap="small") with columns[0]: with st.container(border=True): st.write("##### :green[Base Qwen2.5-VL-7B]") base_generation = st.empty() base_generation.write("...") with columns[1]: with st.container(border=True): st.write("##### :green[Finetuned Qwen2.5-VL-7B]") finetuned_generation = st.empty() finetuned_generation.write("...") columns = st.columns([9, 1], gap="small") with columns[0]: prompt = st.text_input( "Prompt Input", label_visibility="collapsed", key="prompt_input", on_change=lambda: st.session_state.update(prompt=st.session_state["prompt_input"]) ) with columns[1]: if st.button("Generate", width="stretch"): if st.session_state.get("prompt", None): st.session_state["prompt"] = prompt with st.spinner("Running..."): response = start_inference("base") base_generation.markdown(response) with st.spinner("Running..."): response = start_inference("finetuned") finetuned_generation.markdown(response) def load_model_for_inference(config, model_type): if model_type == "finetuned": model_name = config["finetuned_model_id"] elif model_type == "base": model_name = config["model_id"] else: raise ValueError(f"Invalid model type: {model_type}") model, tokenizer = FastVisionModel.from_pretrained( model_name=model_name, max_seq_length=config["max_seq_length"], load_in_4bit=False, ) FastVisionModel.for_inference(model) return model, tokenizer @torch.no_grad() def start_inference(model_type): # define prompt prompt = st.session_state["prompt"] if model_type == "finetuned": prompt = ( f"{prompt}. Also first provide your reasoning or working out"\ f" on how you would go about identifying the presence of wildfire affected regions between {REASONING_START} and {REASONING_END}" f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}" ) # load image image = Image.open(st.session_state["current_image"]) if image.mode != "RGB": image = image.convert("RGB") # construct instruction prompt prompt = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], }, ] # apply chat template prompt = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].apply_chat_template( prompt, tokenize=False, add_generation_prompt=True, ) # tokenize inputs inputs = st.session_state[f"{model_type}_image_vlm"]["tokenizer"]( image, prompt, add_special_tokens=False, return_tensors="pt", ).to("cuda") # perform inference response = st.session_state[f"{model_type}_image_vlm"]["model"].generate( **inputs, max_new_tokens=1024, use_cache=True, do_sample=False )[0][inputs["input_ids"].shape[1]: ] # decode tokens response = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].decode(response, skip_special_tokens=True) # format response if model_type == "finetuned": response = response.replace(REASONING_START, "```") response = response.replace(REASONING_END, "```") # Handle solution formatting with proper newline handling solution_pattern = f'{re.escape(SOLUTION_START)}(.*?){re.escape(SOLUTION_END)}' solution_match = re.search(solution_pattern, response, re.DOTALL) if solution_match: solution_content = solution_match.group(1).strip() response = re.sub(solution_pattern, f"**{solution_content}**", response, flags=re.DOTALL) return response if __name__ == "__main__": main()