chore: Regenerate all playbooks

This commit is contained in:
GitLab CI 2025-10-04 21:21:42 +00:00
parent 0d492abd3f
commit dd5e5be079
107 changed files with 27116 additions and 10 deletions

View File

@ -0,0 +1,39 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
FROM nvcr.io/nvidia/pytorch:25.09-py3
ARG HF_TOKEN
RUN cd /workspace/ && \
git clone https://github.com/comfyanonymous/ComfyUI.git && \
cd ComfyUI && \
git checkout 4ffea0e864275301329ddb5ecc3fbc7211d7a802 && \
sed -i '/torch/d' requirements.txt && \
pip install -r requirements.txt && \
pip install torchsde && \
mkdir -p /workspace/ComfyUI/user/default/workflows/
COPY . /workspace/sd-scripts
RUN hf download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models/vae && \
hf download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/checkpoints && \
hf download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/text_encoders && \
hf download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/text_encoders && \
hf download RLakshmi24/flux-dreambooth-lora-tj-spark flux_dreambooth.safetensors --local-dir models/loras && \
cp /workspace/sd-scripts/workflows/finetuned_flux.json /workspace/ComfyUI/user/default/workflows/
CMD ["/bin/bash"]

View File

@ -0,0 +1,42 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
FROM nvcr.io/nvidia/pytorch:25.09-py3
ARG HF_TOKEN
RUN cd /workspace/ && \
git clone https://github.com/kohya-ss/sd-scripts.git && \
cd sd-scripts && \
git checkout sd3 && \
pip install -r requirements.txt && \
apt update && \
apt install -y libgl1-mesa-dev
COPY . /workspace/sd-scripts
RUN hf auth login --token $HF_TOKEN
RUN cd /workspace/sd-scripts/ && \
hf download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models && \
hf download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models && \
hf download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models && \
hf download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models
RUN cd /workspace/sd-scripts
CMD ["/bin/bash"]

View File

@ -0,0 +1,203 @@
# FLUX.1 Fine-tuning with LoRA
This project demonstrates fine-tuning the FLUX.1-dev 11B model using Dreambooth LoRA (Low-Rank Adaptation) for custom image generation. The demo includes training on custom concepts and inference through both command-line scripts and ComfyUI.
## Results
Fine-tuning FLUX.1 with custom concepts enables the model to generate images with your specific objects and styles:
<figure>
<img src="assets/before_finetuning.png" alt="Before Fine-tuning" width="400"/>
<figcaption>Base FLUX.1 model without custom concept knowledge</figcaption>
</figure>
<br>
<figure>
<img src="assets/after_finetuning.png" alt="After Fine-tuning" width="400"/>
<figcaption>FLUX.1 model after LoRA fine-tuning with custom "tjtoy" and "sparkgpu" concepts</figcaption>
</figure>
## Overview
The project includes:
- **FLUX.1-dev Fine-tuning**: LoRA-based fine-tuning using sd-scripts
- **Custom Concept Training**: Train on "tjtoy" toy and "sparkgpu" GPU
- **Command-line Inference**: Generate images using trained LoRA weights
- **ComfyUI Integration**: Intuitive workflows for inference with custom models
- **Docker Support**: Complete containerized environment
## Training
### 1. Build Docker Image by providing `HF_TOKEN`
```bash
# Build the Docker image (this will download FLUX models automatically)
docker build -f Dockerfile.train --build-arg HF_TOKEN=$HF_TOKEN -t flux-training .
```
**Note**: The Docker build automatically downloads the required FLUX models:
- `flux1-dev.safetensors` (~23GB)
- `ae.safetensors` (~335MB)
- `clip_l.safetensors` (~246MB)
- `t5xxl_fp16.safetensors` (~9.8GB)
### 2. Run Docker Container
```bash
# Run with GPU support and mount current directory
docker run -it \
--gpus all \
--ipc=host \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
--net=host \
flux-training
```
### 3. Train the Model
```bash
# Inside the container, navigate to sd-scripts and run training
cd /workspace/sd-scripts
sh train.sh
```
### 4. Run Inference
The `inference.sh` script generates 9 images with different seeds.
After training, you can generate images using the learned concepts. For example:
- `"tjtoy toy"` - Your custom toy concept
- `"sparkgpu gpu"` - Your custom GPU concept
- Combine them: `"tjtoy toy holding sparkgpu gpu"`
```bash
# Generate images using the trained LoRA
sh inference.sh
```
### Dataset Structure
The training data is organized in the `data/` directory:
```
data/
├── data.toml # Training configuration
├── tjtoy/ # Custom toy concept images (6 images)
│ ├── 1.png
│ ├── 2.jpg
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ └── 6.png
└── sparkgpu/ # Custom GPU concept images (7 images)
├── 1.jpeg
├── 2.jpg
├── 3.jpg
├── 4.jpg
├── 6.png
├── 7.png
└── 8.png
```
### Training Parameters
Key training settings in `train.sh`:
- **Network Type**: LoRA with dimension 256
- **Learning Rate**: 1.0 (with Prodigy optimizer)
- **Epochs**: 100 (saves every 25 epochs)
- **Resolution**: 1024x1024
- **Mixed Precision**: bfloat16
- **Optimizations**: Torch compile, gradient checkpointing, cached latents
## ComfyUI
ComfyUI provides an intuitive visual interface for using your fine-tuned LoRA models. The beauty of LoRA fine-tuning is that you can easily add your custom concepts to any FLUX workflow with just a single node.
### 1. Build Docker Image by providing `HF_TOKEN`
```bash
# Build the Docker image (this will download FLUX models automatically)
docker build -f Dockerfile.inference --build-arg HF_TOKEN=$HF_TOKEN -t flux-comfyui .
```
### 2. Run Docker Container
```bash
# Run with GPU support and mount current directory
docker run -it \
--gpus all \
--ipc=host \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
--net=host \
flux-comfyui
```
### 3. Running ComfyUI
```bash
# Start ComfyUI server
cd /workspace/ComfyUI
python main.py
```
Access ComfyUI at `http://localhost:8188`
### 4. ComfyUI Workflow Example
![ComfyUI Workflow](assets/comfyui_workflow.png)
*ComfyUI workflow showing how easily LoRA can be integrated into the base FLUX model*
The workflow demonstrates the simplicity of LoRA integration:
1. **Load Checkpoint**: Base FLUX.1-dev model remains unchanged
2. **Load LoRA**: Simply add your trained LoRA file (`flux_dreambooth.safetensors`)
3. **Adjust Strength**: Fine-tune the influence of your custom concepts (0.8-1.2 typically works well)
4. **Generate**: Use your custom trigger words (`tjtoy toy`, `sparkgpu gpu`) in prompts
This modular approach means you can:
- **Preserve base model quality**: The original FLUX capabilities remain intact
- **Easy experimentation**: Quickly swap different LoRA models or adjust strengths
- **Combine concepts**: Mix multiple LoRA models or use them with other techniques
- **Minimal storage**: LoRA files are typically 100-200MB vs 23GB+ for full models
### ComfyUI Model Structure
Organize models in ComfyUI as follows:
```
ComfyUI/models/
├── checkpoints/
│ └── flux1-dev.safetensors # Main FLUX model
├── vae/
│ └── ae.safetensors # FLUX VAE
├── clip/
│ ├── clip_l.safetensors # CLIP text encoder
│ └── t5xxl_fp16.safetensors # T5 text encoder
└── loras/
└── flux_dreambooth.safetensors # Your trained LoRA
```
## Custom Concepts
The fine-tuning process teaches FLUX.1 to understand two custom concepts:
### TJToy Concept
- **Trigger phrase**: `tjtoy toy`
- **Training images**: 6 high-quality images of custom toy figures
- **Use case**: Generate images featuring the specific toy character in various scenes
### SparkGPU Concept
- **Trigger phrase**: `sparkgpu gpu`
- **Training images**: 7 images of custom GPU hardware
- **Use case**: Generate images featuring the specific GPU design in different contexts
### Combined Usage
You can combine both concepts in prompts:
- `"tjtoy toy holding sparkgpu gpu"`
- `"tjtoy toy standing next to sparkgpu gpu in a data center"`
- `"sparkgpu gpu being examined by tjtoy toy"`
## Credits
This project uses [sd-scripts](https://github.com/kohya-ss/sd-scripts) repository by `kohya-ss` for FLUX.1 fine-tuning.

View File

@ -0,0 +1,38 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
[general]
shuffle_caption = false
keep_tokens = 2
[[datasets]]
resolution = 1024
batch_size = 1
[[datasets.subsets]]
image_dir = "flux_data/tjtoy"
class_tokens = "tjtoy toy"
num_repeats = 1
is_reg = false
flip_aug = true
[[datasets.subsets]]
image_dir = "flux_data/sparkgpu"
class_tokens = "sparkgpu gpu"
num_repeats = 2
is_reg = false
flip_aug = true

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 945 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 529 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -0,0 +1,36 @@
#!/bin/bash
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
for SEED in $(seq 0 7); do
python flux_minimal_inference.py \
--ckpt_path="models/flux1-dev.safetensors" \
--model_type="flux" \
--clip_l="models/clip_l.safetensors" \
--t5xxl="models/t5xxl_fp16.safetensors" \
--ae="models/ae.safetensors" \
--output_dir="outputs" \
--lora_weights="saved_models/flux_dreambooth.safetensors" \
--merge_lora_weights \
--prompt="tjtoy toy holding sparkgpu gpu in a datacenter" \
--width=1024 \
--height=1024 \
--steps=50 \
--guidance=1.0 \
--cfg_scale=1.0 \
--seed=$SEED \
--dtype="bfloat16"
done

View File

@ -0,0 +1,52 @@
#!/bin/bash
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
accelerate launch \
--num_processes=1 --num_machines=1 --mixed_precision=bf16 \
--main_process_ip=127.0.0.1 --main_process_port=29500 \
--num_cpu_threads_per_process=2 \
flux_train_network.py \
--pretrained_model_name_or_path="models/flux1-dev.safetensors" \
--clip_l="models/clip_l.safetensors" \
--t5xxl="models/t5xxl_fp16.safetensors" \
--ae="models/ae.safetensors" \
--dataset_config="flux_data/data.toml" \
--output_dir="saved_models" \
--prior_loss_weight=1.0 \
--output_name="flux_dreambooth" \
--save_model_as=safetensors \
--network_module=networks.lora_flux \
--network_dim=256 \
--network_alpha=256 \
--learning_rate=1.0 \
--optimizer_type="Prodigy" \
--lr_scheduler="cosine_with_restarts" \
--sdpa \
--max_train_epochs=100 \
--save_every_n_epochs=25 \
--mixed_precision="bf16" \
--guidance_scale=1.0 \
--timestep_sampling="flux_shift" \
--model_prediction_type="raw" \
--torch_compile \
--persistent_data_loader_workers \
--cache_latents \
--cache_latents_to_disk \
--cache_text_encoder_outputs \
--cache_text_encoder_outputs_to_disk \
--gradient_checkpointing

File diff suppressed because one or more lines are too long

View File

@ -45,7 +45,7 @@ GPU acceleration and performance optimization capabilities.
[ ] Docker or container runtime installed
[ ] NVIDIA Container Toolkit configured
[ ] Verify GPU access: `nvidia-smi`
[ ] Verify Docker GPU support: `docker run --gpus all nvidia/cuda:12.0-base-ubuntu20.04 nvidia-smi`
[ ] Verify Docker GPU support: `docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi`
[ ] Port 8080 available for marimo notebook access
## Ancillary files
@ -92,17 +92,23 @@ If the `docker` command fails with a permission error, you can either
To add yourself to the `docker` group, first run `sudo usermod -aG docker $USER`. Then, as your user account, either run `newgrp docker` or log out and log back in.
## Step 2. Build a Docker image
## Step 3. Clone the playbook repository
```bash
git clone https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main
```
## Step 3. Build the Docker image
> **Warning:** This command will download a base image and build a container locally to support this environment
```bash
cd jax-assets
cd jax/assets
docker build -t jax-on-spark .
```
## Step 3. Launch Docker container
## Step 4. Launch Docker container
Run the JAX development environment in a Docker container with GPU support and port forwarding for marimo access.
@ -113,7 +119,7 @@ docker run --gpus all --rm -it \
jax-on-spark
```
## Step 4. Access marimo interface
## Step 5. Access marimo interface
Connect to the marimo notebook server to begin the JAX tutorial.
@ -124,7 +130,7 @@ Connect to the marimo notebook server to begin the JAX tutorial.
The interface will load a table-of-contents display and brief introduction to marimo.
## Step 5. Complete JAX introduction tutorial
## Step 6. Complete JAX introduction tutorial
Work through the introductory material to understand JAX programming model differences from NumPy.
@ -133,7 +139,7 @@ Navigate to and complete the JAX introduction notebook, which covers:
- Key differences from NumPy
- Performance evaluation techniques
## Step 6. Implement NumPy baseline
## Step 7. Implement NumPy baseline
Complete the NumPy-based self-organized map (SOM) implementation to establish a performance
baseline.
@ -143,7 +149,7 @@ Work through the NumPy SOM notebook to:
- Implement the algorithm using familiar NumPy operations
- Record performance metrics for comparison
## Step 7. Optimize with JAX implementations
## Step 8. Optimize with JAX implementations
Progress through the iteratively refined JAX implementations to see performance improvements.
@ -153,7 +159,7 @@ Complete the JAX SOM notebook sections:
- GPU-accelerated parallel JAX implementation
- Compare performance across all versions
## Step 8. Validate performance gains
## Step 9. Validate performance gains
The notebooks will show you how to check the performance of each SOM training implementation; you'll see that that JAX implementations show performance improvements over NumPy baseline (and some will be quite a lot faster).

View File

@ -51,7 +51,7 @@ def _(mo):
The following video shows an example of training a self-organizing map from color data using a batch algorithm, from a random initial map to a relatively converged set of colorful clusters.
"""
))
mo.output.append(mo.image("batch-som.mp4"))
mo.output.append(mo.video("batch-som.mp4"))
mo.output.append(mo.md(
r"""

View File

@ -0,0 +1,76 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
ARG UBUNTU_VERSION=22.04
ARG CUDA_VERSION=13.0.1
ARG BASE_CUDA_DEV_CONTAINER=nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_CUDA_RUN_CONTAINER=nvcr.io/nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
ARG CUDA_DOCKER_ARCH="121"
RUN apt-get update && \
apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1
WORKDIR /app
RUN git clone https://github.com/ggml-org/llama.cpp.git .
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=121"; \
fi && \
cmake -B build -DGGML_CUDA_ENABLE_UNIFIED_MEMORY=1 -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \
&& cp *.py /app/full \
&& cp -r gguf-py /app/full \
&& cp -r requirements /app/full \
&& cp requirements.txt /app/full \
&& cp .devops/tools.sh /app/full/tools.sh
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
&& find /var/cache -type f -delete
COPY --from=build /app/lib/ /app
FROM base
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
WORKDIR /app
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8000/health" ]
ENTRYPOINT [ "/app/llama-server" ]
CMD ["--host", "0.0.0.0", "--port", "8000"]

View File

@ -0,0 +1,105 @@
# Chatbot Spark: A Local Multi-Agent System for DGX Spark
## Project Overview
Chatbot Spark is a fully local multi-agent system built on DGX Spark. With 128GB of unified memory, DGX Spark can run multiple LLMs and VLMs in parallel — enabling interactions across agents.
At the core is a supervisor agent powered by GPT-OSS-120B, orchestrating specialized downstream agents for coding, retrieval-augmented generation (RAG), and image understanding. Thanks to DGX Sparks out-of-the-box support for popular AI frameworks and libraries, development and prototyping were fast and frictionless. Together, these components demonstrate how complex, multimodal workflows can be executed efficiently on local, high-performance hardware.
This project was built to be customizable, serving as a framework that developers can customize.
## Key Features
- **MCP Server Integration**: Chatbot Spark also showcases the ability to connect to custom MCP servers through a simple and customizable multi-server client
- **Tool Calling**: This project uses an agents-as-tools framework and showcases the ability to create additional agents connected as tools. General tools can also be added.
- **Easily Swappable Models**: Models are loaded and served using Llama CPP and Ollama and served through the OpenAI API. Any OpenAI-compatible model can be integrated into the project.
- **Vector Indexing & Retrieval**: GPU-accelerated Milvus for high-performance document retrieval.
- **Real-time LLM Streaming**: We present custom LLM-streaming infrastructure, making it easy for developers to stream supervisor responses from any OpenAI compatible model.
- **gpt-oss Integration**: The default chat/tool-calling model is gpt-oss:120b, providing seamless integration with OpenAI's latest open sorce tool-calling model.
## System Overview
<img src="assets/system-diagram.png" alt="System Diagram" style="max-width:600px;border-radius:5px;justify-content:center">
## Default Models
| Model | Quantization | Model Type | VRAM |
|------------------------------|--------------|------------|-------------|
| GPT-OSS:120B | MXFP4 | Chat | ~ 63.5 GB |
| Deepseek-Coder:6.7B-Instruct | Q8 | Coding | ~ 9.5 GB |
| Qwen2.5-VL:7B-Instruct | BF16 | Image | ~ 35.4 GB |
| Qwen3-Embedding-4B | Q8 | Embedding | ~ 5.39 GB |
**Total VRAM required:** ~114 GB
> **Warning**:
> Since the default models use majority of available VRAM, ensure that you don't have anything already running on DGX Spark using `nvidia-smi`. If you do, switch to `gpt-oss-20b` following [this guide](#using-different-models).
---
## Quick Start
#### 1. Clone the repository and change directories to the multi-agent chatbot directory.
#### 2. Run the setup script
The setup script will take care of pulling model GGUF files from HuggingFace, building base llama cpp server images and starting all the required docker services to serve models, the backend API server as well as the frontend UI.
```bash
chmod +x setup.sh
./setup.sh
```
Wait for all the containers to become ready and healthy.
```bash
watch 'docker ps --format "table {{.ID}}\t{{.Names}}\t{{.Status}}"'
```
> Note: Downloading model files may take ~10 minutes and starting containers may take another 10 minutes depending on network speed. Look for "server is listening on http://0.0.0.0:8000" in the logs of model server containers.
#### 3. Access the frontend UI
Open your browser and go to: [http://localhost:3000](http://localhost:3000)
> Note: If you are running this on a remote GPU via an ssh connection, in a new terminal window, you need to run to be able to access the UI at localhost:3000 and for the UI to be able to communicate to the backend at localhost:8000:
>```bash
> ssh -L 3000:localhost:3000 -L 8000:localhost:8000 username@IP-address
>```
You should see the following UI in your browser:
<img src="assets/multi-agent-chatbot.png" alt="Frontend UI" style="max-width:600px;border-radius:5px;justify-content:center">
### 4. Try out the sample prompts
Click on any of the tiles on the frontend to try out the supervisor and the other agents.
#### RAG Agent:
Before trying out the RAG agent, upload the example PDF document [NVIDIA Blackwell Whitepaper](https://images.nvidia.com/aem-dam/Solutions/geforce/blackwell/nvidia-rtx-blackwell-gpu-architecture.pdf) as context by clicking on the "Attach" icon in the text input space at the botton of the UI and then make sure to check the box in the "Select Sources" section on the left side of the UI.
<img src="assets/upload-image.png" alt="Upload Image" style="max-width:300px;border-radius:5px;justify-content:center">
<img src="assets/document-ingestion.png" alt="Ingest Documents" style="max-width:300px;border-radius:5px;justify-content:center">
#### Image Understanding Agent:
You can either provide URLs or drag and drop images.
**Example Prompt:**
Describe this image: https://en.wikipedia.org/wiki/London_Bridge#/media/File:London_Bridge_from_St_Olaf_Stairs.jpg
## Customizations
### Using different models
You can use swap the model that the supervisor agent is using, for example to gpt-oss-20b.
1. In `setup.sh`, uncomment the line to download gpt-oss-20b.
> Note: If you already have the model files downloaded, you can skip to step 2.
2. In `docker-compose-models.yml`, uncomment the block for gpt-oss-20b.
> Note: Since the default models use all of the existing VRAM, you will need to comment out the block for gpt-oss-120b in `docker-compose-models.yml`.
3. In `docker-compose.yml`, add `gpt-oss-20b` to the `MODELS` environment variable (line 40).
> Note: This name should match the container name that you set for this model in `docker-compose-models.yml`.
### Adding MCP servers and tools
1. You can add more MCP servers and tools under [backend/tools/mcp_servers](backend/tools/mcp_servers/) following existing examples.
2. If you added an MCP server, remember to add it to the server configs in [backend/client.py](backend/client.py)

View File

@ -0,0 +1 @@
3.10

View File

@ -0,0 +1,33 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
FROM python:3.12
RUN apt-get update && \
apt-get install -y curl && \
curl -LsSf https://astral.sh/uv/install.sh | sh
WORKDIR /app
RUN . $HOME/.local/bin/env
ENV PATH="/root/.local/bin:$PATH"
COPY pyproject.toml uv.lock ./
COPY . .
RUN uv sync
CMD ["uv", "run", "--", "uvicorn", "main:app", "--reload", "--host", "0.0.0.0", "--port", "8000"]

View File

@ -0,0 +1 @@
# Chatbot Backend API Server

View File

@ -0,0 +1,16 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@ -0,0 +1,586 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""ChatAgent implementation for LLM-powered conversational AI with tool calling."""
import asyncio
import contextlib
import json
from typing import AsyncIterator, List, Dict, Any, TypedDict, Optional, Callable, Awaitable
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage, SystemMessage, ToolMessage, ToolCall
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from openai import AsyncOpenAI
from client import MCPClient
from logger import logger
from prompts import Prompts
from postgres_storage import PostgreSQLConversationStorage
from utils import convert_langgraph_messages_to_openai
memory = MemorySaver()
SENTINEL = object()
StreamCallback = Callable[[Dict[str, Any]], Awaitable[None]]
class State(TypedDict, total=False):
iterations: int
messages: List[AnyMessage]
chat_id: Optional[str]
image_data: Optional[str]
class ChatAgent:
"""Main conversational agent with tool calling and agent delegation capabilities.
This agent orchestrates conversation flow using a LangGraph state machine that can:
- Generate responses using LLMs
- Execute tool calls (including MCP tools)
- Handle image processing
- Manage conversation history via Redis
"""
def __init__(self, vector_store, config_manager, postgres_storage: PostgreSQLConversationStorage):
"""Initialize the chat agent.
Args:
vector_store: VectorStore instance for document retrieval
config_manager: ConfigManager for reading configuration
postgres_storage: PostgreSQL storage for conversation persistence
"""
self.vector_store = vector_store
self.config_manager = config_manager
self.conversation_store = postgres_storage
self.current_model = None
self.current_model = None
self.max_iterations = 3
self.mcp_client = None
self.openai_tools = None
self.tools_by_name = None
self.system_prompt = None
self.graph = self._build_graph()
self.stream_callback = None
self.last_state = None
@classmethod
async def create(cls, vector_store, config_manager, postgres_storage: PostgreSQLConversationStorage):
"""
Asynchronously creates and initializes a ChatAgent instance.
This factory method ensures that all async setup, like loading tools,
is completed before the agent is ready to be used.
"""
agent = cls(vector_store, config_manager, postgres_storage)
await agent.init_tools()
available_tools = list(agent.tools_by_name.values()) if agent.tools_by_name else []
template_vars = {
"tools": "\n".join([f"- {tool.name}: {tool.description}" for tool in available_tools]) if available_tools else "No tools available",
}
agent.system_prompt = Prompts.get_template("supervisor_agent").render(template_vars)
logger.debug(f"Agent initialized with {len(available_tools)} tools.")
agent.set_current_model(config_manager.get_selected_model())
return agent
async def init_tools(self) -> None:
"""Initialize MCP client and tools with retry logic.
Sets up the MCP client, retrieves available tools, converts them to OpenAI format,
and initializes specialized agents like the coding agent.
"""
self.mcp_client = await MCPClient().init()
base_delay, max_retries = 0.1, 10
mcp_tools = []
for attempt in range(max_retries):
try:
mcp_tools = await self.mcp_client.get_tools()
break
except Exception as e:
logger.warning(f"MCP tools initialization attempt {attempt + 1} failed: {e}")
if attempt == max_retries - 1:
logger.error(f"MCP servers not ready after {max_retries} attempts, continuing without MCP tools")
mcp_tools = []
break
wait_time = base_delay * (2 ** attempt)
await asyncio.sleep(wait_time)
logger.info(f"MCP servers not ready, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})")
self.tools_by_name = {tool.name: tool for tool in mcp_tools}
logger.debug(f"Loaded {len(mcp_tools)} MCP tools: {list(self.tools_by_name.keys())}")
if mcp_tools:
mcp_tools_openai = [convert_to_openai_tool(tool) for tool in mcp_tools]
logger.debug(f"MCP tools converted to OpenAI format: {mcp_tools_openai}")
self.openai_tools = [
{"type": "function", "function": tool['function']}
for tool in mcp_tools_openai
]
logger.debug(f"Final OpenAI tools format: {self.openai_tools}")
else:
self.openai_tools = []
logger.warning("No MCP tools available - agent will run with limited functionality")
def set_current_model(self, model_name: str) -> None:
"""Set the current model for completions.
Args:
model_name: Name of the model to use
Raises:
ValueError: If the model is not available
"""
available_models = self.config_manager.get_available_models()
try:
if model_name in available_models:
self.current_model = model_name
logger.info(f"Switched to model: {model_name}")
self.model_client = AsyncOpenAI(
base_url=f"http://{self.current_model}:8000/v1",
api_key="api_key"
)
else:
raise ValueError(f"Model {model_name} is not available. Available models: {available_models}")
except Exception as e:
logger.error(f"Error setting current model: {e}")
raise ValueError(f"Model {model_name} is not available. Available models: {available_models}")
def should_continue(self, state: State) -> str:
"""Determine whether to continue the tool calling loop.
Args:
state: Current graph state
Returns:
"end" if no more tool calls or max iterations reached, "continue" otherwise
"""
messages = state.get("messages", [])
if not messages:
return "end"
last_message = messages[-1]
iterations = state.get("iterations", 0)
has_tool_calls = bool(last_message.tool_calls) if hasattr(last_message, 'tool_calls') else False
logger.debug({
"message": "GRAPH: should_continue decision",
"chat_id": state.get("chat_id"),
"iterations": iterations,
"max_iterations": self.max_iterations,
"has_tool_calls": has_tool_calls,
"tool_calls_count": len(last_message.tool_calls) if has_tool_calls else 0
})
if iterations >= self.max_iterations:
logger.debug({
"message": "GRAPH: should_continue → END (max iterations reached)",
"chat_id": state.get("chat_id"),
"final_message_preview": str(last_message)[:100] + "..." if len(str(last_message)) > 100 else str(last_message)
})
return "end"
if not has_tool_calls:
logger.debug({"message": "GRAPH: should_continue → END (no tool calls)", "chat_id": state.get("chat_id")})
return "end"
logger.debug({"message": "GRAPH: should_continue → CONTINUE (has tool calls)", "chat_id": state.get("chat_id")})
return "continue"
async def tool_node(self, state: State) -> Dict[str, Any]:
"""Execute tools from the last AI message's tool calls.
Args:
state: Current graph state
Returns:
Updated state with tool results and incremented iteration count
"""
logger.debug({
"message": "GRAPH: ENTERING NODE - action/tool_node",
"chat_id": state.get("chat_id"),
"iterations": state.get("iterations", 0)
})
await self.stream_callback({'type': 'node_start', 'data': 'tool_node'})
outputs = []
messages = state.get("messages", [])
last_message = messages[-1]
for i, tool_call in enumerate(last_message.tool_calls):
logger.debug(f'Executing tool {i+1}/{len(last_message.tool_calls)}: {tool_call["name"]} with args: {tool_call["args"]}')
await self.stream_callback({'type': 'tool_start', 'data': tool_call["name"]})
try:
if tool_call["name"] == "explain_image" and state.get("image_data"):
tool_args = tool_call["args"].copy()
tool_args["image"] = state["image_data"]
logger.info(f'Executing tool {tool_call["name"]} with args: {tool_args}')
tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_args)
state["process_image_used"] = True
else:
tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_call["args"])
if "code" in tool_call["name"]:
content = str(tool_result)
elif isinstance(tool_result, str):
content = tool_result
else:
content = json.dumps(tool_result)
except Exception as e:
logger.error(f'Error executing tool {tool_call["name"]}: {str(e)}', exc_info=True)
content = f"Error executing tool '{tool_call['name']}': {str(e)}"
await self.stream_callback({'type': 'tool_end', 'data': tool_call["name"]})
outputs.append(
ToolMessage(
content=content,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
state["iterations"] = state.get("iterations", 0) + 1
logger.debug({
"message": "GRAPH: EXITING NODE - action/tool_node",
"chat_id": state.get("chat_id"),
"iterations": state.get("iterations"),
"tools_executed": len(outputs),
"next_step": "→ returning to generate"
})
await self.stream_callback({'type': 'node_end', 'data': 'tool_node'})
return {"messages": messages + outputs, "iterations": state.get("iterations", 0) + 1}
async def generate(self, state: State) -> Dict[str, Any]:
"""Generate AI response using the current model.
Args:
state: Current graph state
Returns:
Updated state with new AI message
"""
messages = convert_langgraph_messages_to_openai(state.get("messages", []))
logger.debug({
"message": "GRAPH: ENTERING NODE - generate",
"chat_id": state.get("chat_id"),
"iterations": state.get("iterations", 0),
"current_model": self.current_model,
"message_count": len(state.get("messages", []))
})
await self.stream_callback({'type': 'node_start', 'data': 'generate'})
supports_tools = self.current_model in {"gpt-oss-20b", "gpt-oss-120b"}
has_tools = supports_tools and self.openai_tools and len(self.openai_tools) > 0
logger.debug({
"message": "Tool calling debug info",
"chat_id": state.get("chat_id"),
"current_model": self.current_model,
"supports_tools": supports_tools,
"openai_tools_count": len(self.openai_tools) if self.openai_tools else 0,
"openai_tools": self.openai_tools,
"has_tools": has_tools
})
tool_params = {}
if has_tools:
tool_params = {
"tools": self.openai_tools,
"tool_choice": "auto"
}
stream = await self.model_client.chat.completions.create(
model=self.current_model,
messages=messages,
temperature=0,
top_p=1,
stream=True,
**tool_params
)
llm_output_buffer, tool_calls_buffer = await self._stream_response(stream, self.stream_callback)
tool_calls = self._format_tool_calls(tool_calls_buffer)
raw_output = "".join(llm_output_buffer)
logger.debug({
"message": "Tool call generation results",
"chat_id": state.get("chat_id"),
"tool_calls_buffer": tool_calls_buffer,
"formatted_tool_calls": tool_calls,
"tool_calls_count": len(tool_calls),
"raw_output_length": len(raw_output),
"raw_output": raw_output[:200] + "..." if len(raw_output) > 200 else raw_output
})
response = AIMessage(
content=raw_output,
**({"tool_calls": tool_calls} if tool_calls else {})
)
logger.debug({
"message": "GRAPH: EXITING NODE - generate",
"chat_id": state.get("chat_id"),
"iterations": state.get("iterations", 0),
"response_length": len(response.content) if response.content else 0,
"tool_calls_generated": len(tool_calls),
"tool_calls_names": [tc["name"] for tc in tool_calls] if tool_calls else [],
"next_step": "→ should_continue decision"
})
await self.stream_callback({'type': 'node_end', 'data': 'generate'})
return {"messages": state.get("messages", []) + [response]}
def _build_graph(self) -> StateGraph:
"""Build the LangGraph state machine for conversation flow.
Returns:
Compiled StateGraph with nodes and conditional edges
"""
workflow = StateGraph(State)
workflow.add_node("generate", self.generate)
workflow.add_node("action", self.tool_node)
workflow.add_edge(START, "generate")
workflow.add_conditional_edges(
"generate",
self.should_continue,
{
"continue": "action",
"end": END,
},
)
workflow.add_edge("action", "generate")
return workflow.compile(checkpointer=memory)
def _format_tool_calls(self, tool_calls_buffer: Dict[int, Dict[str, str]]) -> List[ToolCall]:
"""Parse streamed tool call buffer into ToolCall objects.
Args:
tool_calls_buffer: Buffer of streamed tool call data
Returns:
List of formatted ToolCall objects
"""
if not tool_calls_buffer:
return []
tool_calls = []
for i in sorted(tool_calls_buffer):
item = tool_calls_buffer[i]
try:
parsed_args = json.loads(item["arguments"] or "{}")
except json.JSONDecodeError:
parsed_args = {}
tool_calls.append(
ToolCall(
name=item["name"],
args=parsed_args,
id=item["id"] or f"call_{i}",
)
)
return tool_calls
async def _stream_response(self, stream, stream_callback: StreamCallback) -> tuple[List[str], Dict[int, Dict[str, str]]]:
"""Process streaming LLM response and extract content and tool calls.
Args:
stream: Async stream from LLM
stream_callback: Callback for streaming events
Returns:
Tuple of (content_buffer, tool_calls_buffer)
"""
llm_output_buffer = []
tool_calls_buffer = {}
saw_tool_finish = False
async for chunk in stream:
for choice in getattr(chunk, "choices", []) or []:
delta = getattr(choice, "delta", None)
if not delta:
continue
content = getattr(delta, "content", None)
if content:
await stream_callback({"type": "token", "data": content})
llm_output_buffer.append(content)
for tc in getattr(delta, "tool_calls", []) or []:
idx = getattr(tc, "index", None)
if idx is None:
idx = 0 if not tool_calls_buffer else max(tool_calls_buffer) + 1
entry = tool_calls_buffer.setdefault(idx, {"id": None, "name": None, "arguments": ""})
if getattr(tc, "id", None):
entry["id"] = tc.id
fn = getattr(tc, "function", None)
if fn:
if getattr(fn, "name", None):
entry["name"] = fn.name
if getattr(fn, "arguments", None):
entry["arguments"] += fn.arguments
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason == "tool_calls":
saw_tool_finish = True
break
if saw_tool_finish:
break
return llm_output_buffer, tool_calls_buffer
async def query(self, query_text: str, chat_id: str, image_data: str = None) -> AsyncIterator[Dict[str, Any]]:
"""Process user query and stream response tokens.
Args:
query_text: User's input text
chat_id: Unique chat identifier
Yields:
Streaming events and tokens
"""
logger.debug({
"message": "GRAPH: STARTING EXECUTION",
"chat_id": chat_id,
"query": query_text[:100] + "..." if len(query_text) > 100 else query_text,
"graph_flow": "START → generate → should_continue → action → generate → END"
})
config = {"configurable": {"thread_id": chat_id}}
try:
existing_messages = await self.conversation_store.get_messages(chat_id, limit=10)
base_system_prompt = self.system_prompt
if image_data:
image_context = "\n\nIMAGE CONTEXT: The user has uploaded an image with their message. You MUST use the explain_image tool to analyze it."
system_prompt_with_image = base_system_prompt + image_context
messages_to_process = [SystemMessage(content=system_prompt_with_image)]
else:
messages_to_process = [SystemMessage(content=base_system_prompt)]
if existing_messages:
for msg in existing_messages:
if not isinstance(msg, SystemMessage):
messages_to_process.append(msg)
messages_to_process.append(HumanMessage(content=query_text))
config_obj = self.config_manager.read_config()
initial_state = {
"iterations": 0,
"chat_id": chat_id,
"messages": messages_to_process,
"image_data": image_data if image_data else None,
"process_image_used": False
}
model_name = self.config_manager.get_selected_model()
if self.current_model != model_name:
self.set_current_model(model_name)
logger.debug({
"message": "GRAPH: LAUNCHING EXECUTION",
"chat_id": chat_id,
"initial_state": {
"iterations": initial_state["iterations"],
"message_count": len(initial_state["messages"]),
}
})
self.last_state = None
token_q: asyncio.Queue[Any] = asyncio.Queue()
self.stream_callback = lambda event: self._queue_writer(event, token_q)
runner = asyncio.create_task(self._run_graph(initial_state, config, chat_id, token_q))
try:
while True:
item = await token_q.get()
if item is SENTINEL:
break
yield item
except Exception as stream_error:
logger.error({"message": "Error in streaming", "error": str(stream_error)}, exc_info=True)
finally:
with contextlib.suppress(asyncio.CancelledError):
await runner
logger.debug({
"message": "GRAPH: EXECUTION COMPLETED",
"chat_id": chat_id,
"final_iterations": self.last_state.get("iterations", 0) if self.last_state else 0
})
except Exception as e:
logger.error({"message": "GRAPH: EXECUTION FAILED", "error": str(e), "chat_id": chat_id}, exc_info=True)
yield {"type": "error", "data": f"Error performing query: {str(e)}"}
async def _queue_writer(self, event: Dict[str, Any], token_q: asyncio.Queue) -> None:
"""Write events to the streaming queue.
Args:
event: Event data to queue
token_q: Queue for streaming events
"""
await token_q.put(event)
async def _run_graph(self, initial_state: Dict[str, Any], config: Dict[str, Any], chat_id: str, token_q: asyncio.Queue) -> None:
"""Run the graph execution in background task.
Args:
initial_state: Starting state for graph
config: LangGraph configuration
chat_id: Chat identifier
token_q: Queue for streaming events
"""
try:
async for final_state in self.graph.astream(
initial_state,
config=config,
stream_mode="values",
stream_writer=lambda event: self._queue_writer(event, token_q)
):
self.last_state = final_state
finally:
try:
if self.last_state and self.last_state.get("messages"):
final_msg = self.last_state["messages"][-1]
try:
logger.debug(f'Saving messages to conversation store for chat: {chat_id}')
await self.conversation_store.save_messages(chat_id, self.last_state["messages"])
except Exception as save_err:
logger.warning({"message": "Failed to persist conversation", "chat_id": chat_id, "error": str(save_err)})
content = getattr(final_msg, "content", None)
if content:
await token_q.put(content)
finally:
await token_q.put(SENTINEL)

View File

@ -0,0 +1,93 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Multi-Server MCP Client for connecting to multiple MCP servers.
This module provides a unified client interface for connecting to and managing
multiple Model Context Protocol (MCP) servers. It handles server configuration,
initialization, and tool retrieval across different server types.
"""
from typing import List, Optional
from langchain_mcp_adapters.client import MultiServerMCPClient
from mcp.types import Tool
class MCPClient:
"""Client for managing connections to multiple MCP servers.
Provides a unified interface for connecting to and interacting with
various MCP servers including RAG, image understanding, and weather services.
"""
def __init__(self):
"""Initialize the MCP client with predefined server configurations."""
self.server_configs = {
"image-understanding-server": {
"command": "python",
"args": ["tools/mcp_servers/image_understanding.py"],
"transport": "stdio",
},
"code-generation-server": {
"command": "python",
"args": ["tools/mcp_servers/code_generation.py"],
"transport": "stdio",
},
"rag-server": {
"command": "python",
"args": ["tools/mcp_servers/rag.py"],
"transport": "stdio",
},
"weather-server": {
"command": "python",
"args": ["tools/mcp_servers/weather_test.py"],
"transport": "stdio",
}
}
self.mcp_client: MultiServerMCPClient | None = None
async def init(self):
"""Initialize the multi-server MCP client.
Returns:
MCPClient: Self for method chaining
Raises:
Exception: If client initialization fails
"""
self.mcp_client = MultiServerMCPClient(self.server_configs)
return self
async def get_tools(self):
"""Retrieve available tools from all connected MCP servers.
Returns:
List[Tool]: List of available tools from all servers
Raises:
RuntimeError: If client is not initialized
Exception: If tool retrieval fails
"""
if not self.mcp_client:
raise RuntimeError("MCP client not initialized. Call `await init()` first.")
try:
tools = await self.mcp_client.get_tools()
return tools
except Exception as error:
print("Error encountered connecting to MCP server. Is the server running? Is your config server path correct?\n")
raise error

View File

@ -0,0 +1,165 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""ConfigManager for managing the configuration of the chat application."""
import json
import os
import logging
import threading
from typing import List
from logger import logger
from models import ChatConfig
class ConfigManager:
def __init__(self, config_path: str):
"""Initialize the ConfigManager"""
self.config_path = config_path
self.config = None
self._last_modified = 0
self._lock = threading.Lock()
self._ensure_config_exists()
self.read_config()
def _ensure_config_exists(self) -> None:
"""Ensure config.json exists, creating it with default values if not."""
models = []
models = os.getenv("MODELS", "")
if models:
models = [model.strip() for model in models.split(",") if model.strip()]
else:
logger.warning("MODELS environment variable not set, using empty models list")
if not os.path.exists(self.config_path):
logger.debug(f"Config file {self.config_path} not found, creating default config")
default_config = ChatConfig(
sources=[],
models=models,
selected_model=models[0] if models else None,
selected_sources=[],
current_chat_id=None
)
with open(self.config_path, "w") as f:
json.dump(default_config.model_dump(), f, indent=2)
else:
try:
with open(self.config_path, "r") as f:
data = json.load(f)
existing_config = ChatConfig(**data)
if models:
existing_config.models = models
if not existing_config.selected_model or existing_config.selected_model not in models:
existing_config.selected_model = models[0]
with open(self.config_path, "w") as f:
json.dump(existing_config.model_dump(), f, indent=2)
logger.debug(f"Updated existing config with models: {models}")
except Exception as e:
logger.error(f"Error updating existing config: {e}")
default_config = ChatConfig(
sources=[],
models=models,
selected_model=models[0] if models else None,
selected_sources=[],
current_chat_id=None
)
with open(self.config_path, "w") as f:
json.dump(default_config.model_dump(), f, indent=2)
def read_config(self) -> ChatConfig:
"""Read config from file, but only if it has changed since last read."""
with self._lock:
try:
current_mtime = os.path.getmtime(self.config_path)
if self.config is None or current_mtime > self._last_modified:
with open(self.config_path, "r") as f:
data = json.load(f)
self.config = ChatConfig(**data)
self._last_modified = current_mtime
return self.config
except Exception as e:
logger.error(f"Error reading config: {e}")
if self.config is None:
models = []
models = os.getenv("MODELS", "")
if models:
models = [model.strip() for model in models.split(",") if model.strip()]
self.config = ChatConfig(
sources=[],
models=models,
selected_model=models[0] if models else "gpt-oss-120b",
selected_sources=[],
current_chat_id="1"
)
return self.config
def write_config(self, new_config: ChatConfig) -> None:
"""Thread-safe write config to file."""
with self._lock:
with open(self.config_path, "w") as f:
json.dump(new_config.model_dump(), f, indent=2)
self.config = new_config
self._last_modified = os.path.getmtime(self.config_path)
def get_sources(self) -> List[str]:
"""Return list of available sources."""
self.config = self.read_config()
return self.config.sources
def get_selected_sources(self) -> List[str]:
"""Return list of selected sources."""
self.config = self.read_config()
return self.config.selected_sources
def get_available_models(self) -> List[str]:
"""Return list of available models."""
self.config = self.read_config()
return self.config.models
def get_selected_model(self) -> str:
"""Return the selected model."""
self.config = self.read_config()
logger.debug(f"Selected model: {self.config.selected_model}")
return self.config.selected_model
def get_current_chat_id(self) -> str:
"""Return the current chat id."""
self.config = self.read_config()
return self.config.current_chat_id
def updated_selected_sources(self, new_sources: List[str]) -> None:
"""Update the selected sources in the config."""
self.config = self.read_config().model_copy(update={"selected_sources": new_sources})
self.write_config(self.config)
def updated_selected_model(self, new_model: str) -> None:
"""Update the selected model in the config."""
self.config = self.read_config().model_copy(update={"selected_model": new_model})
logger.debug(f"Updated selected model to: {new_model}")
self.write_config(self.config)
def updated_current_chat_id(self, new_chat_id: str) -> None:
"""Update the current chat id in the config."""
self.config = self.read_config().model_copy(update={"current_chat_id": new_chat_id})
self.write_config(self.config)

View File

@ -0,0 +1,145 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import sys
import traceback
from datetime import datetime
from typing import Any, Dict, Optional
class JsonFormatter(logging.Formatter):
"""
Formatter that outputs JSON strings after parsing the log record.
"""
def __init__(self, **kwargs):
self.default_keys = {
'timestamp': 'asctime',
'level': 'levelname',
'message': 'message'
}
self.default_keys.update(kwargs)
def format(self, record: logging.LogRecord) -> str:
"""
Format the log record as JSON.
"""
log_record = {}
log_record['timestamp'] = datetime.utcnow().isoformat() + 'Z'
log_record['level'] = record.levelname
log_record['logger'] = record.name
if isinstance(record.msg, dict):
log_record['message'] = record.msg.get('message', '')
for key, value in record.msg.items():
if key != 'message':
log_record[key] = value
else:
log_record['message'] = record.getMessage()
if record.exc_info:
log_record['exception'] = {
'type': record.exc_info[0].__name__,
'message': str(record.exc_info[1]),
'traceback': traceback.format_exception(*record.exc_info)
}
for key, value in record.__dict__.items():
if key not in ['msg', 'args', 'exc_info', 'exc_text', 'stack_info', 'lineno',
'funcName', 'created', 'msecs', 'relativeCreated', 'levelname',
'levelno', 'pathname', 'filename', 'module', 'name', 'thread',
'threadName', 'processName', 'process']:
log_record[key] = value
return json.dumps(log_record)
def setup_logger(name: str = 'backend',
level: int = logging.INFO,
log_file: Optional[str] = 'app.log') -> logging.Logger:
"""
Set up a JSON logger with console and file handlers.
Args:
name: Logger name
level: Logging level
log_file: Path to log file (None for no file logging)
Returns:
Configured logger instance
"""
logger = logging.getLogger(name)
logger.setLevel(level)
logger.propagate = False
for handler in logger.handlers[:]:
logger.removeHandler(handler)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(JsonFormatter())
logger.addHandler(console_handler)
if log_file:
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(JsonFormatter())
logger.addHandler(file_handler)
return logger
logger = setup_logger()
def log_request(request_data: Dict[str, Any], endpoint: str) -> None:
"""
Log an API request with structured data.
"""
logger.info({
'message': f'API request to {endpoint}',
'endpoint': endpoint,
'request_data': request_data
})
def log_response(response_data: Dict[str, Any], endpoint: str, status_code: int = 200) -> None:
"""
Log an API response with structured data.
"""
logger.info({
'message': f'API response from {endpoint}',
'endpoint': endpoint,
'status_code': status_code,
'response_data': response_data
})
def log_error(error: Exception, endpoint: str = None, request_data: Dict[str, Any] = None) -> None:
"""
Log an error with structured data.
"""
error_data = {
'message': f'Error: {str(error)}',
'error_type': error.__class__.__name__,
}
if endpoint:
error_data['endpoint'] = endpoint
if request_data:
error_data['request_data'] = request_data
logger.error(error_data, exc_info=True)

View File

@ -0,0 +1,516 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""FastAPI backend server for the chatbot application.
This module provides the main HTTP API endpoints and WebSocket connections for:
- Real-time chat via WebSocket
- File upload and document ingestion
- Configuration management (models, sources, chat settings)
- Chat history management
- Vector store operations
"""
import base64
import json
import os
import uuid
from contextlib import asynccontextmanager
from typing import List, Optional, Dict
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from agent import ChatAgent
from config import ConfigManager
from logger import logger, log_request, log_response, log_error
from models import ChatIdRequest, ChatRenameRequest, SelectedModelRequest
from postgres_storage import PostgreSQLConversationStorage
from utils import process_and_ingest_files_background
from vector_store import create_vector_store_with_config
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", 5432))
POSTGRES_DB = os.getenv("POSTGRES_DB", "chatbot")
POSTGRES_USER = os.getenv("POSTGRES_USER", "chatbot_user")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "chatbot_password")
config_manager = ConfigManager("./config.json")
postgres_storage = PostgreSQLConversationStorage(
host=POSTGRES_HOST,
port=POSTGRES_PORT,
database=POSTGRES_DB,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD
)
vector_store = create_vector_store_with_config(config_manager)
vector_store._initialize_store()
agent: ChatAgent | None = None
indexing_tasks: Dict[str, str] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for startup and shutdown tasks."""
global agent
logger.debug("Initializing PostgreSQL storage and agent...")
try:
await postgres_storage.init_pool()
logger.info("PostgreSQL storage initialized successfully")
logger.debug("Initializing ChatAgent...")
agent = await ChatAgent.create(
vector_store=vector_store,
config_manager=config_manager,
postgres_storage=postgres_storage
)
logger.info("ChatAgent initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL storage: {e}")
raise
yield
try:
await postgres_storage.close()
logger.debug("PostgreSQL storage closed successfully")
except Exception as e:
logger.error(f"Error closing PostgreSQL storage: {e}")
app = FastAPI(
title="Chatbot API",
description="Backend API for LLM-powered chatbot with RAG capabilities",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.websocket("/ws/chat/{chat_id}")
async def websocket_endpoint(websocket: WebSocket, chat_id: str):
"""WebSocket endpoint for real-time chat communication.
Args:
websocket: WebSocket connection
chat_id: Unique chat identifier
"""
logger.debug(f"WebSocket connection attempt for chat_id: {chat_id}")
try:
await websocket.accept()
logger.debug(f"WebSocket connection accepted for chat_id: {chat_id}")
history_messages = await postgres_storage.get_messages(chat_id)
history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(history_messages) if i != 0]
await websocket.send_json({"type": "history", "messages": history})
while True:
data = await websocket.receive_text()
client_message = json.loads(data)
new_message = client_message.get("message")
image_id = client_message.get("image_id")
image_data = None
if image_id:
image_data = await postgres_storage.get_image(image_id)
logger.debug(f"Retrieved image data for image_id: {image_id}, data length: {len(image_data) if image_data else 0}")
try:
async for event in agent.query(query_text=new_message, chat_id=chat_id, image_data=image_data):
await websocket.send_json(event)
except Exception as query_error:
logger.error(f"Error in agent.query: {str(query_error)}", exc_info=True)
await websocket.send_json({"type": "error", "content": f"Error processing request: {str(query_error)}"})
final_messages = await postgres_storage.get_messages(chat_id)
final_history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(final_messages) if i != 0]
await websocket.send_json({"type": "history", "messages": final_history})
except WebSocketDisconnect:
logger.debug(f"Client disconnected from chat {chat_id}")
except Exception as e:
logger.error(f"WebSocket error for chat {chat_id}: {str(e)}", exc_info=True)
@app.post("/upload-image")
async def upload_image(image: UploadFile = File(...), chat_id: str = Form(...)):
"""Upload and store an image for chat processing.
Args:
image: Uploaded image file
chat_id: Chat identifier for context
Returns:
Dictionary with generated image_id
"""
image_data = await image.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
data_uri = f"data:{image.content_type};base64,{image_base64}"
image_id = str(uuid.uuid4())
await postgres_storage.store_image(image_id, data_uri)
return {"image_id": image_id}
@app.post("/ingest")
async def ingest_files(files: Optional[List[UploadFile]] = File(None), background_tasks: BackgroundTasks = None):
"""Ingest documents for vector search and RAG.
Args:
files: List of uploaded files to process
background_tasks: FastAPI background tasks manager
Returns:
Task information for tracking ingestion progress
"""
try:
log_request({"file_count": len(files) if files else 0}, "/ingest")
task_id = str(uuid.uuid4())
file_info = []
for file in files:
content = await file.read()
file_info.append({
"filename": file.filename,
"content": content
})
indexing_tasks[task_id] = "queued"
background_tasks.add_task(
process_and_ingest_files_background,
file_info,
vector_store,
config_manager,
task_id,
indexing_tasks
)
response = {
"message": f"Files queued for processing. Indexing {len(files)} files in the background.",
"files": [file.filename for file in files],
"status": "queued",
"task_id": task_id
}
log_response(response, "/ingest")
return response
except Exception as e:
log_error(e, "/ingest")
raise HTTPException(
status_code=500,
detail=f"Error queuing files for ingestion: {str(e)}"
)
@app.get("/ingest/status/{task_id}")
async def get_indexing_status(task_id: str):
"""Get the status of a file ingestion task.
Args:
task_id: Unique task identifier
Returns:
Current task status
"""
if task_id in indexing_tasks:
return {"status": indexing_tasks[task_id]}
else:
raise HTTPException(status_code=404, detail="Task not found")
@app.get("/sources")
async def get_sources():
"""Get all available document sources."""
try:
config = config_manager.read_config()
return {"sources": config.sources}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting sources: {str(e)}")
@app.get("/selected_sources")
async def get_selected_sources():
"""Get currently selected document sources for RAG."""
try:
config = config_manager.read_config()
return {"sources": config.selected_sources}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting selected sources: {str(e)}")
@app.post("/selected_sources")
async def update_selected_sources(selected_sources: List[str]):
"""Update the selected document sources for RAG.
Args:
selected_sources: List of source names to use for retrieval
"""
try:
config_manager.updated_selected_sources(selected_sources)
return {"status": "success", "message": "Selected sources updated successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error updating selected sources: {str(e)}")
@app.get("/selected_model")
async def get_selected_model():
"""Get the currently selected LLM model."""
try:
model = config_manager.get_selected_model()
return {"model": model}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting selected model: {str(e)}")
@app.post("/selected_model")
async def update_selected_model(request: SelectedModelRequest):
"""Update the selected LLM model.
Args:
request: Model selection request with model name
"""
try:
logger.debug(f"Updating selected model to: {request.model}")
config_manager.updated_selected_model(request.model)
return {"status": "success", "message": "Selected model updated successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error updating selected model: {str(e)}")
@app.get("/available_models")
async def get_available_models():
"""Get list of all available LLM models."""
try:
models = config_manager.get_available_models()
return {"models": models}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting available models: {str(e)}")
@app.get("/chats")
async def list_chats():
"""Get list of all chat conversations."""
try:
chat_ids = await postgres_storage.list_conversations()
return {"chats": chat_ids}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error listing chats: {str(e)}")
@app.get("/chat_id")
async def get_chat_id():
"""Get the current active chat ID, creating a conversation if it doesn't exist."""
try:
config = config_manager.read_config()
current_chat_id = config.current_chat_id
if current_chat_id and await postgres_storage.exists(current_chat_id):
return {
"status": "success",
"chat_id": current_chat_id
}
new_chat_id = str(uuid.uuid4())
await postgres_storage.save_messages_immediate(new_chat_id, [])
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
config_manager.updated_current_chat_id(new_chat_id)
return {
"status": "success",
"chat_id": new_chat_id
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error getting chat ID: {str(e)}"
)
@app.post("/chat_id")
async def update_chat_id(request: ChatIdRequest):
"""Update the current active chat ID.
Args:
request: Chat ID update request
"""
try:
config_manager.updated_current_chat_id(request.chat_id)
return {
"status": "success",
"message": f"Current chat ID updated to {request.chat_id}",
"chat_id": request.chat_id
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error updating chat ID: {str(e)}"
)
@app.get("/chat/{chat_id}/metadata")
async def get_chat_metadata(chat_id: str):
"""Get metadata for a specific chat.
Args:
chat_id: Unique chat identifier
Returns:
Chat metadata including name
"""
try:
metadata = await postgres_storage.get_chat_metadata(chat_id)
return metadata
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error getting chat metadata: {str(e)}"
)
@app.post("/chat/rename")
async def rename_chat(request: ChatRenameRequest):
"""Rename a chat conversation.
Args:
request: Chat rename request with chat_id and new_name
"""
try:
await postgres_storage.set_chat_metadata(request.chat_id, request.new_name)
return {
"status": "success",
"message": f"Chat {request.chat_id} renamed to {request.new_name}"
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error renaming chat: {str(e)}"
)
@app.post("/chat/new")
async def create_new_chat():
"""Create a new chat conversation and set it as current."""
try:
new_chat_id = str(uuid.uuid4())
await postgres_storage.save_messages_immediate(new_chat_id, [])
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
config_manager.updated_current_chat_id(new_chat_id)
return {
"status": "success",
"message": "New chat created",
"chat_id": new_chat_id
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error creating new chat: {str(e)}"
)
@app.delete("/chat/{chat_id}")
async def delete_chat(chat_id: str):
"""Delete a specific chat and its messages.
Args:
chat_id: Unique chat identifier to delete
"""
try:
success = await postgres_storage.delete_conversation(chat_id)
if success:
return {
"status": "success",
"message": f"Chat {chat_id} deleted successfully"
}
else:
raise HTTPException(
status_code=404,
detail=f"Chat {chat_id} not found"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error deleting chat: {str(e)}"
)
@app.delete("/chats/clear")
async def clear_all_chats():
"""Clear all chat conversations and create a new default chat."""
try:
chat_ids = await postgres_storage.list_conversations()
cleared_count = 0
for chat_id in chat_ids:
if await postgres_storage.delete_conversation(chat_id):
cleared_count += 1
new_chat_id = str(uuid.uuid4())
await postgres_storage.save_messages_immediate(new_chat_id, [])
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
config_manager.updated_current_chat_id(new_chat_id)
return {
"status": "success",
"message": f"Cleared {cleared_count} chats and created new chat",
"new_chat_id": new_chat_id,
"cleared_count": cleared_count
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error clearing all chats: {str(e)}"
)
@app.delete("/collections/{collection_name}")
async def delete_collection(collection_name: str):
"""Delete a document collection from the vector store.
Args:
collection_name: Name of the collection to delete
"""
try:
success = vector_store.delete_collection(collection_name)
if success:
return {"status": "success", "message": f"Collection '{collection_name}' deleted successfully"}
else:
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found or could not be deleted")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)

View File

@ -0,0 +1,35 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pydantic import BaseModel
from typing import Optional, List
class ChatConfig(BaseModel):
sources: List[str]
models : List[str]
selected_model: Optional[str] = None
selected_sources: Optional[List[str]] = None
current_chat_id: Optional[str] = None
class ChatIdRequest(BaseModel):
chat_id: str
class ChatRenameRequest(BaseModel):
chat_id: str
new_name: str
class SelectedModelRequest(BaseModel):
model: str

View File

@ -0,0 +1,571 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PostgreSQL-based conversation storage with caching and I/O optimization."""
import json
import time
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime, timedelta
import asyncio
import asyncpg
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage, ToolMessage
from logger import logger
@dataclass
class CacheEntry:
"""Cache entry with TTL support."""
data: Any
timestamp: float
ttl: float = 300
def is_expired(self) -> bool:
return time.time() - self.timestamp > self.ttl
class PostgreSQLConversationStorage:
"""PostgreSQL-based conversation storage with intelligent caching and I/O optimization."""
def __init__(
self,
host: str = 'postgres',
port: int = 5432,
database: str = 'chatbot',
user: str = 'chatbot_user',
password: str = 'chatbot_password',
pool_size: int = 10,
cache_ttl: int = 300
):
"""Initialize PostgreSQL connection pool and caching.
Args:
host: PostgreSQL host
port: PostgreSQL port
database: Database name
user: Database user
password: Database password
pool_size: Connection pool size
cache_ttl: Cache TTL in seconds
"""
self.host = host
self.port = port
self.database = database
self.user = user
self.password = password
self.pool_size = pool_size
self.cache_ttl = cache_ttl
self.pool: Optional[asyncpg.Pool] = None
self._message_cache: Dict[str, CacheEntry] = {}
self._metadata_cache: Dict[str, CacheEntry] = {}
self._image_cache: Dict[str, CacheEntry] = {}
self._chat_list_cache: Optional[CacheEntry] = None
self._pending_saves: Dict[str, List[BaseMessage]] = {}
self._save_lock = asyncio.Lock()
self._batch_save_task: Optional[asyncio.Task] = None
self._cache_hits = 0
self._cache_misses = 0
self._db_operations = 0
async def init_pool(self) -> None:
"""Initialize the connection pool and create tables."""
try:
await self._ensure_database_exists()
self.pool = await asyncpg.create_pool(
host=self.host,
port=self.port,
database=self.database,
user=self.user,
password=self.password,
min_size=2,
max_size=self.pool_size,
command_timeout=30
)
await self._create_tables()
logger.debug("PostgreSQL connection pool initialized successfully")
self._batch_save_task = asyncio.create_task(self._batch_save_worker())
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL pool: {e}")
raise
async def _ensure_database_exists(self) -> None:
"""Ensure the target database exists, create if it doesn't."""
try:
conn = await asyncpg.connect(
host=self.host,
port=self.port,
database='postgres',
user=self.user,
password=self.password
)
try:
result = await conn.fetchval(
"SELECT 1 FROM pg_database WHERE datname = $1",
self.database
)
if not result:
await conn.execute(f'CREATE DATABASE "{self.database}"')
logger.debug(f"Created database: {self.database}")
else:
logger.debug(f"Database {self.database} already exists")
finally:
await conn.close()
except Exception as e:
logger.error(f"Error ensuring database exists: {e}")
pass
async def close(self) -> None:
"""Close the connection pool and cleanup."""
if self._batch_save_task:
self._batch_save_task.cancel()
try:
await self._batch_save_task
except asyncio.CancelledError:
pass
if self.pool:
await self.pool.close()
logger.debug("PostgreSQL connection pool closed")
async def _create_tables(self) -> None:
"""Create necessary tables if they don't exist."""
async with self.pool.acquire() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS conversations (
chat_id VARCHAR(255) PRIMARY KEY,
messages JSONB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
message_count INTEGER DEFAULT 0
)
""")
await conn.execute("""
CREATE TABLE IF NOT EXISTS chat_metadata (
chat_id VARCHAR(255) PRIMARY KEY,
name VARCHAR(500),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (chat_id) REFERENCES conversations(chat_id) ON DELETE CASCADE
)
""")
await conn.execute("""
CREATE TABLE IF NOT EXISTS images (
image_id VARCHAR(255) PRIMARY KEY,
image_data TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP + INTERVAL '1 hour')
)
""")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_images_expires_at ON images(expires_at)")
await conn.execute("""
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql'
""")
await conn.execute("""
DROP TRIGGER IF EXISTS update_conversations_updated_at ON conversations
""")
await conn.execute("""
CREATE TRIGGER update_conversations_updated_at
BEFORE UPDATE ON conversations
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column()
""")
def _message_to_dict(self, message: BaseMessage) -> Dict:
"""Convert a message object to a dictionary for storage."""
result = {
"type": message.__class__.__name__,
"content": message.content,
}
if hasattr(message, "tool_calls") and message.tool_calls:
result["tool_calls"] = message.tool_calls
if isinstance(message, ToolMessage):
result["tool_call_id"] = getattr(message, "tool_call_id", None)
result["name"] = getattr(message, "name", None)
return result
def _dict_to_message(self, data: Dict) -> BaseMessage:
"""Convert a dictionary back to a message object."""
msg_type = data["type"]
content = data["content"]
if msg_type == "AIMessage":
msg = AIMessage(content=content)
if "tool_calls" in data:
msg.tool_calls = data["tool_calls"]
return msg
elif msg_type == "HumanMessage":
return HumanMessage(content=content)
elif msg_type == "SystemMessage":
return SystemMessage(content=content)
elif msg_type == "ToolMessage":
return ToolMessage(
content=content,
tool_call_id=data.get("tool_call_id", ""),
name=data.get("name", "")
)
else:
return HumanMessage(content=content)
def _get_cached_messages(self, chat_id: str) -> Optional[List[BaseMessage]]:
"""Get messages from cache if available and not expired."""
cache_entry = self._message_cache.get(chat_id)
if cache_entry and not cache_entry.is_expired():
self._cache_hits += 1
return cache_entry.data
self._cache_misses += 1
return None
def _cache_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
"""Cache messages with TTL."""
self._message_cache[chat_id] = CacheEntry(
data=messages.copy(),
timestamp=time.time(),
ttl=self.cache_ttl
)
def _invalidate_cache(self, chat_id: str) -> None:
"""Invalidate cache entries for a chat."""
self._message_cache.pop(chat_id, None)
self._metadata_cache.pop(chat_id, None)
self._chat_list_cache = None
async def exists(self, chat_id: str) -> bool:
"""Check if a conversation exists (with caching)."""
cached_messages = self._get_cached_messages(chat_id)
if cached_messages is not None:
return len(cached_messages) > 0
async with self.pool.acquire() as conn:
result = await conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM conversations WHERE chat_id = $1)",
chat_id
)
self._db_operations += 1
return result
async def get_messages(self, chat_id: str, limit: Optional[int] = None) -> List[BaseMessage]:
"""Retrieve messages for a chat session with caching."""
cached_messages = self._get_cached_messages(chat_id)
if cached_messages is not None:
return cached_messages[-limit:] if limit else cached_messages
async with self.pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT messages FROM conversations WHERE chat_id = $1",
chat_id
)
self._db_operations += 1
if not row:
return []
messages_data = row['messages']
if isinstance(messages_data, str):
messages_data = json.loads(messages_data)
messages = [self._dict_to_message(msg_data) for msg_data in messages_data]
self._cache_messages(chat_id, messages)
return messages[-limit:] if limit else messages
async def save_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
"""Save messages with batching for performance."""
async with self._save_lock:
self._pending_saves[chat_id] = messages.copy()
self._cache_messages(chat_id, messages)
async def save_messages_immediate(self, chat_id: str, messages: List[BaseMessage]) -> None:
"""Save messages immediately without batching - for critical operations."""
serialized_messages = [self._message_to_dict(msg) for msg in messages]
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO conversations (chat_id, messages, message_count)
VALUES ($1, $2, $3)
ON CONFLICT (chat_id)
DO UPDATE SET
messages = EXCLUDED.messages,
message_count = EXCLUDED.message_count,
updated_at = CURRENT_TIMESTAMP
""", chat_id, json.dumps(serialized_messages), len(messages))
self._db_operations += 1
self._cache_messages(chat_id, messages)
self._chat_list_cache = None
async def _batch_save_worker(self) -> None:
"""Background worker to batch save operations."""
while True:
try:
await asyncio.sleep(1.0)
async with self._save_lock:
if not self._pending_saves:
continue
saves_to_process = self._pending_saves.copy()
self._pending_saves.clear()
async with self.pool.acquire() as conn:
async with conn.transaction():
for chat_id, messages in saves_to_process.items():
serialized_messages = [self._message_to_dict(msg) for msg in messages]
await conn.execute("""
INSERT INTO conversations (chat_id, messages, message_count)
VALUES ($1, $2, $3)
ON CONFLICT (chat_id)
DO UPDATE SET
messages = EXCLUDED.messages,
message_count = EXCLUDED.message_count,
updated_at = CURRENT_TIMESTAMP
""", chat_id, json.dumps(serialized_messages), len(messages))
self._db_operations += len(saves_to_process)
if saves_to_process:
logger.debug(f"Batch saved {len(saves_to_process)} conversations")
self._chat_list_cache = None
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in batch save worker: {e}")
async def add_message(self, chat_id: str, message: BaseMessage) -> None:
"""Add a single message to conversation (optimized)."""
current_messages = await self.get_messages(chat_id)
current_messages.append(message)
await self.save_messages(chat_id, current_messages)
async def delete_conversation(self, chat_id: str) -> bool:
"""Delete a conversation by chat_id."""
try:
async with self.pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM conversations WHERE chat_id = $1",
chat_id
)
self._db_operations += 1
self._invalidate_cache(chat_id)
return "DELETE 1" in result
except Exception as e:
logger.error(f"Error deleting conversation {chat_id}: {e}")
return False
async def list_conversations(self) -> List[str]:
"""List all conversation IDs with caching."""
if self._chat_list_cache and not self._chat_list_cache.is_expired():
self._cache_hits += 1
return self._chat_list_cache.data
async with self.pool.acquire() as conn:
rows = await conn.fetch(
"SELECT chat_id FROM conversations ORDER BY updated_at DESC"
)
self._db_operations += 1
chat_ids = [row['chat_id'] for row in rows]
self._chat_list_cache = CacheEntry(
data=chat_ids,
timestamp=time.time(),
ttl=60
)
self._cache_misses += 1
return chat_ids
async def store_image(self, image_id: str, image_base64: str) -> None:
"""Store base64 image data with TTL."""
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO images (image_id, image_data)
VALUES ($1, $2)
ON CONFLICT (image_id)
DO UPDATE SET
image_data = EXCLUDED.image_data,
created_at = CURRENT_TIMESTAMP,
expires_at = CURRENT_TIMESTAMP + INTERVAL '1 hour'
""", image_id, image_base64)
self._db_operations += 1
self._image_cache[image_id] = CacheEntry(
data=image_base64,
timestamp=time.time(),
ttl=3600
)
async def get_image(self, image_id: str) -> Optional[str]:
"""Retrieve base64 image data with caching."""
cache_entry = self._image_cache.get(image_id)
if cache_entry and not cache_entry.is_expired():
self._cache_hits += 1
return cache_entry.data
async with self.pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT image_data FROM images
WHERE image_id = $1 AND expires_at > CURRENT_TIMESTAMP
""", image_id)
self._db_operations += 1
if row:
image_data = row['image_data']
self._image_cache[image_id] = CacheEntry(
data=image_data,
timestamp=time.time(),
ttl=3600
)
self._cache_misses += 1
return image_data
return None
async def get_chat_metadata(self, chat_id: str) -> Optional[Dict]:
"""Get chat metadata with caching."""
cache_entry = self._metadata_cache.get(chat_id)
if cache_entry and not cache_entry.is_expired():
self._cache_hits += 1
return cache_entry.data
async with self.pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT name, created_at FROM chat_metadata WHERE chat_id = $1",
chat_id
)
self._db_operations += 1
if row:
metadata = {
"name": row['name'],
"created_at": row['created_at'].isoformat()
}
else:
metadata = {"name": f"Chat {chat_id[:8]}"}
self._metadata_cache[chat_id] = CacheEntry(
data=metadata,
timestamp=time.time(),
ttl=self.cache_ttl
)
self._cache_misses += 1
return metadata
async def set_chat_metadata(self, chat_id: str, name: str) -> None:
"""Set chat metadata."""
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO chat_metadata (chat_id, name)
VALUES ($1, $2)
ON CONFLICT (chat_id)
DO UPDATE SET
name = EXCLUDED.name,
updated_at = CURRENT_TIMESTAMP
""", chat_id, name)
self._db_operations += 1
self._metadata_cache[chat_id] = CacheEntry(
data={"name": name},
timestamp=time.time(),
ttl=self.cache_ttl
)
async def cleanup_expired_images(self) -> int:
"""Clean up expired images and return count of deleted images."""
async with self.pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM images WHERE expires_at < CURRENT_TIMESTAMP"
)
self._db_operations += 1
expired_keys = [
key for key, entry in self._image_cache.items()
if entry.is_expired()
]
for key in expired_keys:
del self._image_cache[key]
deleted_count = int(result.split()[-1]) if result else 0
if deleted_count > 0:
logger.debug(f"Cleaned up {deleted_count} expired images")
return deleted_count
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache performance statistics."""
total_requests = self._cache_hits + self._cache_misses
hit_rate = (self._cache_hits / total_requests * 100) if total_requests > 0 else 0
return {
"cache_hits": self._cache_hits,
"cache_misses": self._cache_misses,
"hit_rate_percent": round(hit_rate, 2),
"db_operations": self._db_operations,
"cached_conversations": len(self._message_cache),
"cached_metadata": len(self._metadata_cache),
"cached_images": len(self._image_cache)
}
def load_conversation_history(self, chat_id: str) -> List[Dict]:
"""Legacy method - converts to async call."""
import asyncio
return asyncio.create_task(self._load_conversation_history_dict(chat_id))
async def _load_conversation_history_dict(self, chat_id: str) -> List[Dict]:
"""Load conversation history in dict format for compatibility."""
messages = await self.get_messages(chat_id)
return [self._message_to_dict(msg) for msg in messages]
def save_conversation_history(self, chat_id: str, messages: List[Dict]) -> None:
"""Legacy method - converts to async call."""
import asyncio
message_objects = [self._dict_to_message(msg) for msg in messages]
return asyncio.create_task(self.save_messages(chat_id, message_objects))

View File

@ -0,0 +1,152 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import jinja2
from typing import Dict
SUPERVISOR_AGENT_STR = """
You are a supervisor agent whose role is to be a helpful planner that can use tools to answer questions. DO NOT WRITE CODE YOURSELF, ALWAYS USE THE TOOLS.
{% if tools %}
IMPORTANT: You have access to these tools and you MUST use them when applicable and use tool response in your final answer:
{{ tools }}
CRITICAL RULES:
- **ALWAYS** use a tool when the user's request matches a tool's capability. For example:
- If the user asks to "generate code", "develop", "build", "create", "write a script", "make a website", "develop an app", etc. **MUST** use the write_code tool with appropriate programming_language parameter
- If the user asks to "search", "find", "summarize", "analyze documents/reports", "key points", etc. **MUST** use the search_documents tool with the query, don't add any other text to the query. You can assume that the user has already uploaded the document and just call the tool.
- If the user asks to analyze/describe/understand an image (e.g., "what's in this image", "describe the picture") **MUST** use the explain_image tool
- **NEVER EVER generate code yourself** - you are FORBIDDEN from writing code directly. ALWAYS use the write_code tool for ANY coding requests
- **DO NOT** try to answer questions from documents yourself - always use the search_documents tool
CODING KEYWORDS that REQUIRE write_code tool:
- "code", "develop", "build", "create", "write", "make", "implement", "program", "script", "website", "app", "function", "class", "HTML", "CSS", "JavaScript", "Python", "React", "component"
Batching policy:
- **Batch** when: (a) calls are independent (e.g., weather in two cities), (b) calls target different tools without dependency, or (c) multiple calls to the same tool with different arguments.
- **Do not batch** when: a calls arguments depend on a previous tools output (e.g., writing code which depends on the output of a search_documents tool).
Output protocol:
- In the first assistant message of a turn, if tools are needed, **emit all tool calls together** (as multiple tool_calls). Do not include narrative text before the tool_calls unless required by the API.
- After the ToolMessages arrive, produce a single assistant message with the final answer incorporating all results. Do not call the tools again for the same purpose.
- **CRITICAL**: When you receive tool results, you MUST use them in your final response. Do NOT ignore successful tool results or claim you don't have information when tools have already provided it.
- If any tool call succeeds, base your answer on the successful results. Ignore failed tool calls if you have successful ones.
- Always present the information from successful tool calls as your definitive answer.
Few-shot examples:
# Direct coding request
User: Create a responsive personal website for my AI development business
Assistant (tool calls immediately):
- write_code({"query": "Create a responsive personal website for my AI development business", "programming_language": "HTML"})
# Batching independent calls
User: now, can you get the weather in egypt and the rain forecast in malibu?
Assistant (tool calls in one message):
- get_weather({"location": "Egypt"})
- get_rain_forecast({"location": "Malibu"})
# Staged dependent calls
User: Search my documents for design requirements then build a website based on those requirements
Assistant (first message; dependent plan):
- search_documents({"query": "design requirements website"})
# (Wait for ToolMessage)
Assistant (after ToolMessage):
- write_code({"query": "build a website based on these design requirements: <extracted information>", "programming_language": "HTML"})
# (Then produce final answer)
# Using successful tool results
User: Can you search NVIDIA's earnings document and summarize the key points?
Assistant (tool calls):
- search_documents({"query": "NVIDIA earnings document"})
# (Wait for ToolMessage with comprehensive earnings data)
Assistant (final response):
Based on NVIDIA's earnings document, here are the key highlights:
[...continues with the actual data from tool results...]
{% else %}
You do not have access to any tools right now.
{% endif %}
"""
PROMPT_TEMPLATES = {
"supervisor_agent": SUPERVISOR_AGENT_STR,
}
TEMPLATES: Dict[str, jinja2.Template] = {
name: jinja2.Template(template) for name, template in PROMPT_TEMPLATES.items()
}
class Prompts:
"""
A class providing access to prompt templates.
This class manages a collection of Jinja2 templates used for generating
various prompts in the process.
The templates are pre-compiled for efficiency and can be accessed either
through attribute access or the get_template class method.
Attributes:
None - Templates are stored in module-level constants
Methods:
__getattr__(name: str) -> str:
Dynamically retrieves prompt template strings by name
get_template(name: str) -> jinja2.Template:
Retrieves pre-compiled Jinja2 templates by name
"""
def __getattr__(self, name: str) -> str:
"""
Dynamically retrieve prompt templates by name.
Args:
name (str): Name of the prompt template to retrieve
Returns:
str: The prompt template string
Raises:
AttributeError: If the requested template name doesn't exist
"""
if name in PROMPT_TEMPLATES:
return PROMPT_TEMPLATES[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
@classmethod
def get_template(cls, name: str) -> jinja2.Template:
"""
Get a pre-compiled Jinja2 template by name.
Args:
name (str): Name of the template to retrieve
Returns:
jinja2.Template: The pre-compiled Jinja2 template object
Raises:
KeyError: If the requested template name doesn't exist
"""
return TEMPLATES[name]

View File

@ -0,0 +1,27 @@
[project]
name = "backend"
version = "0.1.0"
description = "Backend API server for chatbot application"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"fastapi>=0.116.1",
"langchain>=0.3.27",
"langchain-milvus>=0.2.1",
"langchain-mcp-adapters>=0.1.0",
"langchain-nvidia-ai-endpoints>=0.3.13",
"langchain-openai>=0.3.28",
"langchain-text-splitters>=0.3.9",
"langchain-unstructured>=0.1.6",
"langgraph>=0.6.0",
"mcp>=0.1.0",
"pydantic>=2.11.7",
"pypdf2>=3.0.1",
"python-dotenv>=1.1.1",
"python-multipart>=0.0.20",
"asyncpg>=0.29.0",
"requests>=2.28.0",
"unstructured[pdf]>=0.18.11",
"uvicorn>=0.35.0",
"websockets>=15.0.1",
]

View File

@ -0,0 +1,16 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@ -0,0 +1,70 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from typing import Type
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_core.messages import SystemMessage, HumanMessage
from mcp.server.fastmcp import FastMCP
from openai import AsyncOpenAI
mcp = FastMCP("Code Generation")
model_name = "deepseek-coder:6.7b"
@mcp.tool()
async def write_code(query: str, programming_language: str):
"""This tool is used to write complete code.
Args:
query: The natural language description of the code to be generated.
programming_language: The programming language for the code generation (e.g., 'Python', 'JavaScript', 'HTML', 'CSS', 'Go').
Returns:
The generated code.
"""
model_client = AsyncOpenAI(
base_url="http://deepseek-coder:8000/v1",
api_key="ollama"
)
system_prompt = f"""You are an expert coder specializing in {programming_language}.
Given a user request, generate clean, efficient {programming_language} code that accomplishes the specified task.
Always provide the full code generation so the user can copy and paste a fully working example.
Return just the raw code, with no markdown formatting, explanations, or any other text.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
]
response = await model_client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.1,
)
generated_code = response.choices[0].message.content
return generated_code.strip()
if __name__ == "__main__":
print(f"Starting {mcp.name} MCP server...")
mcp.run(transport="stdio")

View File

@ -0,0 +1,124 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
MCP server providing image understanding and analysis tools.
This server exposes a `process_image` tool that uses a vision language model to answer queries about images.
It supports multiple image input formats including URLs, file paths, and base64-encoded images.
"""
import asyncio
import base64
import os
import requests
import sys
from pathlib import Path
import time
from langchain_core.tools import tool, Tool
from langchain_mcp_adapters.tools import to_fastmcp
from mcp.server.fastmcp import FastMCP
from openai import AsyncOpenAI, OpenAI
project_root = Path(__file__).parent.parent.parent
sys.path.append(str(project_root))
from postgres_storage import PostgreSQLConversationStorage
mcp = FastMCP("image-understanding-server")
model_name = "Qwen2.5-VL-7B-Instruct"
model_client = OpenAI(
base_url=f"http://qwen2.5-vl:8000/v1",
api_key="api_key"
)
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", 5432))
POSTGRES_DB = os.getenv("POSTGRES_DB", "chatbot")
POSTGRES_USER = os.getenv("POSTGRES_USER", "chatbot_user")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "chatbot_password")
postgres_storage = PostgreSQLConversationStorage(
host=POSTGRES_HOST,
port=POSTGRES_PORT,
database=POSTGRES_DB,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD
)
@mcp.tool()
def explain_image(query: str, image: str):
"""
This tool is used to understand an image. It will respond to the user's query based on the image.
...
"""
if not image:
raise ValueError('Error: explain_image tool received an empty image string.')
image_url_content = {}
if image.startswith("http://") or image.startswith("https://"):
image_url_content = {
"type": "image_url",
"image_url": {"url": image}
}
else:
if image.startswith("data:image/"):
metadata, b64_data = image.split(",", 1)
filetype = metadata.split(";")[0].split("/")[-1]
elif os.path.exists(image):
with open(image, "rb") as image_file:
filetype = image.split('.')[-1]
b64_data = base64.b64encode(image_file.read()).decode("utf-8")
else:
raise ValueError(f'Invalid image type -- could not be identified as a url or filepath: {image}')
image_url_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/{filetype if filetype else 'jpeg'};base64,{b64_data}"
}
}
message = [
{
"role": "user",
"content": [
{"type": "text", "text": query},
image_url_content
]
}
]
try:
print(f"Sending request to vision model: {query}")
response = model_client.chat.completions.create(
model=model_name,
messages=message,
max_tokens=512,
temperature=0.1
)
print(f"Received response from vision model")
return response.choices[0].message.content
except Exception as e:
print(f"Error calling vision model: {e}")
raise RuntimeError(f"Failed to process image with vision model: {e}")
if __name__ == "__main__":
print(f'running {mcp.name} MCP server')
mcp.run(transport="stdio")

View File

@ -0,0 +1,254 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""RAG MCP Server for Document Search and Question Answering.
This module implements an MCP server that provides document search capabilities using
a simple retrieval-augmented generation (RAG) pipeline. The server exposes a
search_documents tool that retrieves relevant document chunks and generates answers.
The simplified RAG workflow consists of:
- Document retrieval from a vector store
- Answer generation using retrieved context
"""
import asyncio
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Any, Annotated, Dict, List, Optional, Sequence, TypedDict
from langchain_core.documents import Document
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, add_messages
from mcp.server.fastmcp import FastMCP
from openai import AsyncOpenAI
from pypdf import PdfReader
project_root = Path(__file__).parent.parent.parent
sys.path.append(str(project_root))
from config import ConfigManager
from vector_store import VectorStore, create_vector_store_with_config
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO
)
logger = logging.getLogger(__name__)
class RAGState(TypedDict, total=False):
"""Type definition for the simplified RAG agent state.
Attributes:
question: The user's question to be answered.
messages: Conversation history with automatic message aggregation.
context: Retrieved documents from the local vector store.
sources: Optional list of source filters for retrieval.
"""
question: str
messages: Annotated[Sequence[AnyMessage], add_messages]
context: Optional[List[Document]]
sources: Optional[List[str]]
class RAGAgent:
"""Simplified RAG Agent for fast document search and answer generation.
This agent manages a simple two-step pipeline:
1. Retrieve documents from the local vector store.
2. Generate an answer using the retrieved context.
"""
def __init__(self):
"""Initialize the RAG agent with model client, configuration, and graph."""
config_path = self._get_config_path()
self.config_manager = ConfigManager(config_path)
self.vector_store = create_vector_store_with_config(self.config_manager)
self.model_name = self.config_manager.get_selected_model()
self.model_client = AsyncOpenAI(
base_url=f"http://{self.model_name}:8000/v1",
api_key="api_key"
)
self.generation_prompt = self._get_generation_prompt()
self.graph = self._build_graph()
def _get_config_path(self):
"""Get the configuration file path and validate its existence."""
config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
if not os.path.exists(config_path):
logger.error("ERROR: config.json not found")
return config_path
def _get_generation_prompt(self) -> str:
"""Get the system prompt template for the generation node."""
return """You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If no context is provided, answer the question using your own knowledge, but state that you could not find relevant information in the provided documents.
Don't make up any information that is not provided in the context. Keep the answer concise.
Context:
{context}
"""
def retrieve(self, state: RAGState) -> Dict:
"""Retrieve relevant documents from the vector store."""
logger.info({"message": "Starting document retrieval"})
sources = state.get("sources", [])
if sources:
logger.info({"message": "Attempting retrieval with source filters", "sources": sources})
retrieved_docs = self.vector_store.get_documents(state["question"], sources=sources)
else:
logger.info({"message": "No sources specified, searching all documents"})
retrieved_docs = self.vector_store.get_documents(state["question"])
if not retrieved_docs and sources:
logger.info({"message": "No documents found with source filtering, trying without filters"})
retrieved_docs = self.vector_store.get_documents(state["question"])
if retrieved_docs:
sources_found = set(doc.metadata.get("source", "unknown") for doc in retrieved_docs)
logger.info({"message": "Document sources found", "sources": list(sources_found), "doc_count": len(retrieved_docs)})
else:
logger.warning({"message": "No documents retrieved", "query": state["question"], "attempted_sources": sources})
return {"context": retrieved_docs}
async def generate(self, state: RAGState) -> Dict:
"""Generate an answer using retrieved context."""
logger.info({
"message": "Generating answer",
"question": state['question']
})
context = state.get("context", [])
if not context:
logger.warning({"message": "No context available for generation", "question": state['question']})
docs_content = "No relevant documents were found."
else:
logger.info({"message": "Generating with context", "context_count": len(context)})
docs_content = self._hydrate_context(context)
system_prompt = self.generation_prompt.format(context=docs_content)
user_message = f"Question: {state['question']}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
try:
response = await self.model_client.chat.completions.create(
model=self.model_name,
messages=messages,
)
response_content = response.choices[0].message.content
logger.info({
"message": "Generation completed",
"response_length": len(response_content),
"response_preview": response_content[:100] + "..."
})
return {
"messages": [HumanMessage(content=state["question"]), AIMessage(content=response_content)]
}
except Exception as e:
logger.error({"message": "Error during generation", "error": str(e)})
fallback_response = f"I apologize, but I encountered an error while processing your query about: {state['question']}"
return {
"messages": [HumanMessage(content=state["question"]), AIMessage(content=fallback_response)]
}
def _hydrate_context(self, context: List[Document]) -> str:
"""Extract text content from document objects."""
return "\n\n".join([doc.page_content for doc in context if doc.page_content])
def _build_graph(self):
"""Build and compile the simplified RAG workflow graph."""
workflow = StateGraph(RAGState)
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("generate", self.generate)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
return workflow.compile()
mcp = FastMCP("RAG")
rag_agent = RAGAgent()
vector_store = create_vector_store_with_config(rag_agent.config_manager)
@mcp.tool()
async def search_documents(query: str) -> str:
"""Search documents uploaded by the user to generate fast, grounded answers.
Performs a simple RAG pipeline that retrieves relevant documents and generates answers.
Args:
query: The question or query to search for.
Returns:
A concise answer based on the retrieved documents.
"""
config_obj = rag_agent.config_manager.read_config()
sources = config_obj.selected_sources or []
initial_state = {
"question": query,
"sources": sources,
"messages": []
}
thread_id = f"rag_session_{time.time()}"
result = await rag_agent.graph.ainvoke(initial_state)
if not result.get("messages"):
logger.error({"message": "No messages in RAG result", "query": query})
return "I apologize, but I encountered an error processing your query and no response was generated."
final_message = result["messages"][-1]
final_content = getattr(final_message, 'content', '') or ''
if not final_content.strip():
logger.warning({"message": "Empty content in final RAG message", "query": query, "message_type": type(final_message).__name__})
return f"I found relevant documents for your query '{query}' but was unable to generate a response. Please try rephrasing your question."
logger.info({"message": "RAG result", "content_length": len(final_content), "query": query})
return final_content
if __name__ == "__main__":
print(f"Starting {mcp.name} MCP server...")
mcp.run(transport="stdio")

View File

@ -0,0 +1,72 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Weather Test MCP Server
A lightweight Model Context Protocol (MCP) server designed for testing and demonstration purposes.
This module provides mock weather tools that return humorous, fake responses rather than real
weather data.
Features:
- Mock weather data retreival for any location
- Example implementation of MCP tool registration with local functions
- Demonstration of FastMCP server setup
Tools provided:
- get_weather(location): Returns mock weather information
- get_rain_forecast(location): Returns mock rain forecast data
Usage:
- Usage in Chatbot Spark:
The server is ran on project startup by MCPClient. SEe client.py for more details
- Standalone Usage:
Run as standalone script to start the MCP server:
$ python weather_test.py
"""
import time
import os
from langchain_core.tools import tool, Tool
from langchain_mcp_adapters.tools import to_fastmcp
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("weather-tools")
@mcp.tool()
def get_weather(location: str):
"""Call to get the weather from a specific location."""
if any([city in location.lower() for city in ["sf", "san francisco"]]):
return "It's sunny in San Francisco, but you better look out if you're a Gemini 😈."
else:
return f"The weather is spooky with a chance of gremlins in {location}"
@mcp.tool()
def get_rain_forecast(location: str):
"""Call to get the rain forecast from a specific location."""
if any([city in location.lower() for city in ["sf", "san francisco"]]):
return "It's going to rain cats and dogs in San Francisco tomorrow."
else:
return f"It is raining muffins in {location}"
if __name__ == "__main__":
print(f'running {mcp.name} MCP server')
mcp.run(transport="stdio")

View File

@ -0,0 +1,182 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utility functions for file processing and message conversion."""
import json
import os
import time
from typing import List, Dict, Any
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, ToolCall
from logger import logger
from vector_store import VectorStore
async def process_and_ingest_files_background(
file_info: List[dict],
vector_store: VectorStore,
config_manager,
task_id: str,
indexing_tasks: Dict[str, str]
) -> None:
"""Process and ingest files in the background.
Args:
file_info: List of file dictionaries with 'filename' and 'content' keys
vector_store: VectorStore instance for document indexing
config_manager: ConfigManager instance for updating sources
task_id: Unique identifier for this processing task
indexing_tasks: Dictionary to track task status
"""
try:
logger.debug({
"message": "Starting background file processing",
"task_id": task_id,
"file_count": len(file_info)
})
indexing_tasks[task_id] = "saving_files"
permanent_dir = os.path.join("uploads", task_id)
os.makedirs(permanent_dir, exist_ok=True)
file_paths = []
file_names = []
for info in file_info:
try:
file_name = info["filename"]
content = info["content"]
file_path = os.path.join(permanent_dir, file_name)
with open(file_path, "wb") as f:
f.write(content)
file_paths.append(file_path)
file_names.append(file_name)
logger.debug({
"message": "Saved file",
"task_id": task_id,
"filename": file_name,
"path": file_path
})
except Exception as e:
logger.error({
"message": f"Error saving file {info['filename']}",
"task_id": task_id,
"filename": info['filename'],
"error": str(e)
}, exc_info=True)
indexing_tasks[task_id] = "loading_documents"
logger.debug({"message": "Loading documents", "task_id": task_id})
try:
documents = vector_store._load_documents(file_paths)
logger.debug({
"message": "Documents loaded, starting indexing",
"task_id": task_id,
"document_count": len(documents)
})
indexing_tasks[task_id] = "indexing_documents"
vector_store.index_documents(documents)
if file_names:
config = config_manager.read_config()
config_updated = False
for file_name in file_names:
if file_name not in config.sources:
config.sources.append(file_name)
config_updated = True
if config_updated:
config_manager.write_config(config)
logger.debug({
"message": "Updated config with new sources",
"task_id": task_id,
"sources": config.sources
})
indexing_tasks[task_id] = "completed"
logger.debug({
"message": "Background processing and indexing completed successfully",
"task_id": task_id
})
except Exception as e:
indexing_tasks[task_id] = f"failed_during_indexing: {str(e)}"
logger.error({
"message": "Error during document loading or indexing",
"task_id": task_id,
"error": str(e)
}, exc_info=True)
except Exception as e:
indexing_tasks[task_id] = f"failed: {str(e)}"
logger.error({
"message": "Error in background processing",
"task_id": task_id,
"error": str(e)
}, exc_info=True)
def convert_langgraph_messages_to_openai(messages: List) -> List[Dict[str, Any]]:
"""Convert LangGraph message objects to OpenAI API format.
Args:
messages: List of LangGraph message objects
Returns:
List of dictionaries in OpenAI API format
"""
openai_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
openai_messages.append({
"role": "user",
"content": msg.content
})
elif isinstance(msg, AIMessage):
openai_msg = {
"role": "assistant",
"content": msg.content or ""
}
if hasattr(msg, 'tool_calls') and msg.tool_calls:
openai_msg["tool_calls"] = []
for tc in msg.tool_calls:
openai_msg["tool_calls"].append({
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": json.dumps(tc["args"])
}
})
openai_messages.append(openai_msg)
elif isinstance(msg, ToolMessage):
openai_messages.append({
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id
})
return openai_messages

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,388 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import glob
from typing import List, Tuple
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_milvus import Milvus
from langchain_core.documents import Document
from typing_extensions import List
from langchain_openai import OpenAIEmbeddings
from langchain_unstructured import UnstructuredLoader
from dotenv import load_dotenv
from logger import logger
from typing import Optional, Callable
import requests
class CustomEmbeddings:
"""Wraps qwen3 embedding model to match OpenAI format"""
def __init__(self, model: str = "Qwen3-Embedding-4B-Q8_0.gguf", host: str = "http://qwen3-embedding:8000"):
self.model = model
self.url = f"{host}/v1/embeddings"
def __call__(self, texts: list[str]) -> list[list[float]]:
embeddings = []
for text in texts:
response = requests.post(
self.url,
json={"input": text, "model": self.model},
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
data = response.json()
embeddings.append(data["data"][0]["embedding"])
return embeddings
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts. Required by Milvus library."""
return self.__call__(texts)
def embed_query(self, text: str) -> list[float]:
"""Embed a single query text. Required by Milvus library."""
return self.__call__([text])[0]
class VectorStore:
"""Vector store for document embedding and retrieval.
Decoupled from ConfigManager - uses optional callbacks for source management.
"""
def __init__(
self,
embeddings=None,
uri: str = "http://milvus:19530",
on_source_deleted: Optional[Callable[[str], None]] = None
):
"""Initialize the vector store.
Args:
embeddings: Embedding model to use (defaults to OllamaEmbeddings)
uri: Milvus connection URI
on_source_deleted: Optional callback when a source is deleted
"""
try:
self.embeddings = embeddings or CustomEmbeddings(model="qwen3-embedding-custom")
self.uri = uri
self.on_source_deleted = on_source_deleted
self._initialize_store()
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
logger.debug({
"message": "VectorStore initialized successfully"
})
except Exception as e:
logger.error({
"message": "Error initializing VectorStore",
"error": str(e)
}, exc_info=True)
raise
def _initialize_store(self):
self._store = Milvus(
embedding_function=self.embeddings,
collection_name="context",
connection_args={"uri": self.uri},
auto_id=True
)
logger.debug({
"message": "Milvus vector store initialized",
"uri": self.uri,
"collection": "context"
})
def _load_documents(self, file_paths: List[str] = None, input_dir: str = None) -> List[str]:
try:
documents = []
source_name = None
if input_dir:
source_name = os.path.basename(os.path.normpath(input_dir))
logger.debug({
"message": "Loading files from directory",
"directory": input_dir,
"source": source_name
})
file_paths = glob.glob(os.path.join(input_dir, "**"), recursive=True)
file_paths = [f for f in file_paths if os.path.isfile(f)]
logger.info(f"Processing {len(file_paths)} files: {file_paths}")
for file_path in file_paths:
try:
if not source_name:
source_name = os.path.basename(file_path)
logger.info(f"Using filename as source: {source_name}")
logger.info(f"Loading file: {file_path}")
file_ext = os.path.splitext(file_path)[1].lower()
logger.info(f"File extension: {file_ext}")
try:
loader = UnstructuredLoader(file_path)
docs = loader.load()
logger.info(f"Successfully loaded {len(docs)} documents from {file_path}")
except Exception as pdf_error:
logger.error(f'error with unstructured loader, trying to load from scratch')
file_text = None
if file_ext == ".pdf":
logger.info("Attempting PyPDF text extraction fallback")
try:
from pypdf import PdfReader
reader = PdfReader(file_path)
extracted_pages = []
for page in reader.pages:
try:
extracted_pages.append(page.extract_text() or "")
except Exception as per_page_err:
logger.info(f"Warning: failed to extract a page: {per_page_err}")
extracted_pages.append("")
file_text = "\n\n".join(extracted_pages).strip()
except Exception as pypdf_error:
logger.info(f"PyPDF fallback failed: {pypdf_error}")
file_text = None
if not file_text:
logger.info("Falling back to raw text read of file contents")
try:
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
file_text = f.read()
except Exception as read_error:
logger.info(f"Fallback read failed: {read_error}")
file_text = ""
if file_text and file_text.strip():
docs = [Document(
page_content=file_text,
metadata={
"source": source_name,
"file_path": file_path,
"filename": os.path.basename(file_path),
}
)]
else:
logger.info("Creating a simple document as fallback (no text extracted)")
docs = [Document(
page_content=f"Document: {os.path.basename(file_path)}",
metadata={
"source": source_name,
"file_path": file_path,
"filename": os.path.basename(file_path),
}
)]
for doc in docs:
if not doc.metadata:
doc.metadata = {}
cleaned_metadata = {}
cleaned_metadata["source"] = source_name
cleaned_metadata["file_path"] = file_path
cleaned_metadata["filename"] = os.path.basename(file_path)
for key, value in doc.metadata.items():
if key not in ["source", "file_path"]:
if isinstance(value, (list, dict, set)):
cleaned_metadata[key] = str(value)
elif value is not None:
cleaned_metadata[key] = str(value)
doc.metadata = cleaned_metadata
documents.extend(docs)
logger.debug({
"message": "Loaded documents from file",
"file_path": file_path,
"document_count": len(docs)
})
except Exception as e:
logger.error({
"message": "Error loading file",
"file_path": file_path,
"error": str(e)
}, exc_info=True)
continue
logger.info(f"Total documents loaded: {len(documents)}")
return documents
except Exception as e:
logger.error({
"message": "Error loading documents",
"error": str(e)
}, exc_info=True)
raise
def index_documents(self, documents: List[Document]) -> List[Document]:
try:
logger.debug({
"message": "Starting document indexing",
"document_count": len(documents)
})
splits = self.text_splitter.split_documents(documents)
logger.debug({
"message": "Split documents into chunks",
"chunk_count": len(splits)
})
self._store.add_documents(splits)
self.flush_store()
logger.debug({
"message": "Document indexing completed"
})
except Exception as e:
logger.error({
"message": "Error during document indexing",
"error": str(e)
}, exc_info=True)
raise
def flush_store(self):
"""
Flush the Milvus collection to ensure that all added documents are persisted to disk.
"""
try:
from pymilvus import connections
connections.connect(uri=self.uri)
from pymilvus import utility
utility.flush_all()
logger.debug({
"message": "Milvus store flushed (persisted to disk)"
})
except Exception as e:
logger.error({
"message": "Error flushing Milvus store",
"error": str(e)
}, exc_info=True)
def get_documents(self, query: str, k: int = 8, sources: List[str] = None) -> List[Document]:
"""
Get relevant documents using the retriever's invoke method.
"""
try:
search_kwargs = {"k": k}
if sources:
if len(sources) == 1:
filter_expr = f'source == "{sources[0]}"'
else:
source_conditions = [f'source == "{source}"' for source in sources]
filter_expr = " || ".join(source_conditions)
search_kwargs["expr"] = filter_expr
logger.debug({
"message": "Retrieving with filter",
"filter": filter_expr
})
retriever = self._store.as_retriever(
search_type="similarity",
search_kwargs=search_kwargs
)
docs = retriever.invoke(query)
logger.debug({
"message": "Retrieved documents",
"query": query,
"document_count": len(docs)
})
return docs
except Exception as e:
logger.error({
"message": "Error retrieving documents",
"error": str(e)
}, exc_info=True)
return []
def delete_collection(self, collection_name: str) -> bool:
"""
Delete a collection from Milvus.
Args:
collection_name: Name of the collection to delete
Returns:
bool: True if successful, False otherwise
"""
try:
from pymilvus import connections, Collection, utility
connections.connect(uri=self.uri)
if utility.has_collection(collection_name):
collection = Collection(name=collection_name)
collection.drop()
if self.on_source_deleted:
self.on_source_deleted(collection_name)
logger.debug({
"message": "Collection deleted successfully",
"collection_name": collection_name
})
return True
else:
logger.warning({
"message": "Collection not found",
"collection_name": collection_name
})
return False
except Exception as e:
logger.error({
"message": "Error deleting collection",
"collection_name": collection_name,
"error": str(e)
}, exc_info=True)
return False
def create_vector_store_with_config(config_manager, uri: str = "http://milvus:19530") -> VectorStore:
"""Factory function to create a VectorStore with ConfigManager integration.
Args:
config_manager: ConfigManager instance for source management
uri: Milvus connection URI
Returns:
VectorStore instance with source deletion callback
"""
def handle_source_deleted(source_name: str):
"""Handle source deletion by updating config."""
config = config_manager.read_config()
if hasattr(config, 'sources') and source_name in config.sources:
config.sources.remove(source_name)
config_manager.write_config(config)
return VectorStore(
uri=uri,
on_source_deleted=handle_source_deleted
)

View File

@ -0,0 +1,160 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
x-build-base: &build-base
build:
context: .
dockerfile: Dockerfile.llamacpp
image: local/llama.cpp:server-cuda
services:
qwen2.5-vl:
image: nvcr.io/nvidia/tensorrt-llm/release:spark-single-gpu-dev
container_name: qwen2.5-vl
shm_size: '1g'
restart: unless-stopped
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
environment:
- TOKENIZERS_PARALLELISM=false
- NCCL_P2P_LEVEL=SYS
- NCCL_DEBUG=INFO
- UCX_TLS=tcp,sm,self
- UCX_MEMTYPE_CACHE=n
- CUDA_VISIBLE_DEVICES=0
command: >
trtllm-serve serve Qwen/Qwen2-VL-7B-Instruct
--backend pytorch
--host 0.0.0.0
--port 8000
--trust_remote_code
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 5
start_period: 60s
qwen3-embedding:
<<: *build-base
container_name: qwen3-embedding
volumes:
- ./models:/models
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
command:
- "-m"
- "/models/Qwen3-Embedding-4B-Q8_0.gguf"
- "--port"
- "8000"
- "--host"
- "0.0.0.0"
- "--jinja"
- "--embeddings"
# Uncomment next block if you want to use gpt-oss-20b
# gpt-oss-20b:
# <<: *build-base
# container_name: gpt-oss-20b
# volumes:
# - ./models:/models
# deploy:
# resources:
# reservations:
# devices:
# - driver: nvidia
# count: all
# capabilities: [gpu]
# command:
# - "-m"
# - "/models/gpt-oss-20b-mxfp4.gguf"
# - "--port"
# - "8000"
# - "--host"
# - "0.0.0.0"
# - "-n"
# - "2048"
# - "--n-gpu-layers"
# - "999"
# - "--jinja"
# Comment next block if you want to use gpt-oss-20b
gpt-oss-120b:
<<: *build-base
container_name: gpt-oss-120b
volumes:
- ./models:/models
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
command:
- "-m"
- "/models/gpt-oss-120b-mxfp4-00001-of-00003.gguf"
- "--port"
- "8000"
- "--host"
- "0.0.0.0"
- "-n"
- "4096"
- "--n-gpu-layers"
- "999"
- "--jinja"
deepseek-coder:
<<: *build-base
container_name: deepseek-coder
volumes:
- ./models:/models
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
command:
- "-m"
- "/models/deepseek-coder-6.7b-instruct.Q8_0.gguf"
- "--port"
- "8000"
- "--host"
- "0.0.0.0"
- "-n"
- "256"
- "--n-gpu-layers"
- "999"
- "--jinja"
volumes:
ollama-data:
networks:
default:
name: chatbot-net

View File

@ -0,0 +1,135 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
services:
backend:
container_name: backend
build:
context: ./backend
dockerfile: Dockerfile
ports:
- "8000:8000"
volumes:
- ./backend:/app
depends_on:
- postgres
- etcd
- minio
restart: unless-stopped
environment:
- POSTGRES_HOST=postgres
- POSTGRES_DB=chatbot
- POSTGRES_USER=chatbot_user
- POSTGRES_PASSWORD=chatbot_password
- ETCD_ENDPOINTS=etcd:2379
- MINIO_ADDRESS=minio:9000
- MILVUS_ADDRESS=milvus:19530
- MODELS=gpt-oss-120b #,gpt-oss-20b
frontend:
container_name: frontend
build:
context: ./frontend
dockerfile: Dockerfile
ports:
- "3000:3000"
volumes:
- /frontend:/app/frontend
- /app/node_modules
depends_on:
- backend
restart: unless-stopped
postgres:
container_name: postgres
image: postgres:15-alpine
ports:
- "5432:5432"
environment:
- POSTGRES_DB=chatbot
- POSTGRES_USER=chatbot_user
- POSTGRES_PASSWORD=chatbot_password
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U chatbot_user -d chatbot"]
interval: 30s
timeout: 10s
retries: 3
restart: unless-stopped
etcd:
container_name: milvus-etcd
image: quay.io/coreos/etcd:v3.5.5
environment:
- ETCD_AUTO_COMPACTION_MODE=revision
- ETCD_AUTO_COMPACTION_RETENTION=1000
- ETCD_QUOTA_BACKEND_BYTES=4294967296
- ETCD_SNAPSHOT_COUNT=50000
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck:
test: ["CMD", "etcdctl", "endpoint", "health"]
interval: 30s
timeout: 20s
retries: 3
minio:
container_name: milvus-minio
image: minio/minio:RELEASE.2023-03-20T20-16-18Z
environment:
MINIO_ACCESS_KEY: minioadmin
MINIO_SECRET_KEY: minioadmin
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
command: minio server /minio_data
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 30s
timeout: 20s
retries: 3
milvus:
container_name: milvus-standalone
image: milvusdb/milvus:v2.5.15-20250718-3a3b374f-gpu-arm64
command: ["milvus", "run", "standalone"]
environment:
ETCD_ENDPOINTS: etcd:2379
MINIO_ADDRESS: minio:9000
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
interval: 30s
start_period: 90s
timeout: 20s
retries: 3
ports:
- "19530:19530"
- "9091:9091"
depends_on:
- "etcd"
- "minio"
volumes:
postgres_data:
networks:
default:
name: chatbot-net

View File

@ -0,0 +1,29 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
FROM node:20-alpine
WORKDIR /app
COPY package.json package-lock.json* ./
RUN npm install
COPY . .
EXPOSE 3000
CMD ["npm", "run", "dev"]

View File

@ -0,0 +1,32 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import { dirname } from "path";
import { fileURLToPath } from "url";
import { FlatCompat } from "@eslint/eslintrc";
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
const compat = new FlatCompat({
baseDirectory: __dirname,
});
const eslintConfig = [
...compat.extends("next/core-web-vitals", "next/typescript"),
];
export default eslintConfig;

View File

@ -0,0 +1,21 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
/// <reference types="next" />
/// <reference types="next/image-types/global" />
// NOTE: This file should not be edited
// see https://nextjs.org/docs/app/api-reference/config/typescript for more information.

View File

@ -0,0 +1,30 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import type { NextConfig } from "next";
const nextConfig: NextConfig = {
async rewrites() {
return [
{
source: '/api/:path*',
destination: 'http://backend:8000/:path*',
},
];
},
};
export default nextConfig;

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
{
"name": "frontend",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev --turbopack",
"build": "next build",
"start": "next start",
"lint": "next lint"
},
"dependencies": {
"next": "15.1.7",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"react-markdown": "^10.1.0",
"react-syntax-highlighter": "^15.6.1",
"remark-gfm": "^4.0.1"
},
"devDependencies": {
"@eslint/eslintrc": "^3",
"@types/node": "^20",
"@types/react": "^19",
"@types/react-dom": "^19",
"eslint": "^9",
"eslint-config-next": "15.1.7",
"postcss": "^8",
"tailwindcss": "^3.4.1",
"typescript": "^5"
}
}

View File

@ -0,0 +1,24 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
/** @type {import('postcss-load-config').Config} */
const config = {
plugins: {
tailwindcss: {},
},
};
export default config;

View File

@ -0,0 +1 @@
<svg fill="none" viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg"><path d="M14.5 13.5V5.41a1 1 0 0 0-.3-.7L9.8.29A1 1 0 0 0 9.08 0H1.5v13.5A2.5 2.5 0 0 0 4 16h8a2.5 2.5 0 0 0 2.5-2.5m-1.5 0v-7H8v-5H3v12a1 1 0 0 0 1 1h8a1 1 0 0 0 1-1M9.5 5V2.12L12.38 5zM5.13 5h-.62v1.25h2.12V5zm-.62 3h7.12v1.25H4.5zm.62 3h-.62v1.25h7.12V11z" clip-rule="evenodd" fill="#666" fill-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 391 B

View File

@ -0,0 +1 @@
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><g clip-path="url(#a)"><path fill-rule="evenodd" clip-rule="evenodd" d="M10.27 14.1a6.5 6.5 0 0 0 3.67-3.45q-1.24.21-2.7.34-.31 1.83-.97 3.1M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16m.48-1.52a7 7 0 0 1-.96 0H7.5a4 4 0 0 1-.84-1.32q-.38-.89-.63-2.08a40 40 0 0 0 3.92 0q-.25 1.2-.63 2.08a4 4 0 0 1-.84 1.31zm2.94-4.76q1.66-.15 2.95-.43a7 7 0 0 0 0-2.58q-1.3-.27-2.95-.43a18 18 0 0 1 0 3.44m-1.27-3.54a17 17 0 0 1 0 3.64 39 39 0 0 1-4.3 0 17 17 0 0 1 0-3.64 39 39 0 0 1 4.3 0m1.1-1.17q1.45.13 2.69.34a6.5 6.5 0 0 0-3.67-3.44q.65 1.26.98 3.1M8.48 1.5l.01.02q.41.37.84 1.31.38.89.63 2.08a40 40 0 0 0-3.92 0q.25-1.2.63-2.08a4 4 0 0 1 .85-1.32 7 7 0 0 1 .96 0m-2.75.4a6.5 6.5 0 0 0-3.67 3.44 29 29 0 0 1 2.7-.34q.31-1.83.97-3.1M4.58 6.28q-1.66.16-2.95.43a7 7 0 0 0 0 2.58q1.3.27 2.95.43a18 18 0 0 1 0-3.44m.17 4.71q-1.45-.12-2.69-.34a6.5 6.5 0 0 0 3.67 3.44q-.65-1.27-.98-3.1" fill="#666"/></g><defs><clipPath id="a"><path fill="#fff" d="M0 0h16v16H0z"/></clipPath></defs></svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 394 80"><path fill="#000" d="M262 0h68.5v12.7h-27.2v66.6h-13.6V12.7H262V0ZM149 0v12.7H94v20.4h44.3v12.6H94v21h55v12.6H80.5V0h68.7zm34.3 0h-17.8l63.8 79.4h17.9l-32-39.7 32-39.6h-17.9l-23 28.6-23-28.6zm18.3 56.7-9-11-27.1 33.7h17.8l18.3-22.7z"/><path fill="#000" d="M81 79.3 17 0H0v79.3h13.6V17l50.2 62.3H81Zm252.6-.4c-1 0-1.8-.4-2.5-1s-1.1-1.6-1.1-2.6.3-1.8 1-2.5 1.6-1 2.6-1 1.8.3 2.5 1a3.4 3.4 0 0 1 .6 4.3 3.7 3.7 0 0 1-3 1.8zm23.2-33.5h6v23.3c0 2.1-.4 4-1.3 5.5a9.1 9.1 0 0 1-3.8 3.5c-1.6.8-3.5 1.3-5.7 1.3-2 0-3.7-.4-5.3-1s-2.8-1.8-3.7-3.2c-.9-1.3-1.4-3-1.4-5h6c.1.8.3 1.6.7 2.2s1 1.2 1.6 1.5c.7.4 1.5.5 2.4.5 1 0 1.8-.2 2.4-.6a4 4 0 0 0 1.6-1.8c.3-.8.5-1.8.5-3V45.5zm30.9 9.1a4.4 4.4 0 0 0-2-3.3 7.5 7.5 0 0 0-4.3-1.1c-1.3 0-2.4.2-3.3.5-.9.4-1.6 1-2 1.6a3.5 3.5 0 0 0-.3 4c.3.5.7.9 1.3 1.2l1.8 1 2 .5 3.2.8c1.3.3 2.5.7 3.7 1.2a13 13 0 0 1 3.2 1.8 8.1 8.1 0 0 1 3 6.5c0 2-.5 3.7-1.5 5.1a10 10 0 0 1-4.4 3.5c-1.8.8-4.1 1.2-6.8 1.2-2.6 0-4.9-.4-6.8-1.2-2-.8-3.4-2-4.5-3.5a10 10 0 0 1-1.7-5.6h6a5 5 0 0 0 3.5 4.6c1 .4 2.2.6 3.4.6 1.3 0 2.5-.2 3.5-.6 1-.4 1.8-1 2.4-1.7a4 4 0 0 0 .8-2.4c0-.9-.2-1.6-.7-2.2a11 11 0 0 0-2.1-1.4l-3.2-1-3.8-1c-2.8-.7-5-1.7-6.6-3.2a7.2 7.2 0 0 1-2.4-5.7 8 8 0 0 1 1.7-5 10 10 0 0 1 4.3-3.5c2-.8 4-1.2 6.4-1.2 2.3 0 4.4.4 6.2 1.2 1.8.8 3.2 2 4.3 3.4 1 1.4 1.5 3 1.5 5h-5.8z"/></svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -0,0 +1 @@
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1155 1000"><path d="m577.3 0 577.4 1000H0z" fill="#fff"/></svg>

After

Width:  |  Height:  |  Size: 128 B

View File

@ -0,0 +1 @@
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path fill-rule="evenodd" clip-rule="evenodd" d="M1.5 2.5h13v10a1 1 0 0 1-1 1h-11a1 1 0 0 1-1-1zM0 1h16v11.5a2.5 2.5 0 0 1-2.5 2.5h-11A2.5 2.5 0 0 1 0 12.5zm3.75 4.5a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5M7 4.75a.75.75 0 1 1-1.5 0 .75.75 0 0 1 1.5 0m1.75.75a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5" fill="#666"/></svg>

After

Width:  |  Height:  |  Size: 385 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

View File

@ -0,0 +1,94 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
@tailwind base;
@tailwind components;
@tailwind utilities;
:root {
--background: white;
--foreground: #1e293b;
--primary: #76B900;
--primary-dark: #669f00;
--secondary: #f0f0f0;
--accent: #f56565;
--border: #e2e8f0;
--shimmer-base: #a3a3a3;
--shimmer-peak: #ffffff;
--placeholder-opacity-min: .6;
}
@media (prefers-color-scheme: dark) {
:root {
--background: #0f172a;
--foreground: #f8fafc;
--primary: #76B900;
--primary-dark: #669f00;
--secondary: #1e293b;
--accent: #ef4444;
--border: #334155;
--shimmer-base: #444;
--shimmer-peak: #222;
--placeholder-opacity-min: .4;
}
}
* {
box-sizing: border-box;
padding: 0;
margin: 0;
}
html,
body {
max-width: 100vw;
height: 100vh;
overflow: hidden;
}
body {
background-color: var(--background);
color: var(--foreground);
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
}
a {
color: inherit;
text-decoration: none;
}
button, input, select, textarea {
font-family: inherit;
}
/* Custom scrollbar */
::-webkit-scrollbar {
width: 6px;
height: 6px;
}
::-webkit-scrollbar-track {
background: transparent;
}
::-webkit-scrollbar-thumb {
background: #cbd5e0;
border-radius: 3px;
}
::-webkit-scrollbar-thumb:hover {
background: #a0aec0;
}

View File

@ -0,0 +1,42 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import type { Metadata } from "next";
import { Inter } from "next/font/google";
import "./globals.css";
import ThemeToggle from "@/components/ThemeToggle";
const inter = Inter({ subsets: ["latin"] });
export const metadata: Metadata = {
title: "Spark Chat",
description: "AI-powered chat interface",
};
export default function RootLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<html lang="en" className="h-full">
<body className={`${inter.className} h-full bg-white dark:bg-gray-900 text-gray-900 dark:text-white transition-colors duration-200`}>
<ThemeToggle />
{children}
</body>
</html>
);
}

View File

@ -0,0 +1,132 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
"use client";
import { useState, useRef, useEffect } from 'react';
import QuerySection from '@/components/QuerySection';
import DocumentIngestion from '@/components/DocumentIngestion';
import Sidebar from '@/components/Sidebar';
import styles from '@/styles/Home.module.css';
export default function Home() {
const [query, setQuery] = useState("");
const [response, setResponse] = useState("[]");
const [files, setFiles] = useState<FileList | null>(null);
const [ingestMessage, setIngestMessage] = useState("");
const [isStreaming, setIsStreaming] = useState(false);
const [isIngesting, setIsIngesting] = useState(false);
const [showIngestion, setShowIngestion] = useState(false);
const [refreshTrigger, setRefreshTrigger] = useState(0);
const [currentChatId, setCurrentChatId] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
// Load initial chat ID
useEffect(() => {
const fetchCurrentChatId = async () => {
try {
const response = await fetch("/api/chat_id");
if (response.ok) {
const { chat_id } = await response.json();
setCurrentChatId(chat_id);
}
} catch (error) {
console.error("Error fetching current chat ID:", error);
}
};
fetchCurrentChatId();
}, []);
// Handle chat changes
const handleChatChange = async (newChatId: string) => {
try {
const response = await fetch("/api/chat_id", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ chat_id: newChatId })
});
if (response.ok) {
setCurrentChatId(newChatId);
setResponse("[]"); // Clear current chat messages with empty JSON array
}
} catch (error) {
console.error("Error updating chat ID:", error);
}
};
// Clean up any ongoing streams when component unmounts
useEffect(() => {
return () => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
}
};
}, []);
// Function to handle successful document ingestion
const handleSuccessfulIngestion = () => {
setRefreshTrigger(prev => prev + 1);
};
return (
<div className={styles.container}>
<Sidebar
showIngestion={showIngestion}
setShowIngestion={setShowIngestion}
refreshTrigger={refreshTrigger}
currentChatId={currentChatId}
onChatChange={handleChatChange}
/>
<div className={styles.mainContent}>
<QuerySection
query={query}
response={response}
isStreaming={isStreaming}
setQuery={setQuery}
setResponse={setResponse}
setIsStreaming={setIsStreaming}
abortControllerRef={abortControllerRef}
setShowIngestion={setShowIngestion}
currentChatId={currentChatId}
/>
</div>
{showIngestion && (
<>
<div className={styles.overlay} onClick={() => setShowIngestion(false)} />
<div className={styles.documentUploadContainer}>
<button
className={styles.closeButton}
onClick={() => setShowIngestion(false)}
>
×
</button>
<DocumentIngestion
files={files}
ingestMessage={ingestMessage}
isIngesting={isIngesting}
setFiles={setFiles}
setIngestMessage={setIngestMessage}
setIsIngesting={setIsIngesting}
onSuccessfulIngestion={handleSuccessfulIngestion}
/>
</div>
</>
)}
</div>
);
}

View File

@ -0,0 +1,141 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import { SetStateAction, useState } from 'react';
import styles from '@/styles/DocumentIngestion.module.css';
declare module 'react' {
interface InputHTMLAttributes<T> extends HTMLAttributes<T> {
webkitdirectory?: string;
directory?: string;
}
}
interface DocumentIngestionProps {
files: FileList | null;
ingestMessage: string;
isIngesting: boolean;
setFiles: (files: FileList | null) => void;
setIngestMessage: (message: string) => void;
setIsIngesting: (value: boolean) => void;
onSuccessfulIngestion?: () => void;
}
export default function DocumentIngestion({
files,
ingestMessage,
isIngesting,
setFiles,
setIngestMessage,
setIsIngesting,
onSuccessfulIngestion
}: DocumentIngestionProps) {
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
setFiles(e.target.files);
};
const handleIngestSubmit = async (e: { preventDefault: () => void }) => {
e.preventDefault();
setIsIngesting(true);
setIngestMessage("");
try {
if (files && files.length > 0) {
const formData = new FormData();
for (let i = 0; i < files.length; i++) {
formData.append("files", files[i]);
}
const res = await fetch("/api/ingest", {
method: "POST",
body: formData,
});
const data = await res.json();
setIngestMessage(data.message);
if (res.ok && onSuccessfulIngestion) {
onSuccessfulIngestion();
}
} else {
setIngestMessage("Please select files or specify a directory path.");
}
} catch (error) {
console.error("Error during ingestion:", error);
setIngestMessage("Error during ingestion. Please check the console for details.");
} finally {
setIsIngesting(false);
}
};
const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault();
e.stopPropagation();
};
const handleDrop = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault();
e.stopPropagation();
if (e.dataTransfer.files && e.dataTransfer.files.length > 0) {
setFiles(e.dataTransfer.files);
}
};
return (
<div className={styles.section}>
<h1>Document Ingestion</h1>
<form onSubmit={handleIngestSubmit} className={styles.ingestForm}>
<div
className={styles.uploadSection}
onDragOver={handleDragOver}
onDrop={handleDrop}
>
<label htmlFor="file-upload" className={styles.customFileLabel}>
Choose Files
</label>
<input
id="file-upload"
type="file"
multiple
onChange={handleFileChange}
disabled={isIngesting}
className={styles.fileInput}
/>
<span className={styles.fileName}>
{files && files.length > 0 ? Array.from(files).map(f => f.name).join(', ') : "No file chosen"}
</span>
<p className={styles.helpText}>
Select files or drag and drop them here
</p>
</div>
<button
type="submit"
disabled={isIngesting || !files}
className={styles.ingestButton}
>
{isIngesting ? "Ingesting..." : "Ingest Documents"}
</button>
</form>
{ingestMessage && (
<div className={styles.messageContainer}>
<p>{ingestMessage}</p>
</div>
)}
</div>
);
}

View File

@ -0,0 +1,657 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import type React from "react";
import { useRef, useEffect, useState } from "react";
import styles from "@/styles/QuerySection.module.css";
import ReactMarkdown from 'react-markdown'; // NEW
import remarkGfm from 'remark-gfm'; // NEW
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; // NEW
import { oneDark, oneLight, Dark, Light } from "react-syntax-highlighter/dist/esm/styles/prism"; // NEW
import WelcomeSection from "./WelcomeSection";
const isDark = typeof document !== "undefined" && document.documentElement.classList.contains("dark");
export function makeChatTheme(isDark: boolean) {
const base = isDark ? oneDark : oneLight;
const accents = isDark
? {
tag: "#E3E3E3",
prolog: "#E3E3E3",
doctype: "#E3E3E3",
punctuation:"#99CFCF",
}
: {
tag: "#9a6700",
prolog: "#7a6200",
doctype: "#7a6200",
punctuation:"#6b7280",
};
return {
...base,
'pre[class*="language-"]': {
...(base['pre[class*="language-"]'] || {}),
background: "transparent",
},
'code[class*="language-"]': {
...(base['code[class*="language-"]'] || {}),
background: "transparent",
},
tag: { ...(base.tag || {}), color: accents.tag },
prolog: { ...(base.prolog || {}), color: accents.prolog },
doctype: { ...(base.doctype || {}), color: accents.doctype },
punctuation: { ...(base.punctuation || {}), color: accents.punctuation },
'attr-name': { ...(base['attr-name'] || {}), color: isDark ? "#e6b450" : "#6b4f00" },
} as const;
}
const theme = makeChatTheme(isDark);
function CodeBlockWithCopy({ code, language }: { code: string; language: string }) {
const [copied, setCopied] = useState(false);
const handleCopy = async () => {
try {
await navigator.clipboard.writeText(code);
setCopied(true);
setTimeout(() => setCopied(false), 1200);
} catch (err) {
try {
const textarea = document.createElement("textarea");
textarea.value = code;
textarea.style.position = "fixed";
textarea.style.left = "-9999px";
document.body.appendChild(textarea);
textarea.focus();
textarea.select();
document.execCommand("copy");
document.body.removeChild(textarea);
setCopied(true);
setTimeout(() => setCopied(false), 1200);
} catch {}
}
};
return (
<div className={styles.codeBlock}>
<button
type="button"
className={styles.copyButton}
onClick={handleCopy}
aria-label="Copy code"
title={copied ? "Copied" : "Copy"}
>
<svg
className={styles.copyButtonIcon}
viewBox="0 0 460 460"
aria-hidden="true"
focusable="false"
fill="currentColor"
>
<g>
<g>
<g>
<path d="M425.934,0H171.662c-18.122,0-32.864,14.743-32.864,32.864v77.134h30V32.864c0-1.579,1.285-2.864,2.864-2.864h254.272
c1.579,0,2.864,1.285,2.864,2.864v254.272c0,1.58-1.285,2.865-2.864,2.865h-74.729v30h74.729
c18.121,0,32.864-14.743,32.864-32.865V32.864C458.797,14.743,444.055,0,425.934,0z"/>
<path d="M288.339,139.998H34.068c-18.122,0-32.865,14.743-32.865,32.865v254.272C1.204,445.257,15.946,460,34.068,460h254.272
c18.122,0,32.865-14.743,32.865-32.864V172.863C321.206,154.741,306.461,139.998,288.339,139.998z M288.341,430H34.068
c-1.58,0-2.865-1.285-2.865-2.864V172.863c0-1.58,1.285-2.865,2.865-2.865h254.272c1.58,0,2.865,1.285,2.865,2.865v254.273h0.001
C291.206,428.715,289.92,430,288.341,430z"/>
</g>
</g>
</g>
</svg>
<span className={styles.copyButtonLabel}>{copied ? "Copied" : "Copy"}</span>
</button>
<SyntaxHighlighter
language={language}
style={theme}
PreTag="div"
wrapLongLines
showLineNumbers={false}
customStyle={{ margin: "0.6rem 0", borderRadius: 10, background: "transparent" }}
>
{code}
</SyntaxHighlighter>
</div>
);
}
interface QuerySectionProps {
query: string;
response: string;
isStreaming: boolean;
setQuery: (value: string) => void;
setResponse: React.Dispatch<React.SetStateAction<string>>;
setIsStreaming: (value: boolean) => void;
abortControllerRef: React.RefObject<AbortController | null>;
setShowIngestion: (value: boolean) => void;
currentChatId: string | null;
}
interface Message {
type: "HumanMessage" | "AssistantMessage" | "ToolMessage";
content: string;
}
export default function QuerySection({
query,
response,
isStreaming,
setQuery,
setResponse,
setIsStreaming,
abortControllerRef,
setShowIngestion,
currentChatId,
}: QuerySectionProps) {
const messagesEndRef = useRef<HTMLDivElement>(null);
const chatContainerRef = useRef<HTMLDivElement>(null);
const [showButtons, setShowButtons] = useState(false);
const [showWelcome, setShowWelcome] = useState(true);
const [inferenceStats, setInferenceStats] = useState({
tokensReceived: 0,
startTime: Date.now(),
tokensPerSecond: 0
});
const [selectedSources, setSelectedSources] = useState<string[]>([]);
const wsRef = useRef<WebSocket | null>(null);
const [uploadedImage, setUploadedImage] = useState<string | null>(null);
const [imagePreview, setImagePreview] = useState<string | null>(null);
const [isDragging, setIsDragging] = useState(false);
const [toolOutput, setToolOutput] = useState("");
const [graphStatus, setGraphStatus] = useState("");
const [isPinnedToolOutputVisible, setPinnedToolOutputVisible] = useState(false);
const [isToolContentVisible, setIsToolContentVisible] = useState(false);
const [fadeIn, setFadeIn] = useState(false);
const firstTokenReceived = useRef(false);
useEffect(() => {
const timer = setTimeout(() => {
setShowButtons(true);
}, 800);
return () => clearTimeout(timer);
}, []);
useEffect(() => {
if (!isStreaming) {
setInferenceStats(prev => ({
...prev,
tokensReceived: 0,
startTime: 0
}));
}
}, [isStreaming]);
useEffect(() => {
const fetchSelectedSources = async () => {
try {
const response = await fetch("/api/selected_sources");
if (response.ok) {
const { sources } = await response.json();
setSelectedSources(sources);
}
} catch (error) {
console.error("Error fetching selected sources:", error);
}
};
fetchSelectedSources();
}, []);
useEffect(() => {
const initWebSocket = async () => {
if (!currentChatId) return;
try {
if (wsRef.current) {
wsRef.current.close();
}
const wsProtocol = 'ws:';
const wsHost = 'localhost';
const wsPort = '8000';
const ws = new WebSocket(`${wsProtocol}//${wsHost}:${wsPort}/ws/chat/${currentChatId}`);
wsRef.current = ws;
ws.onmessage = (event) => {
const msg = JSON.parse(event.data);
const type = msg.type
const text = msg.data ?? msg.token ?? "";
switch (type) {
case "history": {
console.log('history messages: ', msg.messages);
if (Array.isArray(msg.messages)) {
// const filtered = msg.messages.filter(m => m.type !== "ToolMessage"); // TODO: add this back in
setResponse(JSON.stringify(msg.messages));
setIsStreaming(false);
}
break;
}
case "tool_token": {
if (text !== undefined && text !== "undefined") {
setToolOutput(prev => prev + text);
}
break;
}
case "token": {
if (!text) break;
if (!firstTokenReceived.current) {
console.log('TTFT: ', new Date().toISOString());
firstTokenReceived.current = true;
setIsStreaming(false);
}
setResponse(prev => {
try {
const messages = JSON.parse(prev);
const last = messages[messages.length - 1];
if (last && last.type === "AssistantMessage") {
last.content = String(last.content || "") + text;
} else {
messages.push({ type: "AssistantMessage", content: text });
}
return JSON.stringify(messages);
} catch {
return String(prev || "") + text;
}
});
break;
}
case "node_start": {
if (msg?.data === "generate") {
setGraphStatus("Thinking...");
}
break;
}
case "tool_start": {
console.log(type, msg.data);
setGraphStatus(`calling tool: ${msg?.data}`);
break;
}
case "tool_end":
case "node_end": {
console.log(type, msg.data);
if (msg.data === 'generate') {
console.log('generate complete. time: ', new Date().toISOString());
}
setGraphStatus("");
break;
}
default: {
// ignore unknown events
}
}
};
ws.onclose = () => {
console.log("WebSocket connection closed");
setIsStreaming(false);
};
ws.onerror = (error) => {
console.error("WebSocket error:", error);
setIsStreaming(false);
};
} catch (error) {
console.error("Error initializing WebSocket:", error);
setIsStreaming(false);
}
};
initWebSocket();
return () => {
if (wsRef.current) {
wsRef.current.close();
setIsStreaming(false);
}
};
}, [currentChatId]);
useEffect(() => {
try {
const messages = JSON.parse(response);
setShowWelcome(messages.length === 0);
} catch {
// If response can't be parsed as JSON, check if it's empty
setShowWelcome(!response.trim());
}
}, [response]);
// Show/hide pinnedToolOutput with fade
useEffect(() => {
if (graphStatus) {
setPinnedToolOutputVisible(true);
} else if (isPinnedToolOutputVisible) {
// Delay hiding to allow fade-out
const timeout = setTimeout(() => {
setPinnedToolOutputVisible(false);
}, 800); // match CSS transition duration
return () => clearTimeout(timeout);
}
}, [graphStatus]);
// Replace the effect for fade logic with this minimal version
useEffect(() => {
if (isPinnedToolOutputVisible && graphStatus) {
setFadeIn(false);
const t = setTimeout(() => setFadeIn(true), 10); // next tick for fade-in
return () => clearTimeout(t);
} else {
setFadeIn(false);
}
}, [isPinnedToolOutputVisible, graphStatus]);
// Cleanup image preview URL on unmount
useEffect(() => {
return () => {
if (imagePreview) {
URL.revokeObjectURL(imagePreview);
}
};
}, [imagePreview]);
const programmaticScroll = useRef(false);
const scrollTimeout = useRef<number | null>(null);
const handleDragEnter = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(true); };
const handleDragLeave = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(false); };
const handleDragOver = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); };
const handleDrop = async (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
setIsDragging(false);
const files = Array.from(e.dataTransfer.files);
const imageFile = files.find(file => file.type.startsWith('image/'));
if (imageFile) {
const previewUrl = URL.createObjectURL(imageFile);
setImagePreview(previewUrl);
const formData = new FormData();
formData.append('image', imageFile);
formData.append('chat_id', currentChatId || '');
try {
const response = await fetch('/api/upload-image', { method: 'POST', body: formData });
const result = await response.json();
setUploadedImage(result.image_id);
} catch (error) {
console.error('Error uploading image:', error);
URL.revokeObjectURL(previewUrl);
setImagePreview(null);
}
}
};
const handleQuerySubmit = async (e: React.FormEvent<HTMLFormElement>) => {
e.preventDefault();
const currentQuery = query.trim();
if (!currentQuery || isStreaming || !wsRef.current) return;
setQuery("");
setIsStreaming(true);
firstTokenReceived.current = false;
try {
console.log('sending uploaded image: ', uploadedImage, ' with query: ', currentQuery)
console.log('current time: ', new Date().toISOString());
wsRef.current.send(JSON.stringify({
message: currentQuery,
image_id: uploadedImage
}));
setResponse(prev => {
try {
const messages = JSON.parse(prev);
messages.push({
type: "HumanMessage",
content: currentQuery
});
return JSON.stringify(messages);
} catch {
return prev + `\n\nHuman: ${currentQuery}\n\nAssistant: `;
}
});
// NEW CODE
if (imagePreview) {
URL.revokeObjectURL(imagePreview);
}
setUploadedImage(null);
setImagePreview(null);
// NEW CODE
} catch (error) {
console.error("Error sending message:", error);
setIsStreaming(false);
}
};
const handleCancelStream = () => {
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
setIsStreaming(false);
}
};
// filter out all ToolMessages
const parseMessages = (response: string): Message[] => {
try {
const parsed = JSON.parse(response);
if (!Array.isArray(parsed)) return [];
return parsed
.map((msg: any): Message => ({
type: msg?.type === "HumanMessage"
? "HumanMessage"
: msg?.type === "ToolMessage"
? "ToolMessage"
: "AssistantMessage",
content: typeof msg?.content === "string" ? msg.content : String(msg?.content ?? "")
}))
.filter((msg) => msg.type !== "ToolMessage"); // discard ToolMessage completely
} catch {
if (!response?.trim()) return [];
return [{ type: "AssistantMessage", content: String(response) }];
}
};
return (
<div className={styles.chatContainer}>
{showWelcome && <WelcomeSection setQuery={setQuery}/>}
{/* Minimal fade-in/fade-out: always start hidden, fade in on next tick */}
{/* {isPinnedToolOutputVisible && ( )}*/}
<div className={`${styles.pinnedToolOutput} ${!fadeIn ? styles.pinnedToolOutputHidden : ""}`}>
{graphStatus && (
<div className={styles.toolHeader} onClick={() => setIsToolContentVisible(v => !v)} style={{ cursor: 'pointer' }}>
<span className={styles.toolLabel}> {graphStatus} </span>
</div>
)}
</div>
<div className={styles.messagesContainer} ref={chatContainerRef}>
{parseMessages(response).map((message, index) => {
const isHuman = message.type === "HumanMessage";
const key = `${message.type}-${index}`;
if (!message.content?.trim()) return null;
return (
<div
key={key}
className={`${styles.messageWrapper} ${isHuman ? styles.userMessageWrapper : styles.assistantMessageWrapper}`}
style={{
animationDelay: `${index * 0.1}s`
}}
>
<div className={`${styles.message} ${isHuman ? styles.userMessage : styles.assistantMessage}`}>
<div className={styles.markdown}>
<ReactMarkdown
remarkPlugins={[remarkGfm]}
components={{
code({ inline, className, children, ...props }) {
const match = /language-(\w+)/.exec(className || "");
const code = String(children ?? "").replace(/\n$/, "");
if (inline || !match) {
return (
<code className={className} {...props}>
{code}
</code>
);
}
return (
<CodeBlockWithCopy code={code} language={match[1]} />
);
},
}}
>
{message.content}
</ReactMarkdown>
</div>
</div>
</div>
);
})}
{isStreaming && (
<div
className={`${styles.messageWrapper} ${styles.assistantMessageWrapper}`}
style={{
animationDelay: `${parseMessages(response).length * 0.1}s`
}}
>
<div className={`${styles.message} ${styles.assistantMessage}`}>
<div className={styles.typingIndicator}>
<span></span>
<span></span>
<span></span>
</div>
</div>
</div>
)}
<div ref={messagesEndRef} />
</div>
<form onSubmit={handleQuerySubmit} className={styles.inputContainer}>
{/* NEW CODE - Image preview moved to the left of inputWrapper */}
{imagePreview && (
<div className={styles.imagePreview}>
<img
src={imagePreview}
alt="Image preview"
className={styles.previewImage}
/>
<button
className={styles.removeImageButton}
onClick={() => {
if (imagePreview) {
URL.revokeObjectURL(imagePreview);
}
setUploadedImage(null);
setImagePreview(null);
}}
>
</button>
</div>
)}
{/* NEW CODE */}
<div
className={`${styles.inputWrapper} ${isDragging ? styles.dragging : ''}`}
onDragEnter={handleDragEnter}
onDragLeave={handleDragLeave}
onDragOver={handleDragOver}
onDrop={handleDrop}
>
<button
type="button"
onClick={() => setShowIngestion(true)}
className={`${styles.uploadButton} ${showButtons ? styles.show : ''}`}
title="Upload Documents"
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
width="20"
height="20"
>
<path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.49-8.49l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48" />
</svg>
</button>
<textarea
rows={1}
value={query}
onChange={(e) => setQuery(e.target.value)}
placeholder="Send a message..."
disabled={isStreaming}
className={styles.messageInput}
onKeyDown={(e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
handleQuerySubmit(e as any);
}
}}
/>
</div>
{!isStreaming ? (
<button
type="submit"
className={`${styles.sendButton} ${showButtons ? styles.show : ''}`}
disabled={!query.trim()}
>
</button>
) : (
<button
type="button"
onClick={handleCancelStream}
className={`${styles.streamingCancelButton} ${showButtons ? styles.show : ''}`}
>
</button>
)}
</form>
{inferenceStats.tokensPerSecond > 0 && (
<div className={styles.inferenceStats}>
{inferenceStats.tokensPerSecond} tokens/sec
</div>
)}
</div>
);
}

View File

@ -0,0 +1,696 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import React, { useState, useEffect, useRef } from 'react';
import styles from '@/styles/Sidebar.module.css';
interface Model {
id: string;
name: string;
}
interface ChatMetadata {
name: string;
}
interface SidebarProps {
showIngestion: boolean;
setShowIngestion: (value: boolean) => void;
refreshTrigger?: number;
currentChatId: string | null;
onChatChange: (chatId: string) => Promise<void>;
}
export default function Sidebar({
showIngestion,
setShowIngestion,
refreshTrigger = 0,
currentChatId,
onChatChange
}: SidebarProps) {
const [isVisible, setIsVisible] = useState(false);
const [isClosing, setIsClosing] = useState(false);
const [expandedSections, setExpandedSections] = useState<Set<string>>(new Set(["config", "history"]));
const [isLoading, setIsLoading] = useState(false);
const [availableSources, setAvailableSources] = useState<string[]>([]);
const [selectedSources, setSelectedSources] = useState<string[]>([]);
const [selectedModel, setSelectedModel] = useState<string>("");
const [isLoadingSources, setIsLoadingSources] = useState(false);
const [availableModels, setAvailableModels] = useState<Model[]>([]);
const [isLoadingModels, setIsLoadingModels] = useState(false);
const [chats, setChats] = useState<string[]>([]);
const [isLoadingChats, setIsLoadingChats] = useState(false);
const [chatMetadata, setChatMetadata] = useState<Record<string, ChatMetadata>>({});
// Add ref for chat list
const chatListRef = useRef<HTMLDivElement>(null);
// Load initial configuration
useEffect(() => {
const loadInitialConfig = async () => {
try {
setIsLoading(true);
// Get selected model
const modelResponse = await fetch("/api/selected_model");
if (modelResponse.ok) {
const { model } = await modelResponse.json();
setSelectedModel(model);
}
// Get selected sources
const sourcesResponse = await fetch("/api/selected_sources");
if (sourcesResponse.ok) {
const { sources } = await sourcesResponse.json();
setSelectedSources(sources);
}
// Get available models
await fetchAvailableModels();
// Get sources
await fetchSources();
// Get chats if history section is expanded (which it is by default)
if (expandedSections.has('history')) {
await fetchChats();
}
} catch (error) {
console.error("Error loading initial config:", error);
} finally {
setIsLoading(false);
}
};
loadInitialConfig();
}, []);
// Fetch available models
const fetchAvailableModels = async () => {
try {
setIsLoadingModels(true);
const response = await fetch("/api/available_models");
if (!response.ok) {
const errorText = await response.text();
console.error(`Error fetching available models: ${response.status} - ${errorText}`);
return;
}
const data = await response.json();
const models = data.models.map((modelId: string) => ({
id: modelId,
name: modelId.split('-').map(word => word.charAt(0).toUpperCase() + word.slice(1)).join(' ')
}));
setAvailableModels(models);
} catch (error) {
console.error("Error fetching available models:", error);
} finally {
setIsLoadingModels(false);
}
};
// Fetch available sources
const fetchSources = async () => {
try {
setIsLoadingSources(true);
console.log("Fetching sources...");
const response = await fetch("/api/sources");
if (!response.ok) {
const errorText = await response.text();
console.error(`Error fetching sources: ${response.status} - ${errorText}`);
setAvailableSources([]);
return;
}
const data = await response.json();
console.log("Sources fetched:", data.sources);
setAvailableSources(data.sources || []);
} catch (error) {
console.error("Error fetching sources:", error);
setAvailableSources([]);
} finally {
setIsLoadingSources(false);
}
};
// Get sources on initial load and when the context section is expanded
useEffect(() => {
if (expandedSections.has('context')) {
fetchSources();
}
}, [expandedSections]);
// Refresh sources when refreshTrigger changes (document ingestion)
useEffect(() => {
if (refreshTrigger > 0) { // Only refresh if not the initial render
fetchSources();
}
}, [refreshTrigger]);
// Add function to fetch chat metadata
const fetchChatMetadata = async (chatId: string) => {
try {
const response = await fetch(`/api/chat/${chatId}/metadata`);
if (response.ok) {
const metadata = await response.json();
setChatMetadata(prev => ({
...prev,
[chatId]: metadata
}));
}
} catch (error) {
console.error(`Error fetching metadata for chat ${chatId}:`, error);
}
};
// Update fetchChats to also fetch metadata
const fetchChats = async () => {
try {
console.log("fetchChats: Starting to fetch chats...");
setIsLoadingChats(true);
const response = await fetch("/api/chats");
if (response.ok) {
const data = await response.json();
console.log("fetchChats: Received chats:", data.chats);
setChats(data.chats);
// Fetch metadata for each chat
await Promise.all(data.chats.map(fetchChatMetadata));
console.log("fetchChats: Completed fetching all chat metadata");
} else {
console.error("fetchChats: Failed to fetch chats, status:", response.status);
}
} catch (error) {
console.error("Error fetching chats:", error);
} finally {
setIsLoadingChats(false);
}
};
// Fetch chats when history section is expanded
useEffect(() => {
if (expandedSections.has('history')) {
fetchChats();
}
}, [expandedSections]);
// Update highlight position when currentChatId changes
useEffect(() => {
if (currentChatId && chatListRef.current) {
const activeChat = chatListRef.current.querySelector(`.${styles.active}`) as HTMLElement;
if (activeChat) {
const offset = activeChat.offsetTop;
chatListRef.current.style.setProperty('--highlight-offset', `${offset}px`);
}
}
}, [currentChatId, chats]);
// Add new effect to handle initial position and chat list loading
useEffect(() => {
if (chatListRef.current && currentChatId && chats.length > 0) {
// Small delay to ensure DOM is ready
setTimeout(() => {
const activeChat = chatListRef.current?.querySelector(`.${styles.active}`) as HTMLElement;
if (activeChat) {
const offset = activeChat.offsetTop;
chatListRef.current?.style.setProperty('--highlight-offset', `${offset}px`);
}
}, 50);
}
}, [isVisible, expandedSections.has('history'), chats.length]);
const handleClose = () => {
setIsClosing(true);
setTimeout(() => {
setIsVisible(false);
setIsClosing(false);
}, 500); // Match the new animation duration
};
const toggleSidebar = () => {
if (!isVisible) {
setIsVisible(true);
setIsClosing(false);
fetchSources();
// Also fetch chats when sidebar opens
if (expandedSections.has('history')) {
fetchChats();
}
} else {
handleClose();
}
};
const toggleSection = (section: string) => {
const newExpandedSections = new Set(expandedSections);
if (newExpandedSections.has(section)) {
newExpandedSections.delete(section);
} else {
newExpandedSections.add(section);
// Get sources when context section is expanded
if (section === 'context') {
fetchSources();
}
}
setExpandedSections(newExpandedSections);
};
const isSectionExpanded = (section: string) => {
return expandedSections.has(section);
};
const handleSourceToggle = async (source: string) => {
let newSelectedSources: string[];
if (selectedSources.includes(source)) {
// Remove source if already selected
newSelectedSources = selectedSources.filter(s => s !== source);
} else {
// Add source if not selected
newSelectedSources = [...selectedSources, source];
}
setSelectedSources(newSelectedSources);
try {
const response = await fetch("/api/selected_sources", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(newSelectedSources)
});
if (!response.ok) {
console.error("Failed to update selected sources");
// Revert the local state if the update failed
setSelectedSources(selectedSources);
}
} catch (error) {
console.error("Error updating selected sources:", error);
// Revert the local state if the update failed
setSelectedSources(selectedSources);
}
};
const handleChatSelect = async (chatId: string) => {
try {
await onChatChange(chatId);
// Close sidebar on mobile after selection
if (window.innerWidth < 768) {
handleClose();
}
} catch (error) {
console.error("Error selecting chat:", error);
}
};
const handleRenameChat = async (chatId: string, currentName: string) => {
const newName = prompt("Enter new chat name:", currentName);
if (newName && newName.trim() && newName !== currentName) {
try {
const response = await fetch("/api/chat/rename", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ chat_id: chatId, new_name: newName.trim() })
});
if (!response.ok) {
console.error("Failed to rename chat");
return;
}
// Fetch updated metadata for the renamed chat
await fetchChatMetadata(chatId);
} catch (error) {
console.error("Error renaming chat:", error);
}
}
};
const handleDeleteChat = async (chatId: string) => {
try {
// Delete the chat
const response = await fetch(`/api/chat/${chatId}`, {
method: "DELETE"
});
if (!response.ok) {
console.error("Failed to delete chat");
return;
}
// Refresh chat list
await fetchChats();
// If we deleted the current chat
if (currentChatId === chatId) {
// Get updated list of chats
const chatsResponse = await fetch("/api/chats");
const { chats: remainingChats } = await chatsResponse.json();
if (remainingChats.length > 0) {
// Switch to another chat
await onChatChange(remainingChats[0]);
} else {
// No chats left, create a new one
await handleNewChat();
}
}
} catch (error) {
console.error("Error deleting chat:", error);
}
};
const handleNewChat = async () => {
try {
console.log("handleNewChat: Starting new chat creation...");
// Create new chat using backend endpoint
const response = await fetch("/api/chat/new", {
method: "POST"
});
if (!response.ok) {
console.error("Failed to create new chat");
return;
}
const data = await response.json();
console.log("handleNewChat: Created new chat:", data.chat_id);
// First, refresh the chat list to ensure the new chat is available
console.log("handleNewChat: Refreshing chat list...");
await fetchChats();
console.log("handleNewChat: Chat list refreshed");
// Then change to the new chat
await onChatChange(data.chat_id);
console.log("handleNewChat: Changed to new chat");
// Close sidebar on mobile
if (window.innerWidth < 768) {
handleClose();
}
// Add a small delay to ensure the DOM has updated, then trigger highlight animation
setTimeout(() => {
if (chatListRef.current) {
const activeChat = chatListRef.current.querySelector(`.${styles.active}`) as HTMLElement;
if (activeChat) {
const offset = activeChat.offsetTop;
chatListRef.current.style.setProperty('--highlight-offset', `${offset}px`);
}
}
}, 100); // Increased delay for more reliability
} catch (error) {
console.error("Error creating new chat:", error);
}
};
const handleClearAllChats = async () => {
// Show confirmation dialog
const confirmClear = window.confirm(
`Are you sure you want to clear all ${chats.length} chat conversations? This action cannot be undone.`
);
if (!confirmClear) {
return;
}
try {
const response = await fetch("/api/chats/clear", {
method: "DELETE"
});
if (!response.ok) {
console.error("Failed to clear all chats");
alert("Failed to clear chats. Please try again.");
return;
}
const data = await response.json();
// Switch to the new chat created by the backend
await onChatChange(data.new_chat_id);
// Refresh chat list
await fetchChats();
// Close sidebar on mobile
if (window.innerWidth < 768) {
handleClose();
}
console.log(`Successfully cleared ${data.cleared_count} chats`);
} catch (error) {
console.error("Error clearing all chats:", error);
alert("An error occurred while clearing chats. Please try again.");
}
};
const handleModelChange = async (event: React.ChangeEvent<HTMLSelectElement>) => {
const newModel = event.target.value;
const newModelLower = newModel.toLowerCase();
setSelectedModel(newModel);
try {
console.log("Updating selected model to:", newModel);
const response = await fetch("/api/selected_model", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ model: newModel })
});
if (!response.ok) {
console.error("Failed to update selected model");
// Revert the local state if the update failed
setSelectedModel(selectedModel);
}
} catch (error) {
console.error("Error updating selected model:", error);
// Revert the local state if the update failed
setSelectedModel(selectedModel);
}
};
return (
<>
<button
className={`${styles.toggleSidebarButton} ${isVisible && !isClosing ? styles.active : ''}`}
onClick={toggleSidebar}
>
</button>
{isVisible && (
<>
<div
className={`${styles.sidebarOverlay} ${isClosing ? styles.closing : ''}`}
onClick={handleClose}
/>
<div className={`${styles.sidebar} ${isClosing ? styles.closing : ''}`}>
<button
className={styles.closeSidebarButton}
onClick={handleClose}
>
×
</button>
<div className={styles.sidebarHeader}>
<h2 className={styles.title}>Spark Chat</h2>
</div>
{/* Model Selection */}
<div className={styles.sidebarSection}>
<div
className={styles.sectionHeader}
onClick={() => toggleSection('model')}
>
<h3>Model</h3>
<span className={isSectionExpanded('model') ? styles.arrowUp : styles.arrowDown}></span>
</div>
<div className={`${styles.sectionContent} ${isSectionExpanded('model') ? styles.expanded : ''}`}>
<div className={styles.configItem}>
<label htmlFor="model-select">Select Supervisor Model</label>
<select
id="model-select"
className={styles.modelSelect}
value={selectedModel}
onChange={handleModelChange}
disabled={isLoadingModels}
>
{isLoadingModels ? (
<option value="">Loading models...</option>
) : (
availableModels.map(model => (
<option key={model.id} value={model.id}>
{model.name}
</option>
))
)}
</select>
</div>
</div>
</div>
{/* Context */}
<div className={styles.sidebarSection}>
<div
className={styles.sectionHeader}
onClick={() => toggleSection('context')}
>
<h3>Context</h3>
<span className={isSectionExpanded('context') ? styles.arrowUp : styles.arrowDown}></span>
</div>
<div className={`${styles.sectionContent} ${isSectionExpanded('context') ? styles.expanded : ''}`}>
<div className={styles.configItem}>
<label>Select Sources</label>
<div className={styles.sourcesContainer}>
{availableSources.length === 0 ? (
<div className={styles.noSources}>No sources available</div>
) : (
availableSources.map(source => (
<div key={source} className={styles.sourceItem}>
<input
type="checkbox"
id={`source-${source}`}
checked={selectedSources.includes(source)}
onChange={() => handleSourceToggle(source)}
/>
<label htmlFor={`source-${source}`}>{source}</label>
</div>
))
)}
</div>
<button
className={styles.refreshButton}
onClick={(e) => {
e.preventDefault();
fetchSources();
}}
disabled={isLoadingSources}
>
{isLoadingSources ? "Loading..." : "Refresh Sources"}
</button>
</div>
</div>
</div>
{/* Chat History */}
<div className={styles.sidebarSection}>
<div
className={styles.sectionHeader}
onClick={() => toggleSection('history')}
>
<h3>Chat History</h3>
<span className={isSectionExpanded('history') ? styles.arrowUp : styles.arrowDown}></span>
</div>
<div className={`${styles.sectionContent} ${isSectionExpanded('history') ? styles.expanded : ''}`}>
<div className={styles.chatList} ref={chatListRef}>
{isLoadingChats ? (
<div className={styles.loadingText}>Loading chats...</div>
) : chats.length === 0 ? (
<div className={styles.noChatText}>No previous chats</div>
) : (
chats.map((chatId) => (
<div
key={chatId}
className={`${styles.chatItem} ${currentChatId === chatId ? styles.active : ''}`}
onClick={() => handleChatSelect(chatId)}
>
<div className={styles.chatName}>
{chatMetadata[chatId]?.name || chatId.slice(0, 8)}
</div>
<div className={styles.chatActions}>
<button
className={styles.chatActionButton}
onClick={(e) => {
e.stopPropagation();
handleRenameChat(chatId, chatMetadata[chatId]?.name || `Chat ${chatId.slice(0, 8)}`);
}}
>
</button>
<button
className={styles.chatActionButton}
onClick={(e) => {
e.stopPropagation();
handleDeleteChat(chatId);
}}
>
×
</button>
</div>
</div>
))
)}
</div>
</div>
</div>
{/* Chat Action Buttons at bottom */}
<div className={styles.chatButtonsContainer}>
<button
className={styles.newChatButton}
onClick={handleNewChat}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
>
<line x1="12" y1="5" x2="12" y2="19" />
<line x1="5" y1="12" x2="19" y2="12" />
</svg>
New Chat
</button>
<button
className={styles.clearChatsButton}
onClick={handleClearAllChats}
disabled={chats.length === 0}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
>
<path d="M3 6h18" />
<path d="M19 6v14c0 1-1 2-2 2H7c-1 0-2-1-2-2V6" />
<path d="M8 6V4c0-1 1-2 2-2h4c1 0 2 1 2 2v2" />
</svg>
Clear All
</button>
</div>
</div>
</>
)}
</>
);
}

View File

@ -0,0 +1,63 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
"use client";
import { useState, useEffect } from 'react';
export default function ThemeToggle() {
const [isDark, setIsDark] = useState(false);
useEffect(() => {
// Check if user has a theme preference
const isDarkMode = localStorage.getItem('theme') === 'dark' ||
(!localStorage.getItem('theme') && window.matchMedia('(prefers-color-scheme: dark)').matches);
setIsDark(isDarkMode);
if (isDarkMode) {
document.documentElement.classList.add('dark');
}
}, []);
const toggleTheme = () => {
setIsDark(!isDark);
if (!isDark) {
document.documentElement.classList.add('dark');
localStorage.setItem('theme', 'dark');
} else {
document.documentElement.classList.remove('dark');
localStorage.setItem('theme', 'light');
}
};
return (
<button
onClick={toggleTheme}
className="fixed top-4 right-4 p-2 rounded-full bg-white dark:bg-gray-800 shadow-lg hover:shadow-xl transition-all duration-200 z-50"
aria-label="Toggle theme"
>
{isDark ? (
<svg className="w-6 h-6 text-yellow-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M12 3v1m0 16v1m9-9h-1M4 12H3m15.364 6.364l-.707-.707M6.343 6.343l-.707-.707m12.728 0l-.707.707M6.343 17.657l-.707.707M16 12a4 4 0 11-8 0 4 4 0 018 0z" />
</svg>
) : (
<svg className="w-6 h-6 text-gray-700" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M20.354 15.354A9 9 0 018.646 3.646 9.003 9.003 0 0012 21a9.003 9.003 0 008.354-5.646z" />
</svg>
)}
</button>
);
}

View File

@ -0,0 +1,103 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import styles from "@/styles/WelcomeSection.module.css";
interface WelcomeSectionProps {
setQuery: (value: string) => void;
}
export default function WelcomeSection({ setQuery }: WelcomeSectionProps) {
const promptTemplates = {
rag: "What is the Blackwell GB202 GPU according the whitepaper document i uploaded?",
image: "Can you analyze the graphs in this image and tell me any surprising stats? https://menlovc.com/wp-content/uploads/2025/06/5-parents_power_users-062425.png",
code: `Can you generate code to develop a responsive personal website for my freelance AI dev business based on my personal brand palette?
My palette is:
#606C38
#283618
#FEFAE0
#DDA15E
#BC6C25`,
chat: "Hey Spark! Can you draft an email asking a product manager in distributed systems to a coffee chat?"
};
const handleCardClick = (promptKey: keyof typeof promptTemplates) => {
setQuery(promptTemplates[promptKey]);
};
return (
<div className={styles.welcomeContainer}>
<div className={styles.welcomeMessage}>
Hello! Send a message to start chatting with Spark.
</div>
<div className={styles.agentCards}>
<div
className={`${styles.agentCard} ${styles.animate1}`}
onClick={() => handleCardClick('rag')}
>
<div className={styles.agentIcon}>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
<circle cx="11" cy="11" r="8"/>
<path d="m21 21-4.35-4.35"/>
</svg>
</div>
<h3 className={styles.agentTitle}>Search Documents</h3>
<p className={styles.agentSubtitle}>RAG Agent</p>
</div>
<div
className={`${styles.agentCard} ${styles.animate2}`}
onClick={() => handleCardClick('image')}
>
<div className={styles.agentIcon}>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
<rect width="18" height="18" x="3" y="3" rx="2" ry="2"/>
<circle cx="9" cy="9" r="2"/>
<path d="m21 15-3.086-3.086a2 2 0 0 0-2.828 0L6 21"/>
</svg>
</div>
<h3 className={styles.agentTitle}>Image Processor</h3>
<p className={styles.agentSubtitle}>Image Understanding Agent</p>
</div>
<div
className={`${styles.agentCard} ${styles.animate3}`}
onClick={() => handleCardClick('code')}
>
<div className={styles.agentIcon}>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
<polyline points="16,18 22,12 16,6"/>
<polyline points="8,6 2,12 8,18"/>
</svg>
</div>
<h3 className={styles.agentTitle}>Code Generation</h3>
<p className={styles.agentSubtitle}>Coding Agent</p>
</div>
<div
className={`${styles.agentCard} ${styles.animate4}`}
onClick={() => handleCardClick('chat')}
>
<div className={styles.agentIcon}>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
<path d="m3 21 1.9-5.7a8.5 8.5 0 1 1 3.8 3.8z"/>
</svg>
</div>
<h3 className={styles.agentTitle}>Chat</h3>
<p className={styles.agentSubtitle}>Local LLM</p>
</div>
</div>
</div>
);
}

View File

@ -0,0 +1,136 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}}
.section {
background: white;
color: black;
border-radius: 12px;
padding: 1.5rem;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.08);
width: 350px;
flex-shrink: 0;
animation: fadeIn 0.8s cubic-bezier(0.4,0.2,0.2,1) both;
}
.ingestForm {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.uploadSection {
padding: 1.5rem;
border: 2px dashed #e2e8f0;
border-radius: 8px;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
transition: border-color 0.2s;
/* background-color: #29526d; */
}
.uploadSection:hover {
border-color: #4299e1;
}
.fileInput {
display: none;
}
/* choose files button*/
.customFileLabel {
padding: 0.25rem 1rem;
background-color: #f5f6fa;
color: #222;
border: 1px solid #d1d5db;
border-radius: 4px;
font-weight: 400;
cursor: pointer;
transition: background 0.2s, border 0.2s, color 0.2s;
font-size: 0.98rem;
margin-bottom: 0.5em;
display: inline-block;
text-align: center;
}
.customFileLabel:hover {
background-color: #ffffff;
color: #111;
border: 1px solid var(--primary);
}
.customFileLabel:active, .customFileLabel:focus {
outline: 2px solid #76B900;
}
.fileName {
display: block;
margin-top: 0.5em;
color: #222;
text-align: center;
font-size: 1rem;
margin-bottom: 0.5em;
}
.directoryInput {
width: 100%;
padding: 0.75rem;
border: 1px solid #e2e8f0;
border-radius: 2px;
font-size: 1rem;
}
.helpText {
color: #718096;
font-size: 0.875rem;
margin-top: 0.5rem;
text-align: center;
}
.ingestButton {
padding: 0.75rem 1.5rem;
background-color: #48bb78;
color: white;
border: none;
border-radius: 8px;
font-weight: 600;
cursor: pointer;
transition: background-color 0.2s;
}
.ingestButton:hover {
background-color: #38a169;
}
.ingestButton:disabled {
background-color: #cbd5e0;
cursor: not-allowed;
}
.messageContainer {
margin-top: 1.5rem;
padding: 1rem;
border: 1px solid #e2e8f0;
border-radius: 8px;
background-color: #f7fafc;
}

View File

@ -0,0 +1,117 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
.container {
display: flex;
height: 100vh;
width: 100vw;
overflow: hidden;
position: relative;
}
.mainContent {
flex: 1;
display: flex;
overflow: hidden;
width: 100%;
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
}
.documentUploadContainer {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
z-index: 20;
background: white;
padding: 24px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
max-width: 500px;
width: 90%;
}
.overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: rgba(0, 0, 0, 0.5);
z-index: 15;
}
.closeButton {
position: absolute;
top: 12px;
right: 12px;
background: none;
border: none;
font-size: 1.25rem;
cursor: pointer;
color: #718096;
}
.closeButton:hover {
color: #4a5568;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 1.5rem;
}
.title {
color: #2d3748;
font-size: 2rem;
margin: 0;
}
.toggleButton {
background-color: #4299e1;
color: white;
border: none;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
font-weight: 500;
transition: background-color 0.2s;
}
.toggleButton:hover {
background-color: #3182ce;
}
.mainContent > div:only-child {
width: 100%;
}
@media (max-width: 768px) {
.container {
flex-direction: column;
}
.header {
flex-direction: column;
gap: 1rem;
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,781 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
.sidebar {
width: 320px;
height: auto;
max-height: 90vh;
background-color: white;
border-radius: 12px;
padding: 16px;
display: flex;
flex-direction: column;
overflow-y: auto;
position: absolute;
top: 60px;
left: 20px;
z-index: 5;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
pointer-events: auto;
transform-origin: top;
animation: menuExpand 0.6s cubic-bezier(0.34, 1.56, 0.64, 1) forwards;
}
:global(.dark) .sidebar {
background-color: #1f2937;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
.sidebar.closing {
animation: menuCollapse 0.5s cubic-bezier(0.34, 1.56, 0.64, 1) forwards;
}
@keyframes menuExpand {
0% {
opacity: 0;
transform: scaleY(0.3) translateY(-40px);
}
100% {
opacity: 1;
transform: scaleY(1) translateY(0);
}
}
@keyframes menuCollapse {
0% {
opacity: 1;
transform: scaleY(1) translateY(0);
}
100% {
opacity: 0;
transform: scaleY(0.3) translateY(40px);
}
}
.sidebarOverlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: transparent;
z-index: 5;
display: block;
pointer-events: none;
animation: fadeIn 0.2s ease forwards;
}
.sidebarOverlay.closing {
animation: fadeOut 0.2s ease forwards;
}
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
@keyframes fadeOut {
from {
opacity: 1;
}
to {
opacity: 0;
}
}
.toggleSidebarButton {
position: absolute;
top: 20px;
left: 20px;
z-index: 10;
background-color: white;
border: none;
border-radius: 50%;
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
font-size: 18px;
transition: transform 0.5s cubic-bezier(0.34, 1.56, 0.64, 1);
color: #1f2937;
}
:global(.dark) .toggleSidebarButton {
background-color: #1f2937;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
color: white;
}
.toggleSidebarButton.active {
transform: rotate(180deg);
}
.closeSidebarButton {
position: absolute;
top: 12px;
right: 12px;
background: none;
border: none;
font-size: 24px;
line-height: 1;
cursor: pointer;
color: #718096;
padding: 4px 8px;
border-radius: 4px;
}
.closeSidebarButton:hover {
color: #4a5568;
background-color: #f3f4f6;
}
:global(.dark) .closeSidebarButton {
color: #9ca3af;
}
:global(.dark) .closeSidebarButton:hover {
color: #e5e7eb;
background-color: #374151;
}
.sidebarHeader {
padding-bottom: 16px;
border-bottom: 1px solid #e2e8f0;
margin-bottom: 16px;
}
:global(.dark) .sidebarHeader {
border-bottom-color: #374151;
}
.title {
font-size: 1.25rem;
font-weight: 600;
color: #2d3748;
margin: 0;
}
:global(.dark) .title {
color: #e5e7eb;
}
.sidebarSection {
margin-bottom: 16px;
border-bottom: 1px solid #e2e8f0;
padding-bottom: 16px;
}
:global(.dark) .sidebarSection {
border-bottom-color: #374151;
}
.sidebarSection:last-child {
border-bottom: none;
}
.sectionHeader {
display: flex;
justify-content: space-between;
align-items: center;
cursor: pointer;
padding: 8px 0;
}
.sectionHeader h3 {
margin: 0;
font-size: 1rem;
font-weight: 600;
color: #4a5568;
}
:global(.dark) .sectionHeader h3 {
color: #e5e7eb;
}
.sectionContent {
padding-top: 8px;
max-height: 0;
overflow: hidden;
opacity: 0;
transform: translateY(-8px);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
pointer-events: none;
}
.sectionContent.expanded {
max-height: 1000px;
opacity: 1;
transform: translateY(0);
pointer-events: auto;
}
.arrowDown, .arrowUp {
font-size: 12px;
color: #718096;
transition: transform 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
:global(.dark) .arrowDown, :global(.dark) .arrowUp {
color: #9ca3af;
}
.arrowUp {
transform: rotate(180deg);
}
.configItem {
margin-bottom: 12px;
display: flex;
flex-direction: column;
}
.configItem label {
font-size: 0.875rem;
color: #4a5568;
margin-bottom: 4px;
}
:global(.dark) .configItem label {
color: #e5e7eb;
}
.select {
padding: 8px 12px;
border-radius: 6px;
border: 1px solid #e2e8f0;
background-color: #f9fafb;
font-size: 0.875rem;
color: #4a5568;
width: 100%;
}
:global(.dark) .select {
background-color: #374151;
border-color: #4b5563;
color: #e5e7eb;
}
.toolItem {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 8px;
}
.toolItem input[type="checkbox"] {
margin-right: 8px;
}
.toolItem label {
font-size: 0.875rem;
color: #4a5568;
}
:global(.dark) .toolItem label {
color: #e5e7eb;
}
.toggleSwitch {
position: relative;
display: inline-block;
width: 48px;
height: 24px;
}
.toggleSwitch input {
opacity: 0;
width: 0;
height: 0;
}
.toggleSlider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: #e2e8f0;
transition: .4s;
border-radius: 24px;
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1);
}
:global(.dark) .toggleSlider {
background-color: #374151;
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.2);
}
.toggleSlider:before {
position: absolute;
content: "";
height: 18px;
width: 18px;
left: 3px;
bottom: 3px;
background-color: white;
transition: .4s;
border-radius: 50%;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
input:checked + .toggleSlider {
background-color: #76B900;
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.2);
}
input:checked + .toggleSlider:before {
transform: translateX(24px);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
}
.toolLabel {
font-size: 14px;
color: #1f2937;
}
:global(.dark) .toolLabel {
color: #e5e7eb;
}
.chatList {
position: relative;
display: flex;
flex-direction: column;
gap: 4px;
padding: 4px;
}
.chatList::before {
content: '';
position: absolute;
top: 4px;
left: 0;
right: 0;
height: 36px;
background: rgba(118, 185, 0, 0.08);
border: 1px solid rgba(118, 185, 0, 0.2);
border-radius: 8px;
transform: translateY(calc(var(--highlight-offset, 0) - 3px));
transition: transform 0.2s ease-out;
pointer-events: none;
z-index: 0;
}
.chatItem {
position: relative;
display: flex;
align-items: center;
justify-content: space-between;
padding: 6px 12px;
border-radius: 8px;
height: 36px;
cursor: pointer;
z-index: 1;
transition: background-color 0.2s ease;
}
.chatItem:hover {
background-color: rgba(0, 0, 0, 0.03);
}
.chatName {
flex: 1;
font-size: 14px;
font-weight: 500;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
transition: color 0.2s ease;
color: #1f2937;
line-height: 24px;
padding-top: 1px;
}
:global(.dark) .chatName {
color: #ffffff;
}
.chatItem.active .chatName {
color: #76B900;
font-weight: 600;
}
.chatActions {
display: flex;
gap: 4px;
opacity: 0;
transition: opacity 0.2s ease;
}
.chatItem:hover .chatActions {
opacity: 1;
}
.chatActionButton {
background: none;
border: none;
color: #718096;
padding: 4px;
cursor: pointer;
border-radius: 4px;
font-size: 14px;
line-height: 1;
display: flex;
align-items: center;
justify-content: center;
transition: all 0.2s ease;
}
.chatActionButton:hover {
background-color: rgba(0, 0, 0, 0.05);
color: #1f2937;
}
.loadingText, .noChatText {
color: #718096;
font-size: 0.875rem;
text-align: center;
padding: 8px;
}
:global(.dark) .loadingText, :global(.dark) .noChatText {
color: #9ca3af;
}
.activeIndicator {
width: 8px;
height: 8px;
border-radius: 50%;
background-color: #76B900;
margin-right: 8px;
flex-shrink: 0;
}
.chatInfo {
display: flex;
flex-direction: column;
flex: 1;
cursor: pointer;
}
.chatDate {
font-size: 0.75rem;
color: #718096;
}
:global(.dark) .chatDate {
color: #9ca3af;
}
.actionButton {
width: 100%;
padding: 10px 16px;
background-color: #76B900;
color: white;
border: none;
border-radius: 8px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background-color 0.2s;
margin-top: 8px;
}
.actionButton:hover {
background-color: #669f00;
}
.refreshButton {
margin-top: 8px;
padding: 6px 12px;
background-color: #e2e8f0;
border: none;
border-radius: 4px;
font-size: 0.75rem;
cursor: pointer;
transition: background-color 0.2s;
}
:global(.dark) .refreshButton {
background-color: #374151;
color: #e5e7eb;
}
.refreshButton:hover {
background-color: #cbd5e0;
}
:global(.dark) .refreshButton:hover {
background-color: #4b5563;
}
.refreshButton:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.sourcesContainer {
margin-top: 8px;
max-height: 200px;
overflow-y: auto;
border: 1px solid #e2e8f0;
border-radius: 4px;
padding: 8px;
}
:global(.dark) .sourcesContainer {
background-color: #374151;
border-color: #4b5563;
}
.sourceItem {
display: flex;
align-items: center;
margin-bottom: 6px;
}
.sourceItem input[type="checkbox"] {
margin-right: 8px;
}
.sourceItem label {
font-size: 14px;
cursor: pointer;
}
:global(.dark) .sourceItem label {
color: #e5e7eb;
}
.noSources {
color: #718096;
font-size: 14px;
padding: 8px 0;
}
:global(.dark) .noSources {
color: #9ca3af;
}
.modelSelect {
padding: 8px 12px;
border-radius: 6px;
border: 1px solid #e2e8f0;
background-color: #f9fafb;
font-size: 0.875rem;
color: #4a5568;
width: 100%;
margin-top: 4px;
cursor: pointer;
appearance: none;
background-image: url("data:image/svg+xml;charset=US-ASCII,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20width%3D%22292.4%22%20height%3D%22292.4%22%3E%3Cpath%20fill%3D%22%23007CB2%22%20d%3D%22M287%2069.4a17.6%2017.6%200%200%200-13-5.4H18.4c-5%200-9.3%201.8-12.9%205.4A17.6%2017.6%200%200%200%200%2082.2c0%205%201.8%209.3%205.4%2012.9l128%20127.9c3.6%203.6%207.8%205.4%2012.8%205.4s9.2-1.8%2012.8-5.4L287%2095c3.5-3.5%205.4-7.8%205.4-12.8%200-5-1.9-9.2-5.5-12.8z%22%2F%3E%3C%2Fsvg%3E");
background-repeat: no-repeat;
background-position: right 12px top 50%;
background-size: 12px auto;
}
:global(.dark) .modelSelect {
background-color: #374151;
border-color: #4b5563;
color: #e5e7eb;
}
.modelSelect:hover {
border-color: #cbd5e0;
}
:global(.dark) .modelSelect:hover {
border-color: #6b7280;
}
.modelSelect:focus {
outline: none;
border-color: #76B900;
box-shadow: 0 0 0 3px rgba(118, 185, 0, 0.2);
}
:global(.dark) .modelSelect:focus {
box-shadow: 0 0 0 3px rgba(118, 185, 0, 0.4);
}
@media (max-width: 768px) {
.sidebar {
width: 280px;
}
}
.newChatButton {
width: 100%;
padding: 12px;
background-color: #76B900;
color: white;
border: none;
border-radius: 12px;
font-size: 0.875rem;
font-weight: 600;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
transition: all 0.3s ease;
box-shadow: 0 2px 8px rgba(118, 185, 0, 0.2);
position: relative;
overflow: hidden;
}
.newChatButton::before {
content: '';
position: absolute;
top: 50%;
left: 50%;
width: 100%;
height: 100%;
background: rgba(255, 255, 255, 0.1);
transform: translate(-50%, -50%) scale(0);
border-radius: 50%;
transition: transform 0.5s ease;
}
.newChatButton:hover {
background-color: #669f00;
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(118, 185, 0, 0.3);
}
.newChatButton:hover::before {
transform: translate(-50%, -50%) scale(2);
}
:global(.dark) .newChatButton {
box-shadow: 0 2px 8px rgba(118, 185, 0, 0.4);
}
:global(.dark) .newChatButton:hover {
box-shadow: 0 4px 12px rgba(118, 185, 0, 0.5);
}
.newChatButton svg {
width: 16px;
height: 16px;
transition: transform 0.3s ease;
}
.newChatButton:hover svg {
transform: rotate(90deg);
}
.chatButtonsContainer {
display: flex;
flex-direction: column;
gap: 8px;
margin-top: 16px;
}
.clearChatsButton {
width: 100%;
padding: 12px;
background-color: #dc2626;
color: white;
border: none;
border-radius: 12px;
font-size: 0.875rem;
font-weight: 600;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
transition: all 0.3s ease;
box-shadow: 0 2px 8px rgba(220, 38, 38, 0.2);
position: relative;
overflow: hidden;
}
.clearChatsButton:disabled {
background-color: #9ca3af;
cursor: not-allowed;
box-shadow: none;
color: #d1d5db;
}
.clearChatsButton:disabled:hover {
transform: none;
box-shadow: none;
}
.clearChatsButton::before {
content: '';
position: absolute;
top: 50%;
left: 50%;
width: 100%;
height: 100%;
background: rgba(255, 255, 255, 0.1);
transform: translate(-50%, -50%) scale(0);
border-radius: 50%;
transition: transform 0.5s ease;
}
.clearChatsButton:hover:not(:disabled) {
background-color: #b91c1c;
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.3);
}
.clearChatsButton:hover:not(:disabled)::before {
transform: translate(-50%, -50%) scale(2);
}
:global(.dark) .clearChatsButton {
box-shadow: 0 2px 8px rgba(220, 38, 38, 0.4);
}
:global(.dark) .clearChatsButton:hover:not(:disabled) {
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.5);
}
:global(.dark) .clearChatsButton:disabled {
background-color: #4b5563;
color: #6b7280;
}
.clearChatsButton svg {
width: 16px;
height: 16px;
transition: transform 0.3s ease;
}
.clearChatsButton:hover:not(:disabled) svg {
transform: scale(1.1);
}
:global(.dark) .chatItem:hover {
background-color: rgba(255, 255, 255, 0.03);
}
:global(.dark) .chatList::before {
background: rgba(118, 185, 0, 0.15);
border-color: rgba(118, 185, 0, 0.3);
}

View File

@ -0,0 +1,217 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
.welcomeContainer {
position: absolute;
left: 50%;
top: 45%;
transform: translate(-50%, -50%);
text-align: center;
max-width: 900px;
width: 100%;
padding: 20px;
z-index: 10;
opacity: 1;
}
.welcomeMessage {
font-size: 42px;
font-weight: 200;
color: #1f2937;
max-width: 800px;
letter-spacing: -0.5px;
background: linear-gradient(45deg, #76B900, #669f00);
background-clip: text;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin: 0 auto 40px auto;
opacity: 0;
animation: fadeInUp 1.2s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
}
:global(.dark) .welcomeMessage {
color: #e5e7eb;
background: linear-gradient(45deg, #76B900, #9be024);
background-clip: text;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.agentCards {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 24px;
width: 100%;
max-width: 600px;
margin: 0 auto;
}
.agentCard {
background: white;
border-radius: 12px;
padding: 24px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
transition: all 0.3s ease;
cursor: pointer;
text-align: center;
border: 1px solid rgba(118, 185, 0, 0.1);
opacity: 0;
transform: translateY(30px);
}
.agentCard.animate1 {
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
animation-delay: 0.3s;
}
.agentCard.animate2 {
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
animation-delay: 0.4s;
}
.agentCard.animate3 {
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
animation-delay: 0.6s;
}
.agentCard.animate4 {
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
animation-delay: 0.7s;
}
.agentCard.animate1:hover,
.agentCard.animate2:hover,
.agentCard.animate3:hover,
.agentCard.animate4:hover {
transform: translateY(-4px);
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.15);
border-color: rgba(118, 185, 0, 0.3);
}
:global(.dark) .agentCard {
background: #1f2937;
border-color: rgba(118, 185, 0, 0.2);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
:global(.dark) .agentCard:hover {
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4);
border-color: rgba(118, 185, 0, 0.4);
}
.agentIcon {
width: 48px;
height: 48px;
margin: 0 auto 16px auto;
background: linear-gradient(45deg, #76B900, #669f00);
border-radius: 12px;
display: flex;
align-items: center;
justify-content: center;
color: white;
}
.agentTitle {
font-size: 16px;
font-weight: 600;
color: #1f2937;
margin: 0 0 8px 0;
line-height: 1.2;
}
:global(.dark) .agentTitle {
color: #e5e7eb;
}
.agentSubtitle {
font-size: 14px;
font-weight: 400;
color: #6b7280;
margin: 0;
line-height: 1.2;
}
:global(.dark) .agentSubtitle {
color: #9ca3af;
}
@media (max-width: 640px) {
.agentCards {
grid-template-columns: 1fr;
gap: 16px;
max-width: 400px;
}
.welcomeContainer {
max-width: 95%;
padding: 16px;
}
.welcomeMessage {
font-size: 32px;
margin-bottom: 32px;
}
.agentCard {
padding: 20px;
}
}
@keyframes fadeInUp {
0% {
opacity: 0;
transform: translateY(30px);
}
100% {
opacity: 1;
transform: translateY(0);
}
}
@keyframes welcomeBubble {
0% {
opacity: 0;
transform: translate(-50%, 30px) scale(0.8);
filter: blur(10px);
}
60% {
opacity: 0.8;
filter: blur(0px);
}
100% {
opacity: 1;
transform: translate(-50%, -50%) scale(1);
filter: blur(0px);
}
}
@keyframes welcomeDisappear {
0% {
opacity: 1;
transform: translate(-50%, -50%) scale(1);
filter: blur(0px);
}
100% {
opacity: 0;
transform: translate(-50%, -80px) scale(0.9);
filter: blur(10px);
}
}
.welcomeMessage.hide {
animation: welcomeDisappear 0.6s cubic-bezier(0.4, 0, 0.2, 1) forwards;
}

View File

@ -0,0 +1,41 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
export interface ChatConfig {
id: string;
name: string;
model: string;
selectedSources: string[];
}
export interface AppConfig {
currentChatId: string;
chats: ChatConfig[];
sources: string[];
}
export const DEFAULT_CONFIG: AppConfig = {
currentChatId: "default-chat",
chats: [
{
id: "default-chat",
name: "New Chat",
model: "gpt-3.5-turbo",
selectedSources: []
}
],
sources: []
};

View File

@ -0,0 +1,34 @@
/*
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
import type { Config } from "tailwindcss";
export default {
content: [
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
],
theme: {
extend: {
colors: {
background: "var(--background)",
foreground: "var(--foreground)",
},
},
},
plugins: [],
} satisfies Config;

View File

@ -0,0 +1,27 @@
{
"compilerOptions": {
"target": "ES2017",
"lib": ["dom", "dom.iterable", "esnext"],
"allowJs": true,
"skipLibCheck": true,
"strict": true,
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"incremental": true,
"plugins": [
{
"name": "next"
}
],
"paths": {
"@/*": ["./src/*"]
}
},
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
"exclude": ["node_modules"]
}

View File

@ -0,0 +1,56 @@
#!/usr/bin/env bash
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
set -euo pipefail
ROOT_DIR="$(pwd)"
MODELS_DIR="$ROOT_DIR/models"
mkdir -p "$MODELS_DIR"
cd "$MODELS_DIR"
download_if_needed() {
url="$1"
file="$2"
if [ -f "$file" ]; then
echo "$file already exists, skipping."
else
curl -C - -L -o "$file" "$url"
fi
}
download_if_needed "https://huggingface.co/TheBloke/deepseek-coder-6.7B-instruct-GGUF/resolve/main/deepseek-coder-6.7b-instruct.Q8_0.gguf" "deepseek-coder-6.7b-instruct.Q8_0.gguf"
download_if_needed "https://huggingface.co/Qwen/Qwen3-Embedding-4B-GGUF/resolve/main/Qwen3-Embedding-4B-Q8_0.gguf" "Qwen3-Embedding-4B-Q8_0.gguf"
# Comment next three lines if you want to use gpt-oss-20b
download_if_needed "https://huggingface.co/ggml-org/gpt-oss-120b-GGUF/resolve/main/gpt-oss-120b-mxfp4-00001-of-00003.gguf" "gpt-oss-120b-mxfp4-00001-of-00003.gguf"
download_if_needed "https://huggingface.co/ggml-org/gpt-oss-120b-GGUF/resolve/main/gpt-oss-120b-mxfp4-00002-of-00003.gguf" "gpt-oss-120b-mxfp4-00002-of-00003.gguf"
download_if_needed "https://huggingface.co/ggml-org/gpt-oss-120b-GGUF/resolve/main/gpt-oss-120b-mxfp4-00003-of-00003.gguf" "gpt-oss-120b-mxfp4-00003-of-00003.gguf"
# Uncomment next line if you want to use gpt-oss-20b
# download_if_needed "https://huggingface.co/ggml-org/gpt-oss-20b-GGUF/resolve/main/gpt-oss-20b-mxfp4.gguf" "gpt-oss-20b-mxfp4.gguf"
echo "All models downloaded."
cd "$ROOT_DIR"
echo "Starting Docker Compose services..."
docker compose -f docker-compose.yml -f docker-compose-models.yml up -d --build
echo "Docker Compose services are up and running."

View File

@ -0,0 +1,168 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Dockerfile stage to build GPU-accelerated ffmpeg
FROM nvcr.io/nvidia/pytorch:25.09-py3 AS ffmpeg-builder
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
yasm \
libx264-dev \
libfaac-dev \
libmp3lame-dev \
libtheora-dev \
libvorbis-dev \
libxvidcore-dev \
libxext-dev \
libxfixes-dev \
build-essential \
git \
pkg-config && \
apt-get update && \
apt-get install -y --no-install-recommends gcc-11 g++-11
ENV PATH=/usr/local/cuda/bin:${PATH} \
CUDA_HOME=/usr/local/cuda \
NVCC=/usr/local/cuda/bin/nvcc \
CC=/usr/bin/gcc-11 \
CXX=/usr/bin/g++-11 \
CUDAHOSTCXX=/usr/bin/g++-11 \
FFMPEG_VERSION=4.4.6 \
LD_LIBRARY_PATH=/usr/local/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
RUN mkdir -p /deps && \
cd /deps && \
wget -q https://ffmpeg.org/releases/ffmpeg-${FFMPEG_VERSION}.tar.xz && \
tar -xf ffmpeg-${FFMPEG_VERSION}.tar.xz && \
rm ffmpeg-${FFMPEG_VERSION}.tar.xz && \
cd /deps/ffmpeg-${FFMPEG_VERSION} && \
apt-get update && \
apt-get install -y libdrm-dev && \
./configure \
--prefix=/usr/local \
--enable-nonfree \
--enable-shared \
--disable-static \
--enable-libdrm \
--enable-v4l2-m2m \
--enable-gpl \
--enable-libx264 \
--extra-cflags="-I/usr/include/aarch64-linux-gnu" \
--extra-ldflags="-L/usr/lib/aarch64-linux-gnu/tegra" \
--nvccflags="-ccbin=/usr/bin/g++-11" \
|| (echo "---- tail ffbuild/config.log ----" && tail -n 200 ffbuild/config.log && exit 1) && \
make -j"$(nproc)" && \
make install && \
ldconfig && \
echo "✅ FFmpeg installed:" && \
ffmpeg -hide_banner -version | head -n 8 && \
rm -rf /var/lib/apt/lists/*
# Dockerfile stage to compile decord from source
FROM nvcr.io/nvidia/pytorch:25.09-py3 AS decord-builder
COPY --from=ffmpeg-builder /usr/local /usr/local
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
git \
cmake \
ninja-build \
pkg-config \
python3-dev \
python3-pip \
gcc-11 \
g++-11 \
&& rm -rf /var/lib/apt/lists/*
ENV CC=/usr/bin/gcc-11 \
CXX=/usr/bin/g++-11 \
PATH=/usr/local/bin:/lib/aarch64-linux-gnu:${PATH} \
LD_LIBRARY_PATH=/usr/lib/aarch64-linux-gnu/:/lib/aarch64-linux-gnu/:/usr/local/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH} \
PKG_CONFIG_PATH=/usr/local/lib/pkgconfig:/usr/lib/aarch64-linux-gnu/pkgconfig${PKG_CONFIG_PATH:+:$PKG_CONFIG_PATH}
RUN ln -sf /usr/lib/aarch64-linux-gnu/libnvcuvid.so.1 /usr/lib/aarch64-linux-gnu/libnvcuvid.so && \
cp /usr/lib/aarch64-linux-gnu/libnvcuvid.so* /usr/local/cuda/lib64/ || true && \
ln -sf /usr/lib/aarch64-linux-gnu/libnvcuvid.so.1 /usr/local/cuda/lib64/libnvcuvid.so && \
echo "/usr/lib/aarch64-linux-gnu" > /etc/ld.so.conf.d/nvidia.conf && \
ldconfig && \
python3 -m pip install --no-cache-dir --upgrade pip wheel build numpy && \
python3 -m pip install --no-cache-dir --upgrade pip ninja && \
apt-get update && \
apt-get install -y --no-install-recommends libnvidia-decode-575 libnvidia-encode-575
RUN cd /workspace && \
git clone --recursive https://github.com/dmlc/decord && \
cmake -S /workspace/decord -B /workspace/decord/build \
-G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DUSE_CUDA=ON \
-DCMAKE_CUDA_HOST_COMPILER=/usr/bin/g++-11 \
-DCMAKE_C_COMPILER=/usr/bin/gcc-11 \
-DCMAKE_CXX_COMPILER=/usr/bin/g++-11 \
-DFFMPEG_ROOT=/usr/local \
-DUSE_VIDEO_CODEC=OFF && \
cd /workspace/decord/build && \
ninja -j"$(nproc)" && \
cd /workspace/decord/python && \
python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel build && \
python3 -m build --wheel
# Dockerfile for demo
FROM nvcr.io/nvidia/pytorch:25.09-py3
COPY --from=ffmpeg-builder /usr/local /usr/local
COPY --from=decord-builder /workspace/decord/python/dist/*.whl /tmp/wheels/
ARG HF_TOKEN
RUN pip install --no-cache-dir /tmp/wheels/*.whl && \
rm -rf /tmp/wheels && \
apt-get update && \
apt-get install -y libdrm2 libdrm-dev libx264-dev && \
pip install streamlit timm wandb && \
hf auth login --token $HF_TOKEN
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda-13.0/
ENV CUDA_PATH=$CUDA_HOME
ENV PATH=$CUDA_HOME/bin:$PATH
ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
ENV C_INCLUDE_PATH=$CUDA_HOME/include:$C_INCLUDE_PATH
ENV CPLUS_INCLUDE_PATH=$CUDA_HOME/include:$CPLUS_INCLUDE_PATH
# install triton from source for latest blackwell support
RUN git clone https://github.com/triton-lang/triton.git && \
cd triton && \
git checkout c5d671f91d90f40900027382f98b17a3e04045f6 && \
pip install -r python/requirements.txt && \
pip install . && \
cd ..
# install xformers from source for blackwell support
RUN git clone --depth=1 https://github.com/facebookresearch/xformers --recursive && \
cd xformers && \
export TORCH_CUDA_ARCH_LIST="12.1" && \
python setup.py install && \
cd ..
# install unsloth without depedencies so we can build them from source
RUN pip install unsloth unsloth_zoo bitsandbytes==0.48.0
CMD ["/bin/bash"]

View File

@ -0,0 +1,52 @@
# VLM Fine-tuning Recipes
This repository contains comprehensive fine-tuning recipes for Vision-Language Models (VLMs), supporting both **image** and **video** understanding tasks with modern models and training techniques.
## 🎯 Available Recipes
### 📸 Image VLM Fine-tuning (`ui_image/`)
- **Model**: Qwen2.5-VL-7B-Instruct
- **Task**: Wildfire detection from satellite imagery
- **Training Method**: GRPO (Generalized Reward Preference Optimization) and LoRA (Low-rank Adaptation)
### 🎥 Video VLM Fine-tuning (`ui_video/`)
- **Model**: InternVL3-8B
- **Task**: Dangerous driving detection and structured metadata generation
- **Training Method**: Supervised Fine-tuning on Multimodal Instructions
## 🚀 Quick Start
### 1. Build the Docker Container
```bash
# Build the VLM fine-tuning container
docker build --build-arg HF_TOKEN=$HF_TOKEN -t vlm_demo .
```
### 2. Launch the Container
```bash
# Run the container with GPU support
docker run -it \
--gpus=all \
--net=host \
--ipc=host \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-v $(pwd):/vlm_finetuning \
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
vlm_demo
# Enter the mounted directory
cd /vlm_finetuning
```
> **Note**: The same Docker container and launch commands work for both image and video VLM recipes. The container includes all necessary dependencies including FFmpeg, Decord, and optimized libraries for both workflows.
## 📚 Detailed Instructions
Each recipe includes comprehensive documentation:
- **[Image VLM README](ui_image/README.md)**: Complete guide for wildfire detection fine-tuning with Qwen2.5-VL, including dataset setup, GRPO training configuration, and interactive inference
- **[Video VLM README](ui_video/README.md)**: Full walkthrough for dangerous driving detection with InternVL3, covering video data preparation, training notebooks, and structured output generation

View File

@ -0,0 +1,26 @@
#!/bin/bash
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
docker run -it \
--gpus=all \
--net=host \
--ipc=host \
-w $HOME \
-v $HOME:$HOME \
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
vlm_demo

View File

@ -0,0 +1,3 @@
[theme]
base="dark"
greenTextColor = "#76b900"

View File

@ -0,0 +1,438 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unsloth import FastVisionModel
import os
import re
import gc
import yaml
import glob
import random
import subprocess
import wandb
import torch
from PIL import Image
import streamlit as st
REASONING_START = "<REASONING>"
REASONING_END = "</REASONING>"
SOLUTION_START = "<SOLUTION>"
SOLUTION_END = "</SOLUTION>"
def initialize_session_state(resources):
# Initialize page-specific session state
st.session_state["base"] = st.session_state.get("base", resources["base"])
st.session_state["finetuned"] = st.session_state.get("finetuned", resources["finetuned"])
st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[0])
st.session_state["train_process"] = st.session_state.get("train_process", None)
def load_config():
config_key = "config"
if getattr(st.session_state, config_key, None) is None:
with open("src/image_vlm_config.yaml", "r") as f:
config = yaml.safe_load(f)
setattr(st.session_state, config_key, config)
else:
config = getattr(st.session_state, config_key)
return config
@st.cache_resource
def initialize_resources(inference_config):
base_model, base_tokenizer = load_model_for_inference(inference_config, "base")
finetuned_model, finetuned_tokenizer = load_model_for_inference(inference_config, "finetuned")
return {
"base": {"model": base_model, "tokenizer": base_tokenizer},
"finetuned": {"model": finetuned_model, "tokenizer": finetuned_tokenizer},
}
def main():
# set page ui
st.title("Image VLM Finetuning")
st.caption("A DGX Spark showcase for on-device VLM finetuning")
# st.page_link("https://github.com/your-username/your-repo", label="GitHub", icon=":material/github:")
# load css
with open("src/styles.css", "r") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# load resources
config = load_config()
if st.session_state.get("base", None) is None:
st.toast("Loading model", icon="", duration="short")
resource = initialize_resources(config["inference"])
if st.session_state.get("base", None) is None:
st.toast("Model loaded", icon="", duration="short")
initialize_session_state(resource)
# train section
st.markdown("---")
train_section()
# inference Section
st.markdown("---")
inference_section()
def train_section():
st.header("GRPO Training")
column_1, column_2, column_3 = st.columns(3, gap="large")
with column_1:
finetuning_method = st.selectbox(
"Finetuning Method:",
["LoRA", "QLoRA", "Full Fine-tuning"],
)
# update lora config
if finetuning_method in ("QLoRA", "LoRA"):
lora_config = st.session_state["config"]["train"]["model"]["lora_config"]
with column_2:
lora_rank = st.slider(
"LoRA Rank",
min_value=8,
max_value=64,
value=lora_config["rank"],
step=8,
)
with column_3:
lora_alpha = st.slider(
"LoRA Alpha",
min_value=8,
max_value=64,
value=lora_config["alpha"],
step=8,
)
st.session_state["config"]["train"]["model"]["lora_config"].update({
'rank': lora_rank,
'alpha': lora_alpha,
})
# update model config based on selection
st.session_state["config"]["train"]["model"]["use_lora"] = finetuning_method == "LoRA"
st.session_state["config"]["train"]["model"]["use_qlora"] = finetuning_method == "QLoRA"
# update train config
st.write("")
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
with column_1:
finetune_vision_layers = st.toggle(
"Finetune Vision Layers",
value=st.session_state["config"]["train"]["model"]["finetune_vision_layers"])
with column_2:
finetune_language_layers = st.toggle(
"Finetune Language Layers",
value=st.session_state["config"]["train"]["model"]["finetune_language_layers"])
with column_3:
finetune_attention_modules = st.toggle(
"Finetune Attention Modules",
value=st.session_state["config"]["train"]["model"]["finetune_attention_modules"])
with column_4:
finetune_mlp_modules = st.toggle(
"Finetune MLP Modules",
value=st.session_state["config"]["train"]["model"]["finetune_mlp_modules"])
st.write("")
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
with column_1:
epochs = st.slider(
"Epochs",
min_value=1,
max_value=100,
value=st.session_state["config"]["train"]["hyperparameters"]["epochs"])
with column_2:
batch_size = st.select_slider(
"Batch Size",
options=[1, 2, 4, 8, 16],
value=st.session_state["config"]["train"]["hyperparameters"]["batch_size"])
with column_3:
learning_rate = st.number_input(
"Learning Rate",
min_value=1e-6,
max_value=1e-2,
value=float(st.session_state["config"]["train"]["hyperparameters"]["learning_rate"]),
format="%.2e")
with column_4:
optimizer = st.selectbox(
"Optimizer",
options=["adamw_torch", "adafactor"])
st.session_state["config"]["train"]["hyperparameters"].update({
'epochs': epochs,
'batch_size': batch_size,
'learning_rate': learning_rate,
'optimizer': optimizer,
})
st.session_state["config"]["train"]["model"].update({
'finetune_vision_layers': finetune_vision_layers,
'finetune_language_layers': finetune_language_layers,
'finetune_attention_modules': finetune_attention_modules,
'finetune_mlp_modules': finetune_mlp_modules,
})
st.write("")
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
with column_1:
enable_grpo = st.toggle(
"Enable GRPO",
value=st.session_state["config"]["train"]["hyperparameters"]["enable_grpo"],
disabled=True)
with column_2:
format_reward = st.number_input(
"Reward for reasoning format",
min_value=0.0,
max_value=5.0,
value=float(st.session_state["config"]["train"]["hyperparameters"]["format_reward"]),
format="%.2e")
with column_3:
correctness_reward = st.number_input(
"Reward for correct response",
min_value=0.0,
max_value=5.0,
value=float(st.session_state["config"]["train"]["hyperparameters"]["correctness_reward"]),
format="%.2e")
with column_4:
num_generations = st.number_input(
"Number of generations",
min_value=1,
max_value=16,
value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"],
format="%.2e")
# Training control
st.write("")
column_1, column_2, column_3 = st.columns([4, 4, 1])
with column_1:
button_type = "secondary" if st.session_state["train_process"] else "primary"
if st.button("▶️ Start Finetuning", type=button_type, width="stretch", disabled=bool(st.session_state["train_process"])):
if st.session_state["train_process"] is None:
# store config
with open("src/train.yaml", "w") as f:
yaml.dump(st.session_state["config"]["train"], f, default_flow_style=False)
# start training
st.session_state["train_process"] = subprocess.Popen(
["python", "src/train_image_vlm.py"],
stdout=None, stderr=None
)
else:
st.toast("Training already in progress", icon="", duration="short")
with column_2:
button_type = "primary" if st.session_state["train_process"] else "secondary"
if st.button("⏹️ Stop Finetuning", type=button_type, width="stretch", disabled=not bool(st.session_state["train_process"])):
if st.session_state["train_process"] is not None:
st.session_state["train_process"].terminate()
st.session_state["train_process"] = None
st.toast("Training stopped", icon="", duration="short")
st.toast("Re-deploy the app with updated finetuned model", icon=":material/info:", duration="short")
else:
st.toast("No training to stop", icon="", duration="short")
with column_3:
if st.session_state["train_process"]:
st.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch")
else:
st.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch")
# display wandb
runs = wandb.Api().runs(f"{os.environ.get('WANDB_ENTITY')}/{os.environ.get('WANDB_PROJECT')}")
if runs:
base_url = runs[0].url
loss_url = f"{base_url}?panelDisplayName=train%2Floss&panelSectionName=train"
memory_url = f"{base_url}?panelDisplayName=GPU+Memory+Allocated+%28%25%29&panelSectionName=System"
column_1, column_2 = st.columns(2)
with column_1:
st.markdown(f"""
<div class="wandb-wrapper">
<iframe src="{loss_url}" class="wandb-iframe"></iframe>
</div>
""", unsafe_allow_html=True)
with column_2:
st.markdown(f"""
<div class="wandb-wrapper">
<iframe src="{memory_url}" class="wandb-iframe"></iframe>
</div>
""", unsafe_allow_html=True)
def inference_section():
st.header("Image Inference")
columns = st.columns([3, 3, 1, 2])
with columns[1]:
with st.container(border=True, horizontal_alignment="center", vertical_alignment="center"):
image_holder = st.empty()
image_holder.image(st.session_state["current_image"])
with columns[3]:
if st.button("🎲 Test another sample"):
while True:
current_image = random.choice(glob.glob("assets/image_vlm/images/*/*"))
if current_image != st.session_state["current_image"]:
break
st.session_state["current_image"] = current_image
image_holder.image(st.session_state["current_image"])
columns = st.columns(2, gap="small")
with columns[0]:
with st.container(border=True):
st.write("##### :green[Base Qwen2.5-VL-7B]")
base_generation = st.empty()
base_generation.write("...")
with columns[1]:
with st.container(border=True):
st.write("##### :green[Finetuned Qwen2.5-VL-7B]")
finetuned_generation = st.empty()
finetuned_generation.write("...")
columns = st.columns([9, 1], gap="small")
with columns[0]:
prompt = st.text_input(
"Prompt Input",
label_visibility="collapsed",
key="prompt_input",
on_change=lambda: st.session_state.update(prompt=st.session_state["prompt_input"])
)
with columns[1]:
if st.button("Generate", width="stretch"):
if st.session_state.get("prompt", None):
st.session_state["prompt"] = prompt
with st.spinner("Running..."):
response = start_inference("base")
base_generation.markdown(response)
with st.spinner("Running..."):
response = start_inference("finetuned")
finetuned_generation.markdown(response)
def load_model_for_inference(config, model_type):
if model_type == "finetuned":
model_name = config["finetuned_model_id"]
elif model_type == "base":
model_name = config["model_id"]
else:
raise ValueError(f"Invalid model type: {model_type}")
model, tokenizer = FastVisionModel.from_pretrained(
model_name=model_name,
max_seq_length=config["max_seq_length"],
load_in_4bit=False,
)
FastVisionModel.for_inference(model)
return model, tokenizer
@torch.no_grad()
def start_inference(model_type):
# define prompt
prompt = st.session_state["prompt"]
if model_type == "finetuned":
prompt = (
f"{prompt}. Also first provide your reasoning or working out"\
f" on how you would go about identifying the presence of wildfire affected regions between {REASONING_START} and {REASONING_END}"
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
)
# load image
image = Image.open(st.session_state["current_image"])
if image.mode != "RGB":
image = image.convert("RGB")
# construct instruction prompt
prompt = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
]
# apply chat template
prompt = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True,
)
# tokenize inputs
inputs = st.session_state[f"{model_type}_image_vlm"]["tokenizer"](
image,
prompt,
add_special_tokens=False,
return_tensors="pt",
).to("cuda")
# perform inference
response = st.session_state[f"{model_type}_image_vlm"]["model"].generate(
**inputs,
max_new_tokens=1024,
use_cache=True,
do_sample=False
)[0][inputs["input_ids"].shape[1]: ]
# decode tokens
response = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].decode(response, skip_special_tokens=True)
# format response
if model_type == "finetuned":
response = response.replace(REASONING_START, "```")
response = response.replace(REASONING_END, "```")
# Handle solution formatting with proper newline handling
solution_pattern = f'{re.escape(SOLUTION_START)}(.*?){re.escape(SOLUTION_END)}'
solution_match = re.search(solution_pattern, response, re.DOTALL)
if solution_match:
solution_content = solution_match.group(1).strip()
response = re.sub(solution_pattern, f"**{solution_content}**", response, flags=re.DOTALL)
return response
if __name__ == "__main__":
main()

View File

@ -0,0 +1,137 @@
# Image VLM Fine-tuning with Qwen2.5-VL
This project demonstrates fine-tuning Vision-Language Models (VLMs) for image understanding tasks, specifically using the Qwen2.5-VL-7B model for wildfire detection from satellite imagery using GRPO (Generalized Reward Preference Optimization).
## Overview
The project includes:
- **Interactive Training Interface**: Streamlit-based UI for configuring and monitoring VLM fine-tuning
- **GRPO Training**: Advanced preference optimization for better reasoning capabilities
- **Multiple Fine-tuning Methods**: Support for LoRA, QLoRA, and Full Finetuning
- **Side-by-side Inference**: Compare base model vs fine-tuned model performance
## Getting Started
> **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/`.
### 1. Set Up Weights & Biases
Configure your wandb credentials for training monitoring:
```bash
export WANDB_PROJECT="vlm_finetuning"
export WANDB_ENTITY=<WANDB_USERNAME>
export WANDB_API_KEY=<WANDB_API_KEY>
```
### 2. Launch the Application
```bash
# Start the Streamlit interface
streamlit run Image_VLM.py
```
The application will be available at `http://localhost:8501`
## Training
### Dataset
The project uses a **wildfire detection dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes:
- Satellite and aerial imagery from wildfire-affected areas
- Binary classification: wildfire vs no wildfire
#### Dataset Setup
1. **Download from Kaggle**: Visit the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) on Kaggle
2. **Get the curl command**: On the Kaggle dataset page, click the download button and copy the curl command provided
3. **Download and extract**: Run the following commands in your container:
```bash
mkdir data
cd data
# Paste the curl command from Kaggle here, and then continue to unzip the dataset
unzip -qq wildfire-prediction-dataset.zip
rm wildfire-prediction-dataset.zip
cd ..
```
> **Note**: You'll need to be logged into Kaggle and may need to accept the dataset terms before the download link works.
### Training Configuration
Configure training through the interactive interface:
#### Model Settings
- **Base Model**: Qwen/Qwen2.5-VL-7B-Instruct
- **Fine-tuning Method**: Choose from LoRA, QLoRA, or Full Finetuning
- **LoRA Parameters**: Adjustable rank (8-64) and alpha (8-64)
#### Training Parameters
- **Epochs**: 1-100 (default: 10)
- **Batch Size**: 1, 2, 4, 8, or 16 (default: 2)
- **Learning Rate**: 1e-6 to 1e-2 (default: 1e-5)
- **Optimizer**: AdamW or Adafactor
#### GRPO Settings
- **Format Reward**: 2.0 (reward for proper reasoning format)
- **Correctness Reward**: 5.0 (reward for correct answers)
- **Number of Generations**: 4 (for preference optimization)
### Training Process
1. **Configure Parameters**: Use the web interface to set training hyperparameters
2. **Start Training**: Click "▶️ Start Finetuning" to begin GRPO training
3. **Monitor Progress**: View real-time loss curves and GPU utilization via embedded wandb charts
4. **Stop if Needed**: Use "⏹️ Stop Finetuning" to halt training early
> **Important**: After training completes, follow these steps:
> 1. **Stop the UI**: Use Ctrl+C to stop the Streamlit application
> 2. **Update Config**: Edit `src/image_vlm_config.yaml` and change the `finetuned_model_id` path to point to your newly trained model in the `saved_model/` directory
> 3. **Restart UI**: Launch the application again to test your fine-tuned model
## Inference
### Interactive Comparison
![Inference Interface](assets/inference_screenshot.png)
*Side-by-side comparison showing base model vs fine-tuned model performance on wildfire detection*
The inference section provides:
- **Sample Images**: Test on wildfire detection samples from both categories (wildfire/no wildfire)
- **Dual Inference**: Run both base and fine-tuned models simultaneously
- **Random Sampling**: Test different samples with the "🎲 Test another sample" button
- **Structured Reasoning**: Fine-tuned model provides reasoning in `<REASONING>` tags before final answer
### Sample Questions
The interface includes prompts for wildfire detection:
- "Identify if this region has been affected by a wildfire"
- The fine-tuned model provides structured reasoning followed by a Yes/No answer
## File Structure
```
ui_image/
├── Image_VLM_Finetuning.py # Main Streamlit application
├── README.md # This file
├── src/
│ ├── image_vlm_config.yaml # Configuration file (update finetuned_model_id after training)
│ └── styles.css # Custom UI styling
├── assets/
│ └── image_vlm/
│ └── images/
│ ├── wildfire/ # Wildfire-affected images
│ └── nowildfire/ # Non-wildfire images
├── assets/
│ └── inference_screenshot.png # UI demonstration screenshot
└── saved_model/ # Training checkpoints directory (update config to point here)
```
## Configuration
The `src/image_vlm_config.yaml` file contains all training and inference settings:

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

View File

@ -0,0 +1,16 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@ -0,0 +1,50 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
inference:
model_id: unsloth/Qwen2.5-VL-7B-Instruct
finetuned_model_id: RLakshmi24/qwen_wildfire_qrpo_lora
max_seq_length: 16384
train:
model:
model_id: unsloth/Qwen2.5-VL-7B-Instruct
max_seq_length: 16384
use_lora: true
use_qlora: false
lora_config:
rank: 32
alpha: 64
dropout: 0.05
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
data:
dataset_id: data
hyperparameters:
epochs: 10
batch_size: 2
enable_grpo: true
num_generations: 4
format_reward: 2.0
learning_rate: 1e-5
correctness_reward: 5.0
optimizer: adamw_torch
output_dir: saved_model

View File

@ -0,0 +1,178 @@
/*
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
*/
/* VLM Fine-tuning Streamlit App Styles */
:root {
--nvidia-green: #76b900;
}
/* Set maximum page width - multiple selectors for compatibility */
.main .block-container,
.block-container,
.main > div,
section.main > div {
max-width: 1200px !important;
margin: 0 auto !important;
padding-left: 1rem !important;
padding-right: 1rem !important;
}
/* h3 {
color:#76b900 !important;
} */
.main > div {
padding-top: 1rem;
}
/* Global button styling - applies to ALL buttons */
.stButton > button,
button[kind="primary"],
button[kind="secondary"] {
width: 100% !important;
font-size: 16px !important;
font-weight: 600 !important;
padding: 12px 20px !important;
border-radius: 8px !important;
border: none !important;
transition: all 0.3s ease !important;
cursor: pointer !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
}
/* Primary button styling */
.stButton > button[kind="primary"],
button[kind="primary"] {
background: linear-gradient(135deg, #32864c 0%, #6edd93 100%) !important;
/* background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; */
color: white !important;
}
/* Primary button hover effect */
.stButton > button[kind="primary"]:hover,
button[kind="primary"]:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
}
/* Secondary button styling */
.stButton > button[kind="secondary"],
button[kind="secondary"] {
background: linear-gradient(135deg, #b6fb93 0%, #57f586 100%) !important;
/* background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; */
/* color: white !important; */
color: darkslategray !important;
}
/* Regular button styling */
.stButton > button:not([kind]) {
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%) !important;
color: white !important;
}
/* Button hover effect for all buttons */
.stButton > button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
}
/* Button active/pressed state */
.stButton > button:active {
transform: translateY(0px) !important;
}
/* Metric card styling */
.metric-card {
background-color: #f0f2f6;
padding: 1rem;
border-radius: 0.5rem;
border-left: 4px solid #ff6b6b;
}
/* Center images and maintain aspect ratio */
.stImage > img {
display: block;
margin-left: auto;
margin-right: auto;
max-width: 500px;
height: auto;
max-height: 300px
}
/* Larger text input for questions */
.stTextInput > div > div > input {
font-size: 18px !important;
padding: 10px !important;
}
/* Larger inference button */
.inference-button button,
.inference-button .stButton > button {
padding: 20px !important;
height: 200px !important;
min-height: 100% !important;
max-height: 100% !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
line-height: 1.2 !important;
white-space: nowrap !important;
font-size: large !important;
}
.card {
background-color: #1F2022;
color: #bdc3c7;
padding: 1.5rem;
border-radius: 0.5rem;
/* outline: 2px solid var(--nvidia-green); */
box-shadow: 0 2px 8px rgba(0,0,0,0.10);
text-align: left;
height: 300px;
overflow-y: auto;
margin-bottom: 1rem;
}
.card--empty {
text-align: center;
font-style: italic;
height: 300px;
display: flex;
align-items: center;
justify-content: center;
}
/* Wandb iframe wrapper for cropping */
.wandb-wrapper {
overflow: hidden;
position: relative;
border-radius: 15px;
height: var(--visible-height, 320px);
}
/* Wandb iframe styling */
.wandb-iframe {
border: none;
width: 100%;
position: absolute;
left: 0;
height: var(--iframe-height, 600px);
top: var(--shift-up, -230px);
}

View File

@ -0,0 +1,25 @@
data:
dataset_id: data
hyperparameters:
batch_size: 1
correctness_reward: 5.0
enable_grpo: true
epochs: 2
format_reward: 2.0
learning_rate: 1.0e-05
num_generations: 4
optimizer: adamw_torch
output_dir: saved_model
model:
finetune_attention_modules: true
finetune_language_layers: true
finetune_mlp_modules: true
finetune_vision_layers: true
lora_config:
alpha: 64
dropout: 0.05
rank: 16
max_seq_length: 16384
model_id: unsloth/Qwen2.5-VL-7B-Instruct
use_lora: true
use_qlora: false

View File

@ -0,0 +1,204 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unsloth import FastVisionModel
import re
import yaml
from PIL import ImageFile
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
ImageFile.LOAD_TRUNCATED_IMAGES = True
REASONING_START = "<REASONING>"
REASONING_END = "</REASONING>"
SOLUTION_START = "<SOLUTION>"
SOLUTION_END = "</SOLUTION>"
def load_model_for_train(config):
model, tokenizer = FastVisionModel.from_pretrained(
model_name=config["model"]["model_id"],
max_seq_length=config["model"]["max_seq_length"],
load_in_4bit=config["model"]["use_qlora"],
)
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers=config["model"]["finetune_vision_layers"],
finetune_language_layers=config["model"]["finetune_language_layers"],
finetune_attention_modules=config["model"]["finetune_attention_modules"],
finetune_mlp_modules=config["model"]["finetune_mlp_modules"],
r=config["model"]["lora_config"]["rank"],
lora_alpha=config["model"]["lora_config"]["alpha"],
lora_dropout=config["model"]["lora_config"]["dropout"],
bias="none",
random_state=42,
use_rslora=False,
loftq_config=None,
use_gradient_checkpointing="unsloth",
)
return model, tokenizer
def format_instruction(sample, label_dict):
label = label_dict.int2str(sample["label"])
if label == "nowildfire":
answer = "No"
else:
answer = "Yes"
# reasoning prompt
prompt = "Identify if this region has been affected by a wildfire"
prompt = (
f"{prompt}. Also first provide your reasoning or working out"\
f" on how you would go about identifying the presence of wildfire affected regions between {REASONING_START} and {REASONING_END}"
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
)
# convert image format
image = sample["image"]
if image.mode != "RGB":
image = image.convert("RGB")
# construct instruction prompt
prompt = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
]
return {"prompt": prompt, "image": sample["image"], "answer": answer}
def load_wildfire_dataset(config, tokenizer):
# load dataset
train_dataset = load_dataset(config["data"]["dataset_id"])["train"]
# preprocess the dataset
train_dataset = train_dataset.map(
lambda sample: format_instruction(sample, train_dataset.features["label"]),
num_proc=8)
train_dataset = train_dataset.map(
lambda sample: {
"prompt": tokenizer.apply_chat_template(
sample["prompt"],
tokenize=False,
add_generation_prompt=True,
)
}, num_proc=8)
return train_dataset
def format_reward_func(completions, **kwargs):
thinking_pattern = f'{REASONING_START}(.*?){REASONING_END}'
answer_pattern = f'{SOLUTION_START}(.*?){SOLUTION_END}'
scores = []
for completion in completions:
score = 0
thinking_matches = re.findall(thinking_pattern, completion, re.DOTALL)
answer_matches = re.findall(answer_pattern, completion, re.DOTALL)
if len(thinking_matches) == 1:
score += 1.0
if len(answer_matches) == 1:
score += 1.0
# penalize excessive addCriterion predictions in qwen2.5-vl
if len(completion) != 0:
removal = completion.replace("addCriterion", "").replace("\n", "")
if (len(completion)-len(removal))/len(completion) >= 0.5:
score -= 1.0
scores.append(score)
return scores
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
answer_pattern = f'{SOLUTION_START}(.*?){SOLUTION_END}'
responses = [re.findall(answer_pattern, completion, re.DOTALL) for completion in completions]
q = prompts[0]
print("----------------------------------")
print(f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{completions[0]}")
return [
5.0 if len(r)==1 and a == r[0].replace('\n','') else 0.0
for r, a in zip(responses, answer)
]
def start_train(config):
# load base model for finetuning
model, tokenizer = load_model_for_train(config)
# load dataset
train_dataset = load_wildfire_dataset(config, tokenizer)
# define training arguments
training_args = GRPOConfig(
learning_rate=config["hyperparameters"]["learning_rate"],
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_torch",
logging_steps=1,
log_completions=False,
per_device_train_batch_size=config["hyperparameters"]["batch_size"],
gradient_accumulation_steps=1,
num_generations=2,
max_prompt_length=config["model"]["max_seq_length"],
max_completion_length=config["model"]["max_seq_length"],
num_train_epochs=config["hyperparameters"]["epochs"],
save_steps=100,
max_grad_norm=0.1,
report_to="none",
output_dir=config["hyperparameters"]["output_dir"],
# importance_sampling_level="sequence",
# mask_truncated_completions=False,
# loss_type="dr_grpo",
)
# start training
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
reward_funcs=[
format_reward_func,
correctness_reward_func,
],
args=training_args,
)
trainer.train()
if __name__ == "__main__":
with open("src/train.yaml", "r") as f:
config = yaml.safe_load(f)
start_train(config)

View File

@ -0,0 +1,3 @@
[theme]
base="dark"
greenTextColor = "#76b900"

View File

@ -0,0 +1,137 @@
# Video VLM Fine-tuning with InternVL3
This project demonstrates fine-tuning the InternVL3 model for video analysis, specifically for dangerous driving detection and structured metadata generation from driving videos.
## Workflow Overview
![Training Workflow](assets/training_video.png)
### Training Workflow Steps:
1. **🎥 Dashcam Footage**: Dashcam footage from the Nexar Collision Prediction dataset
2. **Generate Structed caption**: Leverage a very large VLM (InternVL3-78B) to generate structured captions from raw videos
3. **🧠 Train InternVL3 Model**: Perform Supervised Finetuning on InternVL3-8B to extract structured metadata
4. **🚀 Fine-tuned VLM**: Trained model ready for analysing driver behaviour and risk factors
## Training
### Data Requirements
Your dataset should be structured as follows:
```
dataset/
├── videos/
│ ├── video1.mp4
│ ├── video2.mp4
│ └── ...
└── metadata.jsonl # Contains video paths and labels
```
Each line in `metadata.jsonl` should contain:
```json
{
"video": "videos/video1.mp4",
"caption": "Description of the video events",
"event_type": "collision" | "near_miss" | "no_incident",
"rule_violations": choose relevant items from ["speeding", "failure_to_yield", "ignoring_traffic_signs"],
"intended_driving_action": "turn_left" | "turn_right" | "change_lanes",
"traffic_density": "low" | "high",
"visibility": "good" | "bad"
}
```
### Running Training
1. **Update Dataset Path**: Edit the training notebook to point to your dataset:
```python
dataset_path = "/path/to/your/dataset"
```
2. **Run Training Notebook**:
```bash
# Inside the container, navigate to the training directory
cd ui_video/train
jupyter notebook video_vlm.ipynb
```
3. **Monitor Training**: Training progress and metrics are displayed directly in the notebook interface.
### Training Configuration
Key training parameters configurable:
- **Model**: InternVL3-8B
- **Video Frames**: 12 to 16 frames per video
- **Sampling Mode**: Uniform temporal sampling
- **LoRA Configuration**: Efficient parameter updates for large-scale fine-tuning
- **Hyperparameters**: Exhaustive suite of hyperparameters to tune for video VLM finetuning
## Inference
### Running Inference
1. **Streamlit Web Interface**:
```bash
# Start the interactive web interface
cd ui_video
streamlit run Video_VLM.py
```
The interface provides:
- Dashcam video gallery and playback
- Side-by-side comparison between base and finetuned model
- JSON output generation
- Tabular view of structured data extracted for analysis
2. **Configuration**: Edit `src/video_vlm_config.yaml` to modify model settings, frame count, and sampling strategy.
### Sample Output
The model generates structured JSON output like:
```json
{
"caption": "A vehicle makes a dangerous lane change without signaling while speeding on a highway during daytime with clear weather conditions.",
"event_type": "near_miss",
"cause_of_risk": ["speeding", "risky_maneuver"],
"presence_of_rule_violations": ["failure_to_use_turn_signals"],
"intended_driving_action": ["change_lanes"],
"traffic_density": "medium",
"driving_setting": ["highway"],
"time_of_day": "day",
"light_conditions": "normal",
"weather": "clear",
"scene": "highway"
}
```
Inference Screenshot
![WebUI Inference](assets/inference_screenshot.png)
## File Structure
```
ui_video/
├── README.md # This file
├── Video_VLM.py # Streamlit web interface for inference
├── src/
│ ├── styles.css # CSS styling for Streamlit app
│ └── video_vlm_config.yaml # Model and inference configuration
├── train/
│ └── video_vlm.ipynb # Jupyter notebook for model training
└── assets/
└── video_vlm/
├── videos/ # Sample video files
└── thumbnails/ # Video thumbnail previews
# Root directory also contains:
├── Dockerfile # Multi-stage Docker build with FFmpeg/Decord
└── launch.sh # Docker launch script
```
## Model Capabilities
The fine-tuned InternVL3 model can:
- **Video Analysis**: Process multi-frame dashcam footage for comprehensive scene understanding
- **Safety Detection**: Identify dangerous driving patterns, near-misses, and traffic violations
- **Structured Output**: Generate JSON metadata with standardized driving scene categories

View File

@ -0,0 +1,413 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import yaml
import string
import random
import torch
import numpy as np
import pandas as pd
from PIL import Image
import streamlit as st
import torchvision.transforms as T
from decord import VideoReader, cpu
from transformers import AutoTokenizer, AutoModel
from torchvision.transforms.functional import InterpolationMode
SCAP_PROMPT = """You are a vision-language assistant analyzing driving videos. You will receive a 5-second video clip of a specific scene. {prompt}
---
### Task 1: Dense Caption
Generate a 2 sentence caption describing:
- Ego vehicle behavior
- Interactions with other vehicles or pedestrians
Focus on **what happens**, **when**, and **who/what is involved**, using only visible information and metadata.
---
### Task 2: Structured JSON
Generate the caption from the perspective of the ego vehicle in a structured JSON object with:
- `"caption"`: from Task 1
- `"event_type"`: "collision" | "near_miss" | "no_incident"
- `"rule_violations"`: choose relevant items from ["speeding", "failure_to_yield", "ignoring_traffic_signs"]
- `"intended_action"`: "turn_left" | "turn_right" | "change_lanes"
- `"traffic_density"`: "low" | "high"
- `"visibility"`: "good" | "bad"
- `"scene"`: "Urban" | "Sub-urban" | "Rural" | "Highway"
**Rules:**
1. Use only visible info and metadata.
2. Do not invent details.
3. Include all fields; enum values must match allowed options.
4. Output a single valid JSON objectno extra text or markdown.
"""
def random_id():
chars = string.ascii_letters + string.digits
return "".join(random.choices(chars, k=8)).lower()
def initialize_session_state(resources):
# Initialize page-specific session state
st.session_state["base_video_vlm"] = st.session_state.get("base_video_vlm", resources["base"])
st.session_state["finetuned_video_vlm"] = st.session_state.get("finetuned_video_vlm", resources["finetuned"])
st.session_state["current_sample"] = st.session_state.get("current_sample", None)
st.session_state["df"] = st.session_state.get("df",
pd.DataFrame(columns=[
"Driver ID",
"Event Type",
"Rule Violations",
"Intended Action",
"Traffic Density",
"Driving Scene",
"Visibility"]))
def load_config():
config_key = "config_video_vlm"
if getattr(st.session_state, config_key, None) is None:
with open("src/video_vlm_config.yaml", "r") as f:
config = yaml.safe_load(f)
setattr(st.session_state, config_key, config)
else:
config = getattr(st.session_state, config_key)
return config
@st.cache_resource
def initialize_resources(inference_config):
base_model = load_model_for_inference(inference_config, "base")
finetuned_model = load_model_for_inference(inference_config, "finetuned")
return {
"base": {"model": base_model},
"finetuned": {"model": finetuned_model},
}
def main():
# set page ui
st.title(":green[Video VLM on DGX Spark]")
st.caption("Driver Behavior Analysis via Structured Video Captioning")
# st.page_link("https://github.com/your-username/your-repo", label="GitHub", icon=":material/github:")
# load css
with open("src/styles.css", "r") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# load resources
config = load_config()
if st.session_state.get("base_video_vlm", None) is None:
st.toast("Loading model", icon="", duration="short")
resource = initialize_resources(config["inference"])
if st.session_state.get("base_video_vlm", None) is None:
st.toast("Model loaded", icon="", duration="short")
initialize_session_state(resource)
# gallery section
st.markdown("---")
st.header("Dashcam Gallery")
columns = st.columns([4, 1, 1, 4, 1, 1, 4, 1], gap="small")
with columns[0]:
st.image("assets/video_vlm/thumbnails/1.png")
with columns[1]:
if st.button(":material/file_open:", key="video_1"):
st.session_state["current_sample"] = "assets/video_vlm/videos/1.mp4"
with columns[3]:
st.image("assets/video_vlm/thumbnails/2.png")
with columns[4]:
if st.button(":material/file_open:", key="video_2"):
st.session_state["current_sample"] = "assets/video_vlm/videos/2.mp4"
with columns[6]:
st.image("assets/video_vlm/thumbnails/3.png")
with columns[7]:
if st.button(":material/file_open:", key="video_3"):
st.session_state["current_sample"] = "assets/video_vlm/videos/3.mp4"
# inference section
st.markdown("---")
st.header("Video Inference")
with st.container(border=True):
if st.session_state["current_sample"]:
st.video(st.session_state["current_sample"], autoplay=True, loop=True, muted=True)
else:
st.write(":gray[Please select a video from the dashcam gallery.]")
columns = st.columns(2, gap="small")
with columns[0]:
with st.container(border=True):
st.write("##### :green[Base InternVL3-8B]")
base_generation = st.empty()
base_generation.write("...")
with columns[1]:
with st.container(border=True):
st.write("##### :green[Finetuned InternVL3-8B]")
finetuned_generation = st.empty()
finetuned_generation.write("...")
columns = st.columns([9, 1], gap="small")
with columns[0]:
prompt = st.text_input(
"Prompt Input",
label_visibility="collapsed",
key="prompt_input",
on_change=lambda: st.session_state.update(prompt=st.session_state.prompt_input)
)
with columns[1]:
if st.button("Generate", width="stretch"):
if st.session_state.get("prompt", None):
st.session_state["prompt"] = prompt
with st.spinner("Running..."):
response = start_inference("base")
base_generation.markdown(response)
with st.spinner("Running..."):
response = start_inference("finetuned")
finetuned_generation.markdown(response)
response = json.loads(response[7: -3].strip())
response["caption"] = random_id() # replace caption with driver id
st.session_state["df"].loc[len(st.session_state["df"])] = list(response.values())
# data analysis section
st.markdown("---")
st.header("Data Analysis")
with st.container(border=True):
st.dataframe(st.session_state["df"])
class InternVLVideoProcessor:
def __init__(self):
self.frame_size = 448
self.transform = T.Compose([
T.Resize(self.frame_size, interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * self.frame_size * self.frame_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(self, frame, min_num=1, max_num=12, use_thumbnail=False):
orig_width, orig_height = frame.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, self.frame_size)
# calculate the target width and height
target_width = self.frame_size * target_aspect_ratio[0]
target_height = self.frame_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = frame.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // self.frame_size)) * self.frame_size,
(i // (target_width // self.frame_size)) * self.frame_size,
((i % (target_width // self.frame_size)) + 1) * self.frame_size,
((i // (target_width // self.frame_size)) + 1) * self.frame_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = frame.resize((self.frame_size, self.frame_size))
processed_images.append(thumbnail_img)
return processed_images
def load_video(self, video_path, num_frames, start_frame=None, end_frame=None):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
if start_frame is None:
start_frame = 0
if end_frame is None:
end_frame = len(vr) - 1
# sample a random number of equally-spaced frames from the video
frame_indices = np.linspace(
start_frame,
end_frame,
num_frames,
dtype=int
)
pixel_values_list, num_patches_list = [], []
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
img = self.dynamic_preprocess(img, use_thumbnail=True, max_num=1)
pixel_values = [self.transform(tile) for tile in img]
pixel_values = torch.stack(pixel_values)
num_patches_list.append(pixel_values.shape[0])
pixel_values_list.append(pixel_values)
pixel_values = torch.cat(pixel_values_list)
return pixel_values, num_patches_list
class InternVLModel:
def __init__(self, model_path):
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
use_flash_attn=False,
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
use_fast=False
)
self.processor = InternVLVideoProcessor()
self.generation_config = dict(
max_new_tokens=1024,
do_sample=False,
temperature=0.0,
num_beams=1,
top_k=1,
top_p=1.0)
@torch.no_grad()
def infer(self, video_path, prompt, sampling_mode, num_frames=32, chunk_duration=2):
if sampling_mode == "default":
return self._infer_default(video_path, prompt, num_frames)
elif sampling_mode == "real-time":
return self._infer_realtime(video_path, prompt, num_frames, chunk_duration)
def _infer_default(self, video_path, prompt, num_frames):
pixel_values, num_patches_list = self.processor.load_video(video_path, num_frames)
pixel_values = pixel_values.to(device=self.model.device, dtype=torch.bfloat16)
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
prompt = f"{video_prefix}{prompt}"
response = self.model.chat(
self.tokenizer,
pixel_values,
prompt,
self.generation_config,
num_patches_list=num_patches_list
)
return response
def _infer_realtime(self, video_path, prompt, num_frames, chunk_duration):
video = VideoReader(video_path, ctx=cpu(0), num_threads=1)
fps = video.get_avg_fps()
total_frames = len(video)
frames_per_chunk = int(chunk_duration * fps)
for start_frame in range(0, total_frames, frames_per_chunk):
end_frame = start_frame + frames_per_chunk
if end_frame > total_frames:
return
pixel_values, num_patches_list = self.processor.load_video(video_path, num_frames, start_frame, end_frame)
pixel_values = pixel_values.to(device=self.model.device, dtype=torch.bfloat16)
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
prompt = f"{video_prefix}{prompt}"
response = self.model.chat(
self.tokenizer,
pixel_values,
prompt,
self.generation_config,
num_patches_list=num_patches_list
)
yield response
def load_model_for_inference(config, model_type):
if model_type == "finetuned":
model_name = config["finetuned_model_id"]
elif model_type == "base":
model_name = config["model_id"]
else:
raise ValueError(f"Invalid model type: {model_type}")
return InternVLModel(model_name)
@torch.no_grad()
def start_inference(model_type):
# define prompt
prompt = st.session_state["prompt"]
if model_type == "finetuned":
prompt = SCAP_PROMPT.format(prompt=prompt)
response = st.session_state[f"{model_type}_video_vlm"]["model"].infer(
st.session_state["current_sample"],
prompt,
num_frames=st.session_state["config_video_vlm"]["inference"]["num_frames"],
sampling_mode=st.session_state["config_video_vlm"]["inference"]["sampling_mode"]
)
if model_type == "finetuned":
response = f"```json\n{json.dumps(json.loads(response), indent=2)}\n```"
return response
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 554 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Some files were not shown because too many files have changed in this diff Show More