chore: Regenerate all playbooks
39
nvidia/flux-finetuning/assets/Dockerfile.inference
Normal file
@ -0,0 +1,39 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3
|
||||
|
||||
ARG HF_TOKEN
|
||||
|
||||
RUN cd /workspace/ && \
|
||||
git clone https://github.com/comfyanonymous/ComfyUI.git && \
|
||||
cd ComfyUI && \
|
||||
git checkout 4ffea0e864275301329ddb5ecc3fbc7211d7a802 && \
|
||||
sed -i '/torch/d' requirements.txt && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install torchsde && \
|
||||
mkdir -p /workspace/ComfyUI/user/default/workflows/
|
||||
|
||||
COPY . /workspace/sd-scripts
|
||||
|
||||
RUN hf download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models/vae && \
|
||||
hf download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/checkpoints && \
|
||||
hf download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/text_encoders && \
|
||||
hf download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/text_encoders && \
|
||||
hf download RLakshmi24/flux-dreambooth-lora-tj-spark flux_dreambooth.safetensors --local-dir models/loras && \
|
||||
cp /workspace/sd-scripts/workflows/finetuned_flux.json /workspace/ComfyUI/user/default/workflows/
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
42
nvidia/flux-finetuning/assets/Dockerfile.train
Normal file
@ -0,0 +1,42 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3
|
||||
|
||||
ARG HF_TOKEN
|
||||
|
||||
RUN cd /workspace/ && \
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git && \
|
||||
cd sd-scripts && \
|
||||
git checkout sd3 && \
|
||||
pip install -r requirements.txt && \
|
||||
apt update && \
|
||||
apt install -y libgl1-mesa-dev
|
||||
|
||||
COPY . /workspace/sd-scripts
|
||||
|
||||
RUN hf auth login --token $HF_TOKEN
|
||||
|
||||
RUN cd /workspace/sd-scripts/ && \
|
||||
hf download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models && \
|
||||
hf download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models && \
|
||||
hf download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models && \
|
||||
hf download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models
|
||||
|
||||
RUN cd /workspace/sd-scripts
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
203
nvidia/flux-finetuning/assets/README.md
Normal file
@ -0,0 +1,203 @@
|
||||
# FLUX.1 Fine-tuning with LoRA
|
||||
|
||||
This project demonstrates fine-tuning the FLUX.1-dev 11B model using Dreambooth LoRA (Low-Rank Adaptation) for custom image generation. The demo includes training on custom concepts and inference through both command-line scripts and ComfyUI.
|
||||
|
||||
## Results
|
||||
|
||||
Fine-tuning FLUX.1 with custom concepts enables the model to generate images with your specific objects and styles:
|
||||
|
||||
<figure>
|
||||
<img src="assets/before_finetuning.png" alt="Before Fine-tuning" width="400"/>
|
||||
<figcaption>Base FLUX.1 model without custom concept knowledge</figcaption>
|
||||
</figure>
|
||||
|
||||
<br>
|
||||
|
||||
<figure>
|
||||
<img src="assets/after_finetuning.png" alt="After Fine-tuning" width="400"/>
|
||||
<figcaption>FLUX.1 model after LoRA fine-tuning with custom "tjtoy" and "sparkgpu" concepts</figcaption>
|
||||
</figure>
|
||||
|
||||
## Overview
|
||||
|
||||
The project includes:
|
||||
- **FLUX.1-dev Fine-tuning**: LoRA-based fine-tuning using sd-scripts
|
||||
- **Custom Concept Training**: Train on "tjtoy" toy and "sparkgpu" GPU
|
||||
- **Command-line Inference**: Generate images using trained LoRA weights
|
||||
- **ComfyUI Integration**: Intuitive workflows for inference with custom models
|
||||
- **Docker Support**: Complete containerized environment
|
||||
|
||||
## Training
|
||||
|
||||
### 1. Build Docker Image by providing `HF_TOKEN`
|
||||
|
||||
```bash
|
||||
# Build the Docker image (this will download FLUX models automatically)
|
||||
docker build -f Dockerfile.train --build-arg HF_TOKEN=$HF_TOKEN -t flux-training .
|
||||
```
|
||||
|
||||
**Note**: The Docker build automatically downloads the required FLUX models:
|
||||
- `flux1-dev.safetensors` (~23GB)
|
||||
- `ae.safetensors` (~335MB)
|
||||
- `clip_l.safetensors` (~246MB)
|
||||
- `t5xxl_fp16.safetensors` (~9.8GB)
|
||||
|
||||
### 2. Run Docker Container
|
||||
|
||||
```bash
|
||||
# Run with GPU support and mount current directory
|
||||
docker run -it \
|
||||
--gpus all \
|
||||
--ipc=host \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
--net=host \
|
||||
flux-training
|
||||
```
|
||||
|
||||
### 3. Train the Model
|
||||
|
||||
```bash
|
||||
# Inside the container, navigate to sd-scripts and run training
|
||||
cd /workspace/sd-scripts
|
||||
sh train.sh
|
||||
```
|
||||
|
||||
### 4. Run Inference
|
||||
The `inference.sh` script generates 9 images with different seeds.
|
||||
|
||||
After training, you can generate images using the learned concepts. For example:
|
||||
- `"tjtoy toy"` - Your custom toy concept
|
||||
- `"sparkgpu gpu"` - Your custom GPU concept
|
||||
- Combine them: `"tjtoy toy holding sparkgpu gpu"`
|
||||
|
||||
```bash
|
||||
# Generate images using the trained LoRA
|
||||
sh inference.sh
|
||||
```
|
||||
|
||||
### Dataset Structure
|
||||
|
||||
The training data is organized in the `data/` directory:
|
||||
|
||||
```
|
||||
data/
|
||||
├── data.toml # Training configuration
|
||||
├── tjtoy/ # Custom toy concept images (6 images)
|
||||
│ ├── 1.png
|
||||
│ ├── 2.jpg
|
||||
│ ├── 3.png
|
||||
│ ├── 4.png
|
||||
│ ├── 5.png
|
||||
│ └── 6.png
|
||||
└── sparkgpu/ # Custom GPU concept images (7 images)
|
||||
├── 1.jpeg
|
||||
├── 2.jpg
|
||||
├── 3.jpg
|
||||
├── 4.jpg
|
||||
├── 6.png
|
||||
├── 7.png
|
||||
└── 8.png
|
||||
```
|
||||
|
||||
### Training Parameters
|
||||
|
||||
Key training settings in `train.sh`:
|
||||
- **Network Type**: LoRA with dimension 256
|
||||
- **Learning Rate**: 1.0 (with Prodigy optimizer)
|
||||
- **Epochs**: 100 (saves every 25 epochs)
|
||||
- **Resolution**: 1024x1024
|
||||
- **Mixed Precision**: bfloat16
|
||||
- **Optimizations**: Torch compile, gradient checkpointing, cached latents
|
||||
|
||||
## ComfyUI
|
||||
|
||||
ComfyUI provides an intuitive visual interface for using your fine-tuned LoRA models. The beauty of LoRA fine-tuning is that you can easily add your custom concepts to any FLUX workflow with just a single node.
|
||||
|
||||
### 1. Build Docker Image by providing `HF_TOKEN`
|
||||
|
||||
```bash
|
||||
# Build the Docker image (this will download FLUX models automatically)
|
||||
docker build -f Dockerfile.inference --build-arg HF_TOKEN=$HF_TOKEN -t flux-comfyui .
|
||||
```
|
||||
|
||||
### 2. Run Docker Container
|
||||
|
||||
```bash
|
||||
# Run with GPU support and mount current directory
|
||||
docker run -it \
|
||||
--gpus all \
|
||||
--ipc=host \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
--net=host \
|
||||
flux-comfyui
|
||||
```
|
||||
|
||||
### 3. Running ComfyUI
|
||||
|
||||
```bash
|
||||
# Start ComfyUI server
|
||||
cd /workspace/ComfyUI
|
||||
python main.py
|
||||
```
|
||||
|
||||
Access ComfyUI at `http://localhost:8188`
|
||||
|
||||
### 4. ComfyUI Workflow Example
|
||||
|
||||

|
||||
*ComfyUI workflow showing how easily LoRA can be integrated into the base FLUX model*
|
||||
|
||||
The workflow demonstrates the simplicity of LoRA integration:
|
||||
1. **Load Checkpoint**: Base FLUX.1-dev model remains unchanged
|
||||
2. **Load LoRA**: Simply add your trained LoRA file (`flux_dreambooth.safetensors`)
|
||||
3. **Adjust Strength**: Fine-tune the influence of your custom concepts (0.8-1.2 typically works well)
|
||||
4. **Generate**: Use your custom trigger words (`tjtoy toy`, `sparkgpu gpu`) in prompts
|
||||
|
||||
This modular approach means you can:
|
||||
- **Preserve base model quality**: The original FLUX capabilities remain intact
|
||||
- **Easy experimentation**: Quickly swap different LoRA models or adjust strengths
|
||||
- **Combine concepts**: Mix multiple LoRA models or use them with other techniques
|
||||
- **Minimal storage**: LoRA files are typically 100-200MB vs 23GB+ for full models
|
||||
|
||||
### ComfyUI Model Structure
|
||||
|
||||
Organize models in ComfyUI as follows:
|
||||
|
||||
```
|
||||
ComfyUI/models/
|
||||
├── checkpoints/
|
||||
│ └── flux1-dev.safetensors # Main FLUX model
|
||||
├── vae/
|
||||
│ └── ae.safetensors # FLUX VAE
|
||||
├── clip/
|
||||
│ ├── clip_l.safetensors # CLIP text encoder
|
||||
│ └── t5xxl_fp16.safetensors # T5 text encoder
|
||||
└── loras/
|
||||
└── flux_dreambooth.safetensors # Your trained LoRA
|
||||
```
|
||||
|
||||
## Custom Concepts
|
||||
|
||||
The fine-tuning process teaches FLUX.1 to understand two custom concepts:
|
||||
|
||||
### TJToy Concept
|
||||
- **Trigger phrase**: `tjtoy toy`
|
||||
- **Training images**: 6 high-quality images of custom toy figures
|
||||
- **Use case**: Generate images featuring the specific toy character in various scenes
|
||||
|
||||
### SparkGPU Concept
|
||||
- **Trigger phrase**: `sparkgpu gpu`
|
||||
- **Training images**: 7 images of custom GPU hardware
|
||||
- **Use case**: Generate images featuring the specific GPU design in different contexts
|
||||
|
||||
### Combined Usage
|
||||
You can combine both concepts in prompts:
|
||||
- `"tjtoy toy holding sparkgpu gpu"`
|
||||
- `"tjtoy toy standing next to sparkgpu gpu in a data center"`
|
||||
- `"sparkgpu gpu being examined by tjtoy toy"`
|
||||
|
||||
## Credits
|
||||
|
||||
This project uses [sd-scripts](https://github.com/kohya-ss/sd-scripts) repository by `kohya-ss` for FLUX.1 fine-tuning.
|
||||
38
nvidia/flux-finetuning/assets/flux_data/data.toml
Executable file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
[general]
|
||||
shuffle_caption = false
|
||||
keep_tokens = 2
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1024
|
||||
batch_size = 1
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "flux_data/tjtoy"
|
||||
class_tokens = "tjtoy toy"
|
||||
num_repeats = 1
|
||||
is_reg = false
|
||||
flip_aug = true
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "flux_data/sparkgpu"
|
||||
class_tokens = "sparkgpu gpu"
|
||||
num_repeats = 2
|
||||
is_reg = false
|
||||
flip_aug = true
|
||||
BIN
nvidia/flux-finetuning/assets/flux_data/sparkgpu/1.jpeg
Executable file
|
After Width: | Height: | Size: 36 KiB |
BIN
nvidia/flux-finetuning/assets/flux_data/sparkgpu/2.jpg
Executable file
|
After Width: | Height: | Size: 101 KiB |
BIN
nvidia/flux-finetuning/assets/flux_data/sparkgpu/3.jpg
Executable file
|
After Width: | Height: | Size: 59 KiB |
BIN
nvidia/flux-finetuning/assets/flux_data/sparkgpu/4.jpg
Executable file
|
After Width: | Height: | Size: 112 KiB |
BIN
nvidia/flux-finetuning/assets/flux_data/sparkgpu/6.png
Executable file
|
After Width: | Height: | Size: 3.0 MiB |
BIN
nvidia/flux-finetuning/assets/flux_data/sparkgpu/7.png
Executable file
|
After Width: | Height: | Size: 2.0 MiB |
BIN
nvidia/flux-finetuning/assets/flux_data/sparkgpu/8.png
Executable file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
nvidia/flux-finetuning/assets/flux_data/tjtoy/1.png
Executable file
|
After Width: | Height: | Size: 2.8 MiB |
BIN
nvidia/flux-finetuning/assets/flux_data/tjtoy/2.jpg
Executable file
|
After Width: | Height: | Size: 274 KiB |
BIN
nvidia/flux-finetuning/assets/flux_data/tjtoy/3.png
Executable file
|
After Width: | Height: | Size: 945 KiB |
BIN
nvidia/flux-finetuning/assets/flux_data/tjtoy/4.png
Executable file
|
After Width: | Height: | Size: 529 KiB |
BIN
nvidia/flux-finetuning/assets/flux_data/tjtoy/5.png
Executable file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
nvidia/flux-finetuning/assets/flux_data/tjtoy/6.png
Executable file
|
After Width: | Height: | Size: 1.6 MiB |
36
nvidia/flux-finetuning/assets/inference.sh
Normal file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
for SEED in $(seq 0 7); do
|
||||
python flux_minimal_inference.py \
|
||||
--ckpt_path="models/flux1-dev.safetensors" \
|
||||
--model_type="flux" \
|
||||
--clip_l="models/clip_l.safetensors" \
|
||||
--t5xxl="models/t5xxl_fp16.safetensors" \
|
||||
--ae="models/ae.safetensors" \
|
||||
--output_dir="outputs" \
|
||||
--lora_weights="saved_models/flux_dreambooth.safetensors" \
|
||||
--merge_lora_weights \
|
||||
--prompt="tjtoy toy holding sparkgpu gpu in a datacenter" \
|
||||
--width=1024 \
|
||||
--height=1024 \
|
||||
--steps=50 \
|
||||
--guidance=1.0 \
|
||||
--cfg_scale=1.0 \
|
||||
--seed=$SEED \
|
||||
--dtype="bfloat16"
|
||||
done
|
||||
52
nvidia/flux-finetuning/assets/train.sh
Normal file
@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
accelerate launch \
|
||||
--num_processes=1 --num_machines=1 --mixed_precision=bf16 \
|
||||
--main_process_ip=127.0.0.1 --main_process_port=29500 \
|
||||
--num_cpu_threads_per_process=2 \
|
||||
flux_train_network.py \
|
||||
--pretrained_model_name_or_path="models/flux1-dev.safetensors" \
|
||||
--clip_l="models/clip_l.safetensors" \
|
||||
--t5xxl="models/t5xxl_fp16.safetensors" \
|
||||
--ae="models/ae.safetensors" \
|
||||
--dataset_config="flux_data/data.toml" \
|
||||
--output_dir="saved_models" \
|
||||
--prior_loss_weight=1.0 \
|
||||
--output_name="flux_dreambooth" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_flux \
|
||||
--network_dim=256 \
|
||||
--network_alpha=256 \
|
||||
--learning_rate=1.0 \
|
||||
--optimizer_type="Prodigy" \
|
||||
--lr_scheduler="cosine_with_restarts" \
|
||||
--sdpa \
|
||||
--max_train_epochs=100 \
|
||||
--save_every_n_epochs=25 \
|
||||
--mixed_precision="bf16" \
|
||||
--guidance_scale=1.0 \
|
||||
--timestep_sampling="flux_shift" \
|
||||
--model_prediction_type="raw" \
|
||||
--torch_compile \
|
||||
--persistent_data_loader_workers \
|
||||
--cache_latents \
|
||||
--cache_latents_to_disk \
|
||||
--cache_text_encoder_outputs \
|
||||
--cache_text_encoder_outputs_to_disk \
|
||||
--gradient_checkpointing
|
||||
@ -45,7 +45,7 @@ GPU acceleration and performance optimization capabilities.
|
||||
[ ] Docker or container runtime installed
|
||||
[ ] NVIDIA Container Toolkit configured
|
||||
[ ] Verify GPU access: `nvidia-smi`
|
||||
[ ] Verify Docker GPU support: `docker run --gpus all nvidia/cuda:12.0-base-ubuntu20.04 nvidia-smi`
|
||||
[ ] Verify Docker GPU support: `docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi`
|
||||
[ ] Port 8080 available for marimo notebook access
|
||||
|
||||
## Ancillary files
|
||||
@ -92,17 +92,23 @@ If the `docker` command fails with a permission error, you can either
|
||||
|
||||
To add yourself to the `docker` group, first run `sudo usermod -aG docker $USER`. Then, as your user account, either run `newgrp docker` or log out and log back in.
|
||||
|
||||
## Step 2. Build a Docker image
|
||||
## Step 3. Clone the playbook repository
|
||||
|
||||
```bash
|
||||
git clone https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main
|
||||
```
|
||||
|
||||
## Step 3. Build the Docker image
|
||||
|
||||
|
||||
> **Warning:** This command will download a base image and build a container locally to support this environment
|
||||
|
||||
```bash
|
||||
cd jax-assets
|
||||
cd jax/assets
|
||||
docker build -t jax-on-spark .
|
||||
```
|
||||
|
||||
## Step 3. Launch Docker container
|
||||
## Step 4. Launch Docker container
|
||||
|
||||
Run the JAX development environment in a Docker container with GPU support and port forwarding for marimo access.
|
||||
|
||||
@ -113,7 +119,7 @@ docker run --gpus all --rm -it \
|
||||
jax-on-spark
|
||||
```
|
||||
|
||||
## Step 4. Access marimo interface
|
||||
## Step 5. Access marimo interface
|
||||
|
||||
Connect to the marimo notebook server to begin the JAX tutorial.
|
||||
|
||||
@ -124,7 +130,7 @@ Connect to the marimo notebook server to begin the JAX tutorial.
|
||||
|
||||
The interface will load a table-of-contents display and brief introduction to marimo.
|
||||
|
||||
## Step 5. Complete JAX introduction tutorial
|
||||
## Step 6. Complete JAX introduction tutorial
|
||||
|
||||
Work through the introductory material to understand JAX programming model differences from NumPy.
|
||||
|
||||
@ -133,7 +139,7 @@ Navigate to and complete the JAX introduction notebook, which covers:
|
||||
- Key differences from NumPy
|
||||
- Performance evaluation techniques
|
||||
|
||||
## Step 6. Implement NumPy baseline
|
||||
## Step 7. Implement NumPy baseline
|
||||
|
||||
Complete the NumPy-based self-organized map (SOM) implementation to establish a performance
|
||||
baseline.
|
||||
@ -143,7 +149,7 @@ Work through the NumPy SOM notebook to:
|
||||
- Implement the algorithm using familiar NumPy operations
|
||||
- Record performance metrics for comparison
|
||||
|
||||
## Step 7. Optimize with JAX implementations
|
||||
## Step 8. Optimize with JAX implementations
|
||||
|
||||
Progress through the iteratively refined JAX implementations to see performance improvements.
|
||||
|
||||
@ -153,7 +159,7 @@ Complete the JAX SOM notebook sections:
|
||||
- GPU-accelerated parallel JAX implementation
|
||||
- Compare performance across all versions
|
||||
|
||||
## Step 8. Validate performance gains
|
||||
## Step 9. Validate performance gains
|
||||
|
||||
The notebooks will show you how to check the performance of each SOM training implementation; you'll see that that JAX implementations show performance improvements over NumPy baseline (and some will be quite a lot faster).
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ def _(mo):
|
||||
The following video shows an example of training a self-organizing map from color data using a batch algorithm, from a random initial map to a relatively converged set of colorful clusters.
|
||||
"""
|
||||
))
|
||||
mo.output.append(mo.image("batch-som.mp4"))
|
||||
mo.output.append(mo.video("batch-som.mp4"))
|
||||
mo.output.append(mo.md(
|
||||
r"""
|
||||
|
||||
|
||||
76
nvidia/multi-agent-chatbot/assets/Dockerfile.llamacpp
Normal file
@ -0,0 +1,76 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
|
||||
ARG CUDA_VERSION=13.0.1
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvcr.io/nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
ARG CUDA_DOCKER_ARCH="121"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN git clone https://github.com/ggml-org/llama.cpp.git .
|
||||
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=121"; \
|
||||
fi && \
|
||||
cmake -B build -DGGML_CUDA_ENABLE_UNIFIED_MEMORY=1 -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
&& cp .devops/tools.sh /app/full/tools.sh
|
||||
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
|
||||
|
||||
FROM base
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8000/health" ]
|
||||
|
||||
ENTRYPOINT [ "/app/llama-server" ]
|
||||
CMD ["--host", "0.0.0.0", "--port", "8000"]
|
||||
105
nvidia/multi-agent-chatbot/assets/README.md
Normal file
@ -0,0 +1,105 @@
|
||||
# Chatbot Spark: A Local Multi-Agent System for DGX Spark
|
||||
|
||||
## Project Overview
|
||||
|
||||
Chatbot Spark is a fully local multi-agent system built on DGX Spark. With 128GB of unified memory, DGX Spark can run multiple LLMs and VLMs in parallel — enabling interactions across agents.
|
||||
|
||||
At the core is a supervisor agent powered by GPT-OSS-120B, orchestrating specialized downstream agents for coding, retrieval-augmented generation (RAG), and image understanding. Thanks to DGX Spark’s out-of-the-box support for popular AI frameworks and libraries, development and prototyping were fast and frictionless. Together, these components demonstrate how complex, multimodal workflows can be executed efficiently on local, high-performance hardware.
|
||||
|
||||
This project was built to be customizable, serving as a framework that developers can customize.
|
||||
|
||||
## Key Features
|
||||
- **MCP Server Integration**: Chatbot Spark also showcases the ability to connect to custom MCP servers through a simple and customizable multi-server client
|
||||
|
||||
- **Tool Calling**: This project uses an agents-as-tools framework and showcases the ability to create additional agents connected as tools. General tools can also be added.
|
||||
|
||||
- **Easily Swappable Models**: Models are loaded and served using Llama CPP and Ollama and served through the OpenAI API. Any OpenAI-compatible model can be integrated into the project.
|
||||
|
||||
- **Vector Indexing & Retrieval**: GPU-accelerated Milvus for high-performance document retrieval.
|
||||
|
||||
- **Real-time LLM Streaming**: We present custom LLM-streaming infrastructure, making it easy for developers to stream supervisor responses from any OpenAI compatible model.
|
||||
|
||||
- **gpt-oss Integration**: The default chat/tool-calling model is gpt-oss:120b, providing seamless integration with OpenAI's latest open sorce tool-calling model.
|
||||
|
||||
|
||||
## System Overview
|
||||
<img src="assets/system-diagram.png" alt="System Diagram" style="max-width:600px;border-radius:5px;justify-content:center">
|
||||
|
||||
## Default Models
|
||||
| Model | Quantization | Model Type | VRAM |
|
||||
|------------------------------|--------------|------------|-------------|
|
||||
| GPT-OSS:120B | MXFP4 | Chat | ~ 63.5 GB |
|
||||
| Deepseek-Coder:6.7B-Instruct | Q8 | Coding | ~ 9.5 GB |
|
||||
| Qwen2.5-VL:7B-Instruct | BF16 | Image | ~ 35.4 GB |
|
||||
| Qwen3-Embedding-4B | Q8 | Embedding | ~ 5.39 GB |
|
||||
|
||||
**Total VRAM required:** ~114 GB
|
||||
|
||||
> **Warning**:
|
||||
> Since the default models use majority of available VRAM, ensure that you don't have anything already running on DGX Spark using `nvidia-smi`. If you do, switch to `gpt-oss-20b` following [this guide](#using-different-models).
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
#### 1. Clone the repository and change directories to the multi-agent chatbot directory.
|
||||
|
||||
#### 2. Run the setup script
|
||||
The setup script will take care of pulling model GGUF files from HuggingFace, building base llama cpp server images and starting all the required docker services to serve models, the backend API server as well as the frontend UI.
|
||||
```bash
|
||||
chmod +x setup.sh
|
||||
./setup.sh
|
||||
```
|
||||
Wait for all the containers to become ready and healthy.
|
||||
```bash
|
||||
watch 'docker ps --format "table {{.ID}}\t{{.Names}}\t{{.Status}}"'
|
||||
```
|
||||
> Note: Downloading model files may take ~10 minutes and starting containers may take another 10 minutes depending on network speed. Look for "server is listening on http://0.0.0.0:8000" in the logs of model server containers.
|
||||
|
||||
|
||||
#### 3. Access the frontend UI
|
||||
|
||||
Open your browser and go to: [http://localhost:3000](http://localhost:3000)
|
||||
|
||||
> Note: If you are running this on a remote GPU via an ssh connection, in a new terminal window, you need to run to be able to access the UI at localhost:3000 and for the UI to be able to communicate to the backend at localhost:8000:
|
||||
>```bash
|
||||
> ssh -L 3000:localhost:3000 -L 8000:localhost:8000 username@IP-address
|
||||
>```
|
||||
|
||||
You should see the following UI in your browser:
|
||||
<img src="assets/multi-agent-chatbot.png" alt="Frontend UI" style="max-width:600px;border-radius:5px;justify-content:center">
|
||||
|
||||
### 4. Try out the sample prompts
|
||||
Click on any of the tiles on the frontend to try out the supervisor and the other agents.
|
||||
|
||||
#### RAG Agent:
|
||||
Before trying out the RAG agent, upload the example PDF document [NVIDIA Blackwell Whitepaper](https://images.nvidia.com/aem-dam/Solutions/geforce/blackwell/nvidia-rtx-blackwell-gpu-architecture.pdf) as context by clicking on the "Attach" icon in the text input space at the botton of the UI and then make sure to check the box in the "Select Sources" section on the left side of the UI.
|
||||
<img src="assets/upload-image.png" alt="Upload Image" style="max-width:300px;border-radius:5px;justify-content:center">
|
||||
<img src="assets/document-ingestion.png" alt="Ingest Documents" style="max-width:300px;border-radius:5px;justify-content:center">
|
||||
|
||||
#### Image Understanding Agent:
|
||||
You can either provide URLs or drag and drop images.
|
||||
|
||||
**Example Prompt:**
|
||||
|
||||
|
||||
Describe this image: https://en.wikipedia.org/wiki/London_Bridge#/media/File:London_Bridge_from_St_Olaf_Stairs.jpg
|
||||
|
||||
|
||||
## Customizations
|
||||
|
||||
### Using different models
|
||||
|
||||
You can use swap the model that the supervisor agent is using, for example to gpt-oss-20b.
|
||||
|
||||
1. In `setup.sh`, uncomment the line to download gpt-oss-20b.
|
||||
> Note: If you already have the model files downloaded, you can skip to step 2.
|
||||
2. In `docker-compose-models.yml`, uncomment the block for gpt-oss-20b.
|
||||
> Note: Since the default models use all of the existing VRAM, you will need to comment out the block for gpt-oss-120b in `docker-compose-models.yml`.
|
||||
3. In `docker-compose.yml`, add `gpt-oss-20b` to the `MODELS` environment variable (line 40).
|
||||
> Note: This name should match the container name that you set for this model in `docker-compose-models.yml`.
|
||||
|
||||
### Adding MCP servers and tools
|
||||
|
||||
1. You can add more MCP servers and tools under [backend/tools/mcp_servers](backend/tools/mcp_servers/) following existing examples.
|
||||
|
||||
2. If you added an MCP server, remember to add it to the server configs in [backend/client.py](backend/client.py)
|
||||
@ -0,0 +1 @@
|
||||
3.10
|
||||
33
nvidia/multi-agent-chatbot/assets/backend/Dockerfile
Normal file
@ -0,0 +1,33 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
FROM python:3.12
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN . $HOME/.local/bin/env
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY . .
|
||||
|
||||
RUN uv sync
|
||||
|
||||
CMD ["uv", "run", "--", "uvicorn", "main:app", "--reload", "--host", "0.0.0.0", "--port", "8000"]
|
||||
1
nvidia/multi-agent-chatbot/assets/backend/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Chatbot Backend API Server
|
||||
16
nvidia/multi-agent-chatbot/assets/backend/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
586
nvidia/multi-agent-chatbot/assets/backend/agent.py
Normal file
@ -0,0 +1,586 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""ChatAgent implementation for LLM-powered conversational AI with tool calling."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
from typing import AsyncIterator, List, Dict, Any, TypedDict, Optional, Callable, Awaitable
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage, SystemMessage, ToolMessage, ToolCall
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from client import MCPClient
|
||||
from logger import logger
|
||||
from prompts import Prompts
|
||||
from postgres_storage import PostgreSQLConversationStorage
|
||||
from utils import convert_langgraph_messages_to_openai
|
||||
|
||||
|
||||
memory = MemorySaver()
|
||||
SENTINEL = object()
|
||||
StreamCallback = Callable[[Dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
class State(TypedDict, total=False):
|
||||
iterations: int
|
||||
messages: List[AnyMessage]
|
||||
chat_id: Optional[str]
|
||||
image_data: Optional[str]
|
||||
|
||||
|
||||
class ChatAgent:
|
||||
"""Main conversational agent with tool calling and agent delegation capabilities.
|
||||
|
||||
This agent orchestrates conversation flow using a LangGraph state machine that can:
|
||||
- Generate responses using LLMs
|
||||
- Execute tool calls (including MCP tools)
|
||||
- Handle image processing
|
||||
- Manage conversation history via Redis
|
||||
"""
|
||||
|
||||
def __init__(self, vector_store, config_manager, postgres_storage: PostgreSQLConversationStorage):
|
||||
"""Initialize the chat agent.
|
||||
|
||||
Args:
|
||||
vector_store: VectorStore instance for document retrieval
|
||||
config_manager: ConfigManager for reading configuration
|
||||
postgres_storage: PostgreSQL storage for conversation persistence
|
||||
"""
|
||||
self.vector_store = vector_store
|
||||
self.config_manager = config_manager
|
||||
self.conversation_store = postgres_storage
|
||||
self.current_model = None
|
||||
|
||||
self.current_model = None
|
||||
self.max_iterations = 3
|
||||
|
||||
self.mcp_client = None
|
||||
self.openai_tools = None
|
||||
self.tools_by_name = None
|
||||
self.system_prompt = None
|
||||
|
||||
self.graph = self._build_graph()
|
||||
self.stream_callback = None
|
||||
self.last_state = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls, vector_store, config_manager, postgres_storage: PostgreSQLConversationStorage):
|
||||
"""
|
||||
Asynchronously creates and initializes a ChatAgent instance.
|
||||
|
||||
This factory method ensures that all async setup, like loading tools,
|
||||
is completed before the agent is ready to be used.
|
||||
"""
|
||||
agent = cls(vector_store, config_manager, postgres_storage)
|
||||
await agent.init_tools()
|
||||
|
||||
available_tools = list(agent.tools_by_name.values()) if agent.tools_by_name else []
|
||||
template_vars = {
|
||||
"tools": "\n".join([f"- {tool.name}: {tool.description}" for tool in available_tools]) if available_tools else "No tools available",
|
||||
}
|
||||
agent.system_prompt = Prompts.get_template("supervisor_agent").render(template_vars)
|
||||
|
||||
logger.debug(f"Agent initialized with {len(available_tools)} tools.")
|
||||
agent.set_current_model(config_manager.get_selected_model())
|
||||
return agent
|
||||
|
||||
async def init_tools(self) -> None:
|
||||
"""Initialize MCP client and tools with retry logic.
|
||||
|
||||
Sets up the MCP client, retrieves available tools, converts them to OpenAI format,
|
||||
and initializes specialized agents like the coding agent.
|
||||
"""
|
||||
self.mcp_client = await MCPClient().init()
|
||||
|
||||
base_delay, max_retries = 0.1, 10
|
||||
mcp_tools = []
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
mcp_tools = await self.mcp_client.get_tools()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP tools initialization attempt {attempt + 1} failed: {e}")
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(f"MCP servers not ready after {max_retries} attempts, continuing without MCP tools")
|
||||
mcp_tools = []
|
||||
break
|
||||
wait_time = base_delay * (2 ** attempt)
|
||||
await asyncio.sleep(wait_time)
|
||||
logger.info(f"MCP servers not ready, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})")
|
||||
|
||||
self.tools_by_name = {tool.name: tool for tool in mcp_tools}
|
||||
logger.debug(f"Loaded {len(mcp_tools)} MCP tools: {list(self.tools_by_name.keys())}")
|
||||
|
||||
if mcp_tools:
|
||||
mcp_tools_openai = [convert_to_openai_tool(tool) for tool in mcp_tools]
|
||||
logger.debug(f"MCP tools converted to OpenAI format: {mcp_tools_openai}")
|
||||
|
||||
self.openai_tools = [
|
||||
{"type": "function", "function": tool['function']}
|
||||
for tool in mcp_tools_openai
|
||||
]
|
||||
logger.debug(f"Final OpenAI tools format: {self.openai_tools}")
|
||||
else:
|
||||
self.openai_tools = []
|
||||
logger.warning("No MCP tools available - agent will run with limited functionality")
|
||||
|
||||
def set_current_model(self, model_name: str) -> None:
|
||||
"""Set the current model for completions.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to use
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not available
|
||||
"""
|
||||
available_models = self.config_manager.get_available_models()
|
||||
|
||||
try:
|
||||
if model_name in available_models:
|
||||
self.current_model = model_name
|
||||
logger.info(f"Switched to model: {model_name}")
|
||||
self.model_client = AsyncOpenAI(
|
||||
base_url=f"http://{self.current_model}:8000/v1",
|
||||
api_key="api_key"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {model_name} is not available. Available models: {available_models}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting current model: {e}")
|
||||
raise ValueError(f"Model {model_name} is not available. Available models: {available_models}")
|
||||
|
||||
def should_continue(self, state: State) -> str:
|
||||
"""Determine whether to continue the tool calling loop.
|
||||
|
||||
Args:
|
||||
state: Current graph state
|
||||
|
||||
Returns:
|
||||
"end" if no more tool calls or max iterations reached, "continue" otherwise
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return "end"
|
||||
|
||||
last_message = messages[-1]
|
||||
iterations = state.get("iterations", 0)
|
||||
has_tool_calls = bool(last_message.tool_calls) if hasattr(last_message, 'tool_calls') else False
|
||||
|
||||
logger.debug({
|
||||
"message": "GRAPH: should_continue decision",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"iterations": iterations,
|
||||
"max_iterations": self.max_iterations,
|
||||
"has_tool_calls": has_tool_calls,
|
||||
"tool_calls_count": len(last_message.tool_calls) if has_tool_calls else 0
|
||||
})
|
||||
|
||||
if iterations >= self.max_iterations:
|
||||
logger.debug({
|
||||
"message": "GRAPH: should_continue → END (max iterations reached)",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"final_message_preview": str(last_message)[:100] + "..." if len(str(last_message)) > 100 else str(last_message)
|
||||
})
|
||||
return "end"
|
||||
|
||||
if not has_tool_calls:
|
||||
logger.debug({"message": "GRAPH: should_continue → END (no tool calls)", "chat_id": state.get("chat_id")})
|
||||
return "end"
|
||||
|
||||
logger.debug({"message": "GRAPH: should_continue → CONTINUE (has tool calls)", "chat_id": state.get("chat_id")})
|
||||
return "continue"
|
||||
|
||||
async def tool_node(self, state: State) -> Dict[str, Any]:
|
||||
"""Execute tools from the last AI message's tool calls.
|
||||
|
||||
Args:
|
||||
state: Current graph state
|
||||
|
||||
Returns:
|
||||
Updated state with tool results and incremented iteration count
|
||||
"""
|
||||
logger.debug({
|
||||
"message": "GRAPH: ENTERING NODE - action/tool_node",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"iterations": state.get("iterations", 0)
|
||||
})
|
||||
await self.stream_callback({'type': 'node_start', 'data': 'tool_node'})
|
||||
|
||||
outputs = []
|
||||
messages = state.get("messages", [])
|
||||
last_message = messages[-1]
|
||||
for i, tool_call in enumerate(last_message.tool_calls):
|
||||
logger.debug(f'Executing tool {i+1}/{len(last_message.tool_calls)}: {tool_call["name"]} with args: {tool_call["args"]}')
|
||||
await self.stream_callback({'type': 'tool_start', 'data': tool_call["name"]})
|
||||
|
||||
try:
|
||||
if tool_call["name"] == "explain_image" and state.get("image_data"):
|
||||
tool_args = tool_call["args"].copy()
|
||||
tool_args["image"] = state["image_data"]
|
||||
logger.info(f'Executing tool {tool_call["name"]} with args: {tool_args}')
|
||||
tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_args)
|
||||
state["process_image_used"] = True
|
||||
else:
|
||||
tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_call["args"])
|
||||
if "code" in tool_call["name"]:
|
||||
content = str(tool_result)
|
||||
elif isinstance(tool_result, str):
|
||||
content = tool_result
|
||||
else:
|
||||
content = json.dumps(tool_result)
|
||||
except Exception as e:
|
||||
logger.error(f'Error executing tool {tool_call["name"]}: {str(e)}', exc_info=True)
|
||||
content = f"Error executing tool '{tool_call['name']}': {str(e)}"
|
||||
|
||||
await self.stream_callback({'type': 'tool_end', 'data': tool_call["name"]})
|
||||
|
||||
outputs.append(
|
||||
ToolMessage(
|
||||
content=content,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
|
||||
state["iterations"] = state.get("iterations", 0) + 1
|
||||
|
||||
logger.debug({
|
||||
"message": "GRAPH: EXITING NODE - action/tool_node",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"iterations": state.get("iterations"),
|
||||
"tools_executed": len(outputs),
|
||||
"next_step": "→ returning to generate"
|
||||
})
|
||||
await self.stream_callback({'type': 'node_end', 'data': 'tool_node'})
|
||||
return {"messages": messages + outputs, "iterations": state.get("iterations", 0) + 1}
|
||||
|
||||
async def generate(self, state: State) -> Dict[str, Any]:
|
||||
"""Generate AI response using the current model.
|
||||
|
||||
Args:
|
||||
state: Current graph state
|
||||
|
||||
Returns:
|
||||
Updated state with new AI message
|
||||
"""
|
||||
messages = convert_langgraph_messages_to_openai(state.get("messages", []))
|
||||
logger.debug({
|
||||
"message": "GRAPH: ENTERING NODE - generate",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"iterations": state.get("iterations", 0),
|
||||
"current_model": self.current_model,
|
||||
"message_count": len(state.get("messages", []))
|
||||
})
|
||||
await self.stream_callback({'type': 'node_start', 'data': 'generate'})
|
||||
|
||||
supports_tools = self.current_model in {"gpt-oss-20b", "gpt-oss-120b"}
|
||||
has_tools = supports_tools and self.openai_tools and len(self.openai_tools) > 0
|
||||
|
||||
logger.debug({
|
||||
"message": "Tool calling debug info",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"current_model": self.current_model,
|
||||
"supports_tools": supports_tools,
|
||||
"openai_tools_count": len(self.openai_tools) if self.openai_tools else 0,
|
||||
"openai_tools": self.openai_tools,
|
||||
"has_tools": has_tools
|
||||
})
|
||||
|
||||
tool_params = {}
|
||||
if has_tools:
|
||||
tool_params = {
|
||||
"tools": self.openai_tools,
|
||||
"tool_choice": "auto"
|
||||
}
|
||||
|
||||
stream = await self.model_client.chat.completions.create(
|
||||
model=self.current_model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
stream=True,
|
||||
**tool_params
|
||||
)
|
||||
|
||||
llm_output_buffer, tool_calls_buffer = await self._stream_response(stream, self.stream_callback)
|
||||
tool_calls = self._format_tool_calls(tool_calls_buffer)
|
||||
raw_output = "".join(llm_output_buffer)
|
||||
|
||||
logger.debug({
|
||||
"message": "Tool call generation results",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"tool_calls_buffer": tool_calls_buffer,
|
||||
"formatted_tool_calls": tool_calls,
|
||||
"tool_calls_count": len(tool_calls),
|
||||
"raw_output_length": len(raw_output),
|
||||
"raw_output": raw_output[:200] + "..." if len(raw_output) > 200 else raw_output
|
||||
})
|
||||
|
||||
response = AIMessage(
|
||||
content=raw_output,
|
||||
**({"tool_calls": tool_calls} if tool_calls else {})
|
||||
)
|
||||
|
||||
logger.debug({
|
||||
"message": "GRAPH: EXITING NODE - generate",
|
||||
"chat_id": state.get("chat_id"),
|
||||
"iterations": state.get("iterations", 0),
|
||||
"response_length": len(response.content) if response.content else 0,
|
||||
"tool_calls_generated": len(tool_calls),
|
||||
"tool_calls_names": [tc["name"] for tc in tool_calls] if tool_calls else [],
|
||||
"next_step": "→ should_continue decision"
|
||||
})
|
||||
await self.stream_callback({'type': 'node_end', 'data': 'generate'})
|
||||
return {"messages": state.get("messages", []) + [response]}
|
||||
|
||||
def _build_graph(self) -> StateGraph:
|
||||
"""Build the LangGraph state machine for conversation flow.
|
||||
|
||||
Returns:
|
||||
Compiled StateGraph with nodes and conditional edges
|
||||
"""
|
||||
workflow = StateGraph(State)
|
||||
|
||||
workflow.add_node("generate", self.generate)
|
||||
workflow.add_node("action", self.tool_node)
|
||||
workflow.add_edge(START, "generate")
|
||||
workflow.add_conditional_edges(
|
||||
"generate",
|
||||
self.should_continue,
|
||||
{
|
||||
"continue": "action",
|
||||
"end": END,
|
||||
},
|
||||
)
|
||||
workflow.add_edge("action", "generate")
|
||||
|
||||
return workflow.compile(checkpointer=memory)
|
||||
|
||||
def _format_tool_calls(self, tool_calls_buffer: Dict[int, Dict[str, str]]) -> List[ToolCall]:
|
||||
"""Parse streamed tool call buffer into ToolCall objects.
|
||||
|
||||
Args:
|
||||
tool_calls_buffer: Buffer of streamed tool call data
|
||||
|
||||
Returns:
|
||||
List of formatted ToolCall objects
|
||||
"""
|
||||
if not tool_calls_buffer:
|
||||
return []
|
||||
|
||||
tool_calls = []
|
||||
for i in sorted(tool_calls_buffer):
|
||||
item = tool_calls_buffer[i]
|
||||
try:
|
||||
parsed_args = json.loads(item["arguments"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
parsed_args = {}
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
name=item["name"],
|
||||
args=parsed_args,
|
||||
id=item["id"] or f"call_{i}",
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
async def _stream_response(self, stream, stream_callback: StreamCallback) -> tuple[List[str], Dict[int, Dict[str, str]]]:
|
||||
"""Process streaming LLM response and extract content and tool calls.
|
||||
|
||||
Args:
|
||||
stream: Async stream from LLM
|
||||
stream_callback: Callback for streaming events
|
||||
|
||||
Returns:
|
||||
Tuple of (content_buffer, tool_calls_buffer)
|
||||
"""
|
||||
llm_output_buffer = []
|
||||
tool_calls_buffer = {}
|
||||
saw_tool_finish = False
|
||||
|
||||
async for chunk in stream:
|
||||
for choice in getattr(chunk, "choices", []) or []:
|
||||
delta = getattr(choice, "delta", None)
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
content = getattr(delta, "content", None)
|
||||
if content:
|
||||
await stream_callback({"type": "token", "data": content})
|
||||
llm_output_buffer.append(content)
|
||||
for tc in getattr(delta, "tool_calls", []) or []:
|
||||
idx = getattr(tc, "index", None)
|
||||
if idx is None:
|
||||
idx = 0 if not tool_calls_buffer else max(tool_calls_buffer) + 1
|
||||
entry = tool_calls_buffer.setdefault(idx, {"id": None, "name": None, "arguments": ""})
|
||||
|
||||
if getattr(tc, "id", None):
|
||||
entry["id"] = tc.id
|
||||
|
||||
fn = getattr(tc, "function", None)
|
||||
if fn:
|
||||
if getattr(fn, "name", None):
|
||||
entry["name"] = fn.name
|
||||
if getattr(fn, "arguments", None):
|
||||
entry["arguments"] += fn.arguments
|
||||
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
if finish_reason == "tool_calls":
|
||||
saw_tool_finish = True
|
||||
break
|
||||
|
||||
if saw_tool_finish:
|
||||
break
|
||||
|
||||
return llm_output_buffer, tool_calls_buffer
|
||||
|
||||
async def query(self, query_text: str, chat_id: str, image_data: str = None) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Process user query and stream response tokens.
|
||||
|
||||
Args:
|
||||
query_text: User's input text
|
||||
chat_id: Unique chat identifier
|
||||
|
||||
Yields:
|
||||
Streaming events and tokens
|
||||
"""
|
||||
logger.debug({
|
||||
"message": "GRAPH: STARTING EXECUTION",
|
||||
"chat_id": chat_id,
|
||||
"query": query_text[:100] + "..." if len(query_text) > 100 else query_text,
|
||||
"graph_flow": "START → generate → should_continue → action → generate → END"
|
||||
})
|
||||
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
try:
|
||||
existing_messages = await self.conversation_store.get_messages(chat_id, limit=10)
|
||||
|
||||
base_system_prompt = self.system_prompt
|
||||
if image_data:
|
||||
image_context = "\n\nIMAGE CONTEXT: The user has uploaded an image with their message. You MUST use the explain_image tool to analyze it."
|
||||
system_prompt_with_image = base_system_prompt + image_context
|
||||
messages_to_process = [SystemMessage(content=system_prompt_with_image)]
|
||||
else:
|
||||
messages_to_process = [SystemMessage(content=base_system_prompt)]
|
||||
|
||||
if existing_messages:
|
||||
for msg in existing_messages:
|
||||
if not isinstance(msg, SystemMessage):
|
||||
messages_to_process.append(msg)
|
||||
|
||||
messages_to_process.append(HumanMessage(content=query_text))
|
||||
|
||||
config_obj = self.config_manager.read_config()
|
||||
|
||||
initial_state = {
|
||||
"iterations": 0,
|
||||
"chat_id": chat_id,
|
||||
"messages": messages_to_process,
|
||||
"image_data": image_data if image_data else None,
|
||||
"process_image_used": False
|
||||
}
|
||||
|
||||
|
||||
model_name = self.config_manager.get_selected_model()
|
||||
if self.current_model != model_name:
|
||||
self.set_current_model(model_name)
|
||||
|
||||
logger.debug({
|
||||
"message": "GRAPH: LAUNCHING EXECUTION",
|
||||
"chat_id": chat_id,
|
||||
"initial_state": {
|
||||
"iterations": initial_state["iterations"],
|
||||
"message_count": len(initial_state["messages"]),
|
||||
}
|
||||
})
|
||||
|
||||
self.last_state = None
|
||||
token_q: asyncio.Queue[Any] = asyncio.Queue()
|
||||
self.stream_callback = lambda event: self._queue_writer(event, token_q)
|
||||
runner = asyncio.create_task(self._run_graph(initial_state, config, chat_id, token_q))
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await token_q.get()
|
||||
if item is SENTINEL:
|
||||
break
|
||||
yield item
|
||||
except Exception as stream_error:
|
||||
logger.error({"message": "Error in streaming", "error": str(stream_error)}, exc_info=True)
|
||||
finally:
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await runner
|
||||
|
||||
logger.debug({
|
||||
"message": "GRAPH: EXECUTION COMPLETED",
|
||||
"chat_id": chat_id,
|
||||
"final_iterations": self.last_state.get("iterations", 0) if self.last_state else 0
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error({"message": "GRAPH: EXECUTION FAILED", "error": str(e), "chat_id": chat_id}, exc_info=True)
|
||||
yield {"type": "error", "data": f"Error performing query: {str(e)}"}
|
||||
|
||||
|
||||
async def _queue_writer(self, event: Dict[str, Any], token_q: asyncio.Queue) -> None:
|
||||
"""Write events to the streaming queue.
|
||||
|
||||
Args:
|
||||
event: Event data to queue
|
||||
token_q: Queue for streaming events
|
||||
"""
|
||||
await token_q.put(event)
|
||||
|
||||
async def _run_graph(self, initial_state: Dict[str, Any], config: Dict[str, Any], chat_id: str, token_q: asyncio.Queue) -> None:
|
||||
"""Run the graph execution in background task.
|
||||
|
||||
Args:
|
||||
initial_state: Starting state for graph
|
||||
config: LangGraph configuration
|
||||
chat_id: Chat identifier
|
||||
token_q: Queue for streaming events
|
||||
"""
|
||||
try:
|
||||
async for final_state in self.graph.astream(
|
||||
initial_state,
|
||||
config=config,
|
||||
stream_mode="values",
|
||||
stream_writer=lambda event: self._queue_writer(event, token_q)
|
||||
):
|
||||
self.last_state = final_state
|
||||
finally:
|
||||
try:
|
||||
if self.last_state and self.last_state.get("messages"):
|
||||
final_msg = self.last_state["messages"][-1]
|
||||
try:
|
||||
logger.debug(f'Saving messages to conversation store for chat: {chat_id}')
|
||||
await self.conversation_store.save_messages(chat_id, self.last_state["messages"])
|
||||
except Exception as save_err:
|
||||
logger.warning({"message": "Failed to persist conversation", "chat_id": chat_id, "error": str(save_err)})
|
||||
|
||||
content = getattr(final_msg, "content", None)
|
||||
if content:
|
||||
await token_q.put(content)
|
||||
finally:
|
||||
await token_q.put(SENTINEL)
|
||||
93
nvidia/multi-agent-chatbot/assets/backend/client.py
Normal file
@ -0,0 +1,93 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Multi-Server MCP Client for connecting to multiple MCP servers.
|
||||
|
||||
This module provides a unified client interface for connecting to and managing
|
||||
multiple Model Context Protocol (MCP) servers. It handles server configuration,
|
||||
initialization, and tool retrieval across different server types.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from mcp.types import Tool
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Client for managing connections to multiple MCP servers.
|
||||
|
||||
Provides a unified interface for connecting to and interacting with
|
||||
various MCP servers including RAG, image understanding, and weather services.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the MCP client with predefined server configurations."""
|
||||
self.server_configs = {
|
||||
"image-understanding-server": {
|
||||
"command": "python",
|
||||
"args": ["tools/mcp_servers/image_understanding.py"],
|
||||
"transport": "stdio",
|
||||
},
|
||||
"code-generation-server": {
|
||||
"command": "python",
|
||||
"args": ["tools/mcp_servers/code_generation.py"],
|
||||
"transport": "stdio",
|
||||
},
|
||||
"rag-server": {
|
||||
"command": "python",
|
||||
"args": ["tools/mcp_servers/rag.py"],
|
||||
"transport": "stdio",
|
||||
},
|
||||
"weather-server": {
|
||||
"command": "python",
|
||||
"args": ["tools/mcp_servers/weather_test.py"],
|
||||
"transport": "stdio",
|
||||
}
|
||||
}
|
||||
self.mcp_client: MultiServerMCPClient | None = None
|
||||
|
||||
async def init(self):
|
||||
"""Initialize the multi-server MCP client.
|
||||
|
||||
Returns:
|
||||
MCPClient: Self for method chaining
|
||||
|
||||
Raises:
|
||||
Exception: If client initialization fails
|
||||
"""
|
||||
self.mcp_client = MultiServerMCPClient(self.server_configs)
|
||||
return self
|
||||
|
||||
async def get_tools(self):
|
||||
"""Retrieve available tools from all connected MCP servers.
|
||||
|
||||
Returns:
|
||||
List[Tool]: List of available tools from all servers
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client is not initialized
|
||||
Exception: If tool retrieval fails
|
||||
"""
|
||||
if not self.mcp_client:
|
||||
raise RuntimeError("MCP client not initialized. Call `await init()` first.")
|
||||
|
||||
try:
|
||||
tools = await self.mcp_client.get_tools()
|
||||
return tools
|
||||
except Exception as error:
|
||||
print("Error encountered connecting to MCP server. Is the server running? Is your config server path correct?\n")
|
||||
raise error
|
||||
165
nvidia/multi-agent-chatbot/assets/backend/config.py
Normal file
@ -0,0 +1,165 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""ConfigManager for managing the configuration of the chat application."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from typing import List
|
||||
|
||||
from logger import logger
|
||||
from models import ChatConfig
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
def __init__(self, config_path: str):
|
||||
"""Initialize the ConfigManager"""
|
||||
self.config_path = config_path
|
||||
self.config = None
|
||||
self._last_modified = 0
|
||||
self._lock = threading.Lock()
|
||||
self._ensure_config_exists()
|
||||
self.read_config()
|
||||
|
||||
def _ensure_config_exists(self) -> None:
|
||||
"""Ensure config.json exists, creating it with default values if not."""
|
||||
models = []
|
||||
models = os.getenv("MODELS", "")
|
||||
|
||||
if models:
|
||||
models = [model.strip() for model in models.split(",") if model.strip()]
|
||||
else:
|
||||
logger.warning("MODELS environment variable not set, using empty models list")
|
||||
|
||||
if not os.path.exists(self.config_path):
|
||||
logger.debug(f"Config file {self.config_path} not found, creating default config")
|
||||
default_config = ChatConfig(
|
||||
sources=[],
|
||||
models=models,
|
||||
selected_model=models[0] if models else None,
|
||||
selected_sources=[],
|
||||
current_chat_id=None
|
||||
)
|
||||
|
||||
with open(self.config_path, "w") as f:
|
||||
json.dump(default_config.model_dump(), f, indent=2)
|
||||
else:
|
||||
try:
|
||||
with open(self.config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
existing_config = ChatConfig(**data)
|
||||
|
||||
if models:
|
||||
existing_config.models = models
|
||||
if not existing_config.selected_model or existing_config.selected_model not in models:
|
||||
existing_config.selected_model = models[0]
|
||||
|
||||
with open(self.config_path, "w") as f:
|
||||
json.dump(existing_config.model_dump(), f, indent=2)
|
||||
|
||||
logger.debug(f"Updated existing config with models: {models}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating existing config: {e}")
|
||||
default_config = ChatConfig(
|
||||
sources=[],
|
||||
models=models,
|
||||
selected_model=models[0] if models else None,
|
||||
selected_sources=[],
|
||||
current_chat_id=None
|
||||
)
|
||||
with open(self.config_path, "w") as f:
|
||||
json.dump(default_config.model_dump(), f, indent=2)
|
||||
|
||||
def read_config(self) -> ChatConfig:
|
||||
"""Read config from file, but only if it has changed since last read."""
|
||||
with self._lock:
|
||||
try:
|
||||
current_mtime = os.path.getmtime(self.config_path)
|
||||
if self.config is None or current_mtime > self._last_modified:
|
||||
with open(self.config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
self.config = ChatConfig(**data)
|
||||
self._last_modified = current_mtime
|
||||
return self.config
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading config: {e}")
|
||||
if self.config is None:
|
||||
models = []
|
||||
models = os.getenv("MODELS", "")
|
||||
if models:
|
||||
models = [model.strip() for model in models.split(",") if model.strip()]
|
||||
|
||||
self.config = ChatConfig(
|
||||
sources=[],
|
||||
models=models,
|
||||
selected_model=models[0] if models else "gpt-oss-120b",
|
||||
selected_sources=[],
|
||||
current_chat_id="1"
|
||||
)
|
||||
return self.config
|
||||
|
||||
def write_config(self, new_config: ChatConfig) -> None:
|
||||
"""Thread-safe write config to file."""
|
||||
with self._lock:
|
||||
with open(self.config_path, "w") as f:
|
||||
json.dump(new_config.model_dump(), f, indent=2)
|
||||
self.config = new_config
|
||||
self._last_modified = os.path.getmtime(self.config_path)
|
||||
|
||||
def get_sources(self) -> List[str]:
|
||||
"""Return list of available sources."""
|
||||
self.config = self.read_config()
|
||||
return self.config.sources
|
||||
|
||||
def get_selected_sources(self) -> List[str]:
|
||||
"""Return list of selected sources."""
|
||||
self.config = self.read_config()
|
||||
return self.config.selected_sources
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""Return list of available models."""
|
||||
self.config = self.read_config()
|
||||
return self.config.models
|
||||
|
||||
def get_selected_model(self) -> str:
|
||||
"""Return the selected model."""
|
||||
self.config = self.read_config()
|
||||
logger.debug(f"Selected model: {self.config.selected_model}")
|
||||
return self.config.selected_model
|
||||
|
||||
def get_current_chat_id(self) -> str:
|
||||
"""Return the current chat id."""
|
||||
self.config = self.read_config()
|
||||
return self.config.current_chat_id
|
||||
|
||||
|
||||
def updated_selected_sources(self, new_sources: List[str]) -> None:
|
||||
"""Update the selected sources in the config."""
|
||||
self.config = self.read_config().model_copy(update={"selected_sources": new_sources})
|
||||
self.write_config(self.config)
|
||||
|
||||
def updated_selected_model(self, new_model: str) -> None:
|
||||
"""Update the selected model in the config."""
|
||||
self.config = self.read_config().model_copy(update={"selected_model": new_model})
|
||||
logger.debug(f"Updated selected model to: {new_model}")
|
||||
self.write_config(self.config)
|
||||
|
||||
def updated_current_chat_id(self, new_chat_id: str) -> None:
|
||||
"""Update the current chat id in the config."""
|
||||
self.config = self.read_config().model_copy(update={"current_chat_id": new_chat_id})
|
||||
self.write_config(self.config)
|
||||
145
nvidia/multi-agent-chatbot/assets/backend/logger.py
Normal file
@ -0,0 +1,145 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""
|
||||
Formatter that outputs JSON strings after parsing the log record.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
self.default_keys = {
|
||||
'timestamp': 'asctime',
|
||||
'level': 'levelname',
|
||||
'message': 'message'
|
||||
}
|
||||
self.default_keys.update(kwargs)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""
|
||||
Format the log record as JSON.
|
||||
"""
|
||||
log_record = {}
|
||||
|
||||
log_record['timestamp'] = datetime.utcnow().isoformat() + 'Z'
|
||||
log_record['level'] = record.levelname
|
||||
log_record['logger'] = record.name
|
||||
|
||||
if isinstance(record.msg, dict):
|
||||
log_record['message'] = record.msg.get('message', '')
|
||||
for key, value in record.msg.items():
|
||||
if key != 'message':
|
||||
log_record[key] = value
|
||||
else:
|
||||
log_record['message'] = record.getMessage()
|
||||
|
||||
if record.exc_info:
|
||||
log_record['exception'] = {
|
||||
'type': record.exc_info[0].__name__,
|
||||
'message': str(record.exc_info[1]),
|
||||
'traceback': traceback.format_exception(*record.exc_info)
|
||||
}
|
||||
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in ['msg', 'args', 'exc_info', 'exc_text', 'stack_info', 'lineno',
|
||||
'funcName', 'created', 'msecs', 'relativeCreated', 'levelname',
|
||||
'levelno', 'pathname', 'filename', 'module', 'name', 'thread',
|
||||
'threadName', 'processName', 'process']:
|
||||
log_record[key] = value
|
||||
|
||||
return json.dumps(log_record)
|
||||
|
||||
def setup_logger(name: str = 'backend',
|
||||
level: int = logging.INFO,
|
||||
log_file: Optional[str] = 'app.log') -> logging.Logger:
|
||||
"""
|
||||
Set up a JSON logger with console and file handlers.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
level: Logging level
|
||||
log_file: Path to log file (None for no file logging)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
logger.propagate = False
|
||||
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(JsonFormatter())
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
if log_file:
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(JsonFormatter())
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def log_request(request_data: Dict[str, Any], endpoint: str) -> None:
|
||||
"""
|
||||
Log an API request with structured data.
|
||||
"""
|
||||
logger.info({
|
||||
'message': f'API request to {endpoint}',
|
||||
'endpoint': endpoint,
|
||||
'request_data': request_data
|
||||
})
|
||||
|
||||
|
||||
def log_response(response_data: Dict[str, Any], endpoint: str, status_code: int = 200) -> None:
|
||||
"""
|
||||
Log an API response with structured data.
|
||||
"""
|
||||
logger.info({
|
||||
'message': f'API response from {endpoint}',
|
||||
'endpoint': endpoint,
|
||||
'status_code': status_code,
|
||||
'response_data': response_data
|
||||
})
|
||||
|
||||
|
||||
def log_error(error: Exception, endpoint: str = None, request_data: Dict[str, Any] = None) -> None:
|
||||
"""
|
||||
Log an error with structured data.
|
||||
"""
|
||||
error_data = {
|
||||
'message': f'Error: {str(error)}',
|
||||
'error_type': error.__class__.__name__,
|
||||
}
|
||||
|
||||
if endpoint:
|
||||
error_data['endpoint'] = endpoint
|
||||
|
||||
if request_data:
|
||||
error_data['request_data'] = request_data
|
||||
|
||||
logger.error(error_data, exc_info=True)
|
||||
516
nvidia/multi-agent-chatbot/assets/backend/main.py
Normal file
@ -0,0 +1,516 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""FastAPI backend server for the chatbot application.
|
||||
|
||||
This module provides the main HTTP API endpoints and WebSocket connections for:
|
||||
- Real-time chat via WebSocket
|
||||
- File upload and document ingestion
|
||||
- Configuration management (models, sources, chat settings)
|
||||
- Chat history management
|
||||
- Vector store operations
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from agent import ChatAgent
|
||||
from config import ConfigManager
|
||||
from logger import logger, log_request, log_response, log_error
|
||||
from models import ChatIdRequest, ChatRenameRequest, SelectedModelRequest
|
||||
from postgres_storage import PostgreSQLConversationStorage
|
||||
from utils import process_and_ingest_files_background
|
||||
from vector_store import create_vector_store_with_config
|
||||
|
||||
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
|
||||
POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", 5432))
|
||||
POSTGRES_DB = os.getenv("POSTGRES_DB", "chatbot")
|
||||
POSTGRES_USER = os.getenv("POSTGRES_USER", "chatbot_user")
|
||||
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "chatbot_password")
|
||||
|
||||
config_manager = ConfigManager("./config.json")
|
||||
|
||||
postgres_storage = PostgreSQLConversationStorage(
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
database=POSTGRES_DB,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD
|
||||
)
|
||||
|
||||
vector_store = create_vector_store_with_config(config_manager)
|
||||
|
||||
vector_store._initialize_store()
|
||||
|
||||
agent: ChatAgent | None = None
|
||||
indexing_tasks: Dict[str, str] = {}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager for startup and shutdown tasks."""
|
||||
global agent
|
||||
logger.debug("Initializing PostgreSQL storage and agent...")
|
||||
|
||||
try:
|
||||
await postgres_storage.init_pool()
|
||||
logger.info("PostgreSQL storage initialized successfully")
|
||||
logger.debug("Initializing ChatAgent...")
|
||||
agent = await ChatAgent.create(
|
||||
vector_store=vector_store,
|
||||
config_manager=config_manager,
|
||||
postgres_storage=postgres_storage
|
||||
)
|
||||
logger.info("ChatAgent initialized successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL storage: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await postgres_storage.close()
|
||||
logger.debug("PostgreSQL storage closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing PostgreSQL storage: {e}")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Chatbot API",
|
||||
description="Backend API for LLM-powered chatbot with RAG capabilities",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.websocket("/ws/chat/{chat_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, chat_id: str):
|
||||
"""WebSocket endpoint for real-time chat communication.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
chat_id: Unique chat identifier
|
||||
"""
|
||||
logger.debug(f"WebSocket connection attempt for chat_id: {chat_id}")
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.debug(f"WebSocket connection accepted for chat_id: {chat_id}")
|
||||
|
||||
history_messages = await postgres_storage.get_messages(chat_id)
|
||||
history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(history_messages) if i != 0]
|
||||
await websocket.send_json({"type": "history", "messages": history})
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
client_message = json.loads(data)
|
||||
new_message = client_message.get("message")
|
||||
image_id = client_message.get("image_id")
|
||||
|
||||
image_data = None
|
||||
if image_id:
|
||||
image_data = await postgres_storage.get_image(image_id)
|
||||
logger.debug(f"Retrieved image data for image_id: {image_id}, data length: {len(image_data) if image_data else 0}")
|
||||
|
||||
try:
|
||||
async for event in agent.query(query_text=new_message, chat_id=chat_id, image_data=image_data):
|
||||
await websocket.send_json(event)
|
||||
except Exception as query_error:
|
||||
logger.error(f"Error in agent.query: {str(query_error)}", exc_info=True)
|
||||
await websocket.send_json({"type": "error", "content": f"Error processing request: {str(query_error)}"})
|
||||
|
||||
final_messages = await postgres_storage.get_messages(chat_id)
|
||||
final_history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(final_messages) if i != 0]
|
||||
await websocket.send_json({"type": "history", "messages": final_history})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug(f"Client disconnected from chat {chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for chat {chat_id}: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
@app.post("/upload-image")
|
||||
async def upload_image(image: UploadFile = File(...), chat_id: str = Form(...)):
|
||||
"""Upload and store an image for chat processing.
|
||||
|
||||
Args:
|
||||
image: Uploaded image file
|
||||
chat_id: Chat identifier for context
|
||||
|
||||
Returns:
|
||||
Dictionary with generated image_id
|
||||
"""
|
||||
image_data = await image.read()
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
data_uri = f"data:{image.content_type};base64,{image_base64}"
|
||||
image_id = str(uuid.uuid4())
|
||||
await postgres_storage.store_image(image_id, data_uri)
|
||||
return {"image_id": image_id}
|
||||
|
||||
|
||||
@app.post("/ingest")
|
||||
async def ingest_files(files: Optional[List[UploadFile]] = File(None), background_tasks: BackgroundTasks = None):
|
||||
"""Ingest documents for vector search and RAG.
|
||||
|
||||
Args:
|
||||
files: List of uploaded files to process
|
||||
background_tasks: FastAPI background tasks manager
|
||||
|
||||
Returns:
|
||||
Task information for tracking ingestion progress
|
||||
"""
|
||||
try:
|
||||
log_request({"file_count": len(files) if files else 0}, "/ingest")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
file_info = []
|
||||
for file in files:
|
||||
content = await file.read()
|
||||
file_info.append({
|
||||
"filename": file.filename,
|
||||
"content": content
|
||||
})
|
||||
|
||||
indexing_tasks[task_id] = "queued"
|
||||
|
||||
background_tasks.add_task(
|
||||
process_and_ingest_files_background,
|
||||
file_info,
|
||||
vector_store,
|
||||
config_manager,
|
||||
task_id,
|
||||
indexing_tasks
|
||||
)
|
||||
|
||||
response = {
|
||||
"message": f"Files queued for processing. Indexing {len(files)} files in the background.",
|
||||
"files": [file.filename for file in files],
|
||||
"status": "queued",
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
log_response(response, "/ingest")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
log_error(e, "/ingest")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error queuing files for ingestion: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/ingest/status/{task_id}")
|
||||
async def get_indexing_status(task_id: str):
|
||||
"""Get the status of a file ingestion task.
|
||||
|
||||
Args:
|
||||
task_id: Unique task identifier
|
||||
|
||||
Returns:
|
||||
Current task status
|
||||
"""
|
||||
if task_id in indexing_tasks:
|
||||
return {"status": indexing_tasks[task_id]}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
||||
@app.get("/sources")
|
||||
async def get_sources():
|
||||
"""Get all available document sources."""
|
||||
try:
|
||||
config = config_manager.read_config()
|
||||
return {"sources": config.sources}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting sources: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/selected_sources")
|
||||
async def get_selected_sources():
|
||||
"""Get currently selected document sources for RAG."""
|
||||
try:
|
||||
config = config_manager.read_config()
|
||||
return {"sources": config.selected_sources}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting selected sources: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/selected_sources")
|
||||
async def update_selected_sources(selected_sources: List[str]):
|
||||
"""Update the selected document sources for RAG.
|
||||
|
||||
Args:
|
||||
selected_sources: List of source names to use for retrieval
|
||||
"""
|
||||
try:
|
||||
config_manager.updated_selected_sources(selected_sources)
|
||||
return {"status": "success", "message": "Selected sources updated successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error updating selected sources: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/selected_model")
|
||||
async def get_selected_model():
|
||||
"""Get the currently selected LLM model."""
|
||||
try:
|
||||
model = config_manager.get_selected_model()
|
||||
return {"model": model}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting selected model: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/selected_model")
|
||||
async def update_selected_model(request: SelectedModelRequest):
|
||||
"""Update the selected LLM model.
|
||||
|
||||
Args:
|
||||
request: Model selection request with model name
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Updating selected model to: {request.model}")
|
||||
config_manager.updated_selected_model(request.model)
|
||||
return {"status": "success", "message": "Selected model updated successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error updating selected model: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/available_models")
|
||||
async def get_available_models():
|
||||
"""Get list of all available LLM models."""
|
||||
try:
|
||||
models = config_manager.get_available_models()
|
||||
return {"models": models}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting available models: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/chats")
|
||||
async def list_chats():
|
||||
"""Get list of all chat conversations."""
|
||||
try:
|
||||
chat_ids = await postgres_storage.list_conversations()
|
||||
return {"chats": chat_ids}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error listing chats: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/chat_id")
|
||||
async def get_chat_id():
|
||||
"""Get the current active chat ID, creating a conversation if it doesn't exist."""
|
||||
try:
|
||||
config = config_manager.read_config()
|
||||
current_chat_id = config.current_chat_id
|
||||
|
||||
if current_chat_id and await postgres_storage.exists(current_chat_id):
|
||||
return {
|
||||
"status": "success",
|
||||
"chat_id": current_chat_id
|
||||
}
|
||||
|
||||
new_chat_id = str(uuid.uuid4())
|
||||
|
||||
await postgres_storage.save_messages_immediate(new_chat_id, [])
|
||||
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
|
||||
|
||||
config_manager.updated_current_chat_id(new_chat_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"chat_id": new_chat_id
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting chat ID: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/chat_id")
|
||||
async def update_chat_id(request: ChatIdRequest):
|
||||
"""Update the current active chat ID.
|
||||
|
||||
Args:
|
||||
request: Chat ID update request
|
||||
"""
|
||||
try:
|
||||
config_manager.updated_current_chat_id(request.chat_id)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Current chat ID updated to {request.chat_id}",
|
||||
"chat_id": request.chat_id
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error updating chat ID: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/chat/{chat_id}/metadata")
|
||||
async def get_chat_metadata(chat_id: str):
|
||||
"""Get metadata for a specific chat.
|
||||
|
||||
Args:
|
||||
chat_id: Unique chat identifier
|
||||
|
||||
Returns:
|
||||
Chat metadata including name
|
||||
"""
|
||||
try:
|
||||
metadata = await postgres_storage.get_chat_metadata(chat_id)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting chat metadata: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/chat/rename")
|
||||
async def rename_chat(request: ChatRenameRequest):
|
||||
"""Rename a chat conversation.
|
||||
|
||||
Args:
|
||||
request: Chat rename request with chat_id and new_name
|
||||
"""
|
||||
try:
|
||||
await postgres_storage.set_chat_metadata(request.chat_id, request.new_name)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Chat {request.chat_id} renamed to {request.new_name}"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error renaming chat: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/chat/new")
|
||||
async def create_new_chat():
|
||||
"""Create a new chat conversation and set it as current."""
|
||||
try:
|
||||
new_chat_id = str(uuid.uuid4())
|
||||
await postgres_storage.save_messages_immediate(new_chat_id, [])
|
||||
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
|
||||
|
||||
config_manager.updated_current_chat_id(new_chat_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "New chat created",
|
||||
"chat_id": new_chat_id
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error creating new chat: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/chat/{chat_id}")
|
||||
async def delete_chat(chat_id: str):
|
||||
"""Delete a specific chat and its messages.
|
||||
|
||||
Args:
|
||||
chat_id: Unique chat identifier to delete
|
||||
"""
|
||||
try:
|
||||
success = await postgres_storage.delete_conversation(chat_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Chat {chat_id} deleted successfully"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Chat {chat_id} not found"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error deleting chat: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/chats/clear")
|
||||
async def clear_all_chats():
|
||||
"""Clear all chat conversations and create a new default chat."""
|
||||
try:
|
||||
chat_ids = await postgres_storage.list_conversations()
|
||||
cleared_count = 0
|
||||
|
||||
for chat_id in chat_ids:
|
||||
if await postgres_storage.delete_conversation(chat_id):
|
||||
cleared_count += 1
|
||||
|
||||
new_chat_id = str(uuid.uuid4())
|
||||
await postgres_storage.save_messages_immediate(new_chat_id, [])
|
||||
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
|
||||
|
||||
config_manager.updated_current_chat_id(new_chat_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Cleared {cleared_count} chats and created new chat",
|
||||
"new_chat_id": new_chat_id,
|
||||
"cleared_count": cleared_count
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error clearing all chats: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/collections/{collection_name}")
|
||||
async def delete_collection(collection_name: str):
|
||||
"""Delete a document collection from the vector store.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to delete
|
||||
"""
|
||||
try:
|
||||
success = vector_store.delete_collection(collection_name)
|
||||
if success:
|
||||
return {"status": "success", "message": f"Collection '{collection_name}' deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found or could not be deleted")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)
|
||||
35
nvidia/multi-agent-chatbot/assets/backend/models.py
Normal file
@ -0,0 +1,35 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
||||
class ChatConfig(BaseModel):
|
||||
sources: List[str]
|
||||
models : List[str]
|
||||
selected_model: Optional[str] = None
|
||||
selected_sources: Optional[List[str]] = None
|
||||
current_chat_id: Optional[str] = None
|
||||
|
||||
class ChatIdRequest(BaseModel):
|
||||
chat_id: str
|
||||
|
||||
class ChatRenameRequest(BaseModel):
|
||||
chat_id: str
|
||||
new_name: str
|
||||
|
||||
class SelectedModelRequest(BaseModel):
|
||||
model: str
|
||||
571
nvidia/multi-agent-chatbot/assets/backend/postgres_storage.py
Normal file
@ -0,0 +1,571 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""PostgreSQL-based conversation storage with caching and I/O optimization."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
import asyncpg
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage, ToolMessage
|
||||
|
||||
from logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cache entry with TTL support."""
|
||||
data: Any
|
||||
timestamp: float
|
||||
ttl: float = 300
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return time.time() - self.timestamp > self.ttl
|
||||
|
||||
|
||||
class PostgreSQLConversationStorage:
|
||||
"""PostgreSQL-based conversation storage with intelligent caching and I/O optimization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = 'postgres',
|
||||
port: int = 5432,
|
||||
database: str = 'chatbot',
|
||||
user: str = 'chatbot_user',
|
||||
password: str = 'chatbot_password',
|
||||
pool_size: int = 10,
|
||||
cache_ttl: int = 300
|
||||
):
|
||||
"""Initialize PostgreSQL connection pool and caching.
|
||||
|
||||
Args:
|
||||
host: PostgreSQL host
|
||||
port: PostgreSQL port
|
||||
database: Database name
|
||||
user: Database user
|
||||
password: Database password
|
||||
pool_size: Connection pool size
|
||||
cache_ttl: Cache TTL in seconds
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.database = database
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.pool_size = pool_size
|
||||
self.cache_ttl = cache_ttl
|
||||
|
||||
self.pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
self._message_cache: Dict[str, CacheEntry] = {}
|
||||
self._metadata_cache: Dict[str, CacheEntry] = {}
|
||||
self._image_cache: Dict[str, CacheEntry] = {}
|
||||
self._chat_list_cache: Optional[CacheEntry] = None
|
||||
|
||||
self._pending_saves: Dict[str, List[BaseMessage]] = {}
|
||||
self._save_lock = asyncio.Lock()
|
||||
self._batch_save_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
self._db_operations = 0
|
||||
|
||||
async def init_pool(self) -> None:
|
||||
"""Initialize the connection pool and create tables."""
|
||||
try:
|
||||
await self._ensure_database_exists()
|
||||
|
||||
self.pool = await asyncpg.create_pool(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
database=self.database,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
min_size=2,
|
||||
max_size=self.pool_size,
|
||||
command_timeout=30
|
||||
)
|
||||
|
||||
await self._create_tables()
|
||||
logger.debug("PostgreSQL connection pool initialized successfully")
|
||||
|
||||
self._batch_save_task = asyncio.create_task(self._batch_save_worker())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL pool: {e}")
|
||||
raise
|
||||
|
||||
async def _ensure_database_exists(self) -> None:
|
||||
"""Ensure the target database exists, create if it doesn't."""
|
||||
try:
|
||||
conn = await asyncpg.connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
database='postgres',
|
||||
user=self.user,
|
||||
password=self.password
|
||||
)
|
||||
|
||||
try:
|
||||
result = await conn.fetchval(
|
||||
"SELECT 1 FROM pg_database WHERE datname = $1",
|
||||
self.database
|
||||
)
|
||||
|
||||
if not result:
|
||||
await conn.execute(f'CREATE DATABASE "{self.database}"')
|
||||
logger.debug(f"Created database: {self.database}")
|
||||
else:
|
||||
logger.debug(f"Database {self.database} already exists")
|
||||
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring database exists: {e}")
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection pool and cleanup."""
|
||||
if self._batch_save_task:
|
||||
self._batch_save_task.cancel()
|
||||
try:
|
||||
await self._batch_save_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
logger.debug("PostgreSQL connection pool closed")
|
||||
|
||||
async def _create_tables(self) -> None:
|
||||
"""Create necessary tables if they don't exist."""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
chat_id VARCHAR(255) PRIMARY KEY,
|
||||
messages JSONB NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
message_count INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_metadata (
|
||||
chat_id VARCHAR(255) PRIMARY KEY,
|
||||
name VARCHAR(500),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (chat_id) REFERENCES conversations(chat_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_id VARCHAR(255) PRIMARY KEY,
|
||||
image_data TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP + INTERVAL '1 hour')
|
||||
)
|
||||
""")
|
||||
|
||||
await conn.execute("CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at)")
|
||||
await conn.execute("CREATE INDEX IF NOT EXISTS idx_images_expires_at ON images(expires_at)")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = CURRENT_TIMESTAMP;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ language 'plpgsql'
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
DROP TRIGGER IF EXISTS update_conversations_updated_at ON conversations
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE TRIGGER update_conversations_updated_at
|
||||
BEFORE UPDATE ON conversations
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column()
|
||||
""")
|
||||
|
||||
def _message_to_dict(self, message: BaseMessage) -> Dict:
|
||||
"""Convert a message object to a dictionary for storage."""
|
||||
result = {
|
||||
"type": message.__class__.__name__,
|
||||
"content": message.content,
|
||||
}
|
||||
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
result["tool_calls"] = message.tool_calls
|
||||
|
||||
if isinstance(message, ToolMessage):
|
||||
result["tool_call_id"] = getattr(message, "tool_call_id", None)
|
||||
result["name"] = getattr(message, "name", None)
|
||||
|
||||
return result
|
||||
|
||||
def _dict_to_message(self, data: Dict) -> BaseMessage:
|
||||
"""Convert a dictionary back to a message object."""
|
||||
msg_type = data["type"]
|
||||
content = data["content"]
|
||||
|
||||
if msg_type == "AIMessage":
|
||||
msg = AIMessage(content=content)
|
||||
if "tool_calls" in data:
|
||||
msg.tool_calls = data["tool_calls"]
|
||||
return msg
|
||||
elif msg_type == "HumanMessage":
|
||||
return HumanMessage(content=content)
|
||||
elif msg_type == "SystemMessage":
|
||||
return SystemMessage(content=content)
|
||||
elif msg_type == "ToolMessage":
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
name=data.get("name", "")
|
||||
)
|
||||
else:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
def _get_cached_messages(self, chat_id: str) -> Optional[List[BaseMessage]]:
|
||||
"""Get messages from cache if available and not expired."""
|
||||
cache_entry = self._message_cache.get(chat_id)
|
||||
if cache_entry and not cache_entry.is_expired():
|
||||
self._cache_hits += 1
|
||||
return cache_entry.data
|
||||
|
||||
self._cache_misses += 1
|
||||
return None
|
||||
|
||||
def _cache_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
|
||||
"""Cache messages with TTL."""
|
||||
self._message_cache[chat_id] = CacheEntry(
|
||||
data=messages.copy(),
|
||||
timestamp=time.time(),
|
||||
ttl=self.cache_ttl
|
||||
)
|
||||
|
||||
def _invalidate_cache(self, chat_id: str) -> None:
|
||||
"""Invalidate cache entries for a chat."""
|
||||
self._message_cache.pop(chat_id, None)
|
||||
self._metadata_cache.pop(chat_id, None)
|
||||
self._chat_list_cache = None
|
||||
|
||||
async def exists(self, chat_id: str) -> bool:
|
||||
"""Check if a conversation exists (with caching)."""
|
||||
cached_messages = self._get_cached_messages(chat_id)
|
||||
if cached_messages is not None:
|
||||
return len(cached_messages) > 0
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
result = await conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM conversations WHERE chat_id = $1)",
|
||||
chat_id
|
||||
)
|
||||
self._db_operations += 1
|
||||
return result
|
||||
|
||||
async def get_messages(self, chat_id: str, limit: Optional[int] = None) -> List[BaseMessage]:
|
||||
"""Retrieve messages for a chat session with caching."""
|
||||
cached_messages = self._get_cached_messages(chat_id)
|
||||
if cached_messages is not None:
|
||||
return cached_messages[-limit:] if limit else cached_messages
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT messages FROM conversations WHERE chat_id = $1",
|
||||
chat_id
|
||||
)
|
||||
self._db_operations += 1
|
||||
|
||||
if not row:
|
||||
return []
|
||||
|
||||
messages_data = row['messages']
|
||||
if isinstance(messages_data, str):
|
||||
messages_data = json.loads(messages_data)
|
||||
messages = [self._dict_to_message(msg_data) for msg_data in messages_data]
|
||||
|
||||
self._cache_messages(chat_id, messages)
|
||||
|
||||
return messages[-limit:] if limit else messages
|
||||
|
||||
async def save_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
|
||||
"""Save messages with batching for performance."""
|
||||
async with self._save_lock:
|
||||
self._pending_saves[chat_id] = messages.copy()
|
||||
|
||||
self._cache_messages(chat_id, messages)
|
||||
|
||||
async def save_messages_immediate(self, chat_id: str, messages: List[BaseMessage]) -> None:
|
||||
"""Save messages immediately without batching - for critical operations."""
|
||||
serialized_messages = [self._message_to_dict(msg) for msg in messages]
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO conversations (chat_id, messages, message_count)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (chat_id)
|
||||
DO UPDATE SET
|
||||
messages = EXCLUDED.messages,
|
||||
message_count = EXCLUDED.message_count,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
""", chat_id, json.dumps(serialized_messages), len(messages))
|
||||
self._db_operations += 1
|
||||
|
||||
self._cache_messages(chat_id, messages)
|
||||
self._chat_list_cache = None
|
||||
|
||||
async def _batch_save_worker(self) -> None:
|
||||
"""Background worker to batch save operations."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
async with self._save_lock:
|
||||
if not self._pending_saves:
|
||||
continue
|
||||
|
||||
saves_to_process = self._pending_saves.copy()
|
||||
self._pending_saves.clear()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
for chat_id, messages in saves_to_process.items():
|
||||
serialized_messages = [self._message_to_dict(msg) for msg in messages]
|
||||
|
||||
await conn.execute("""
|
||||
INSERT INTO conversations (chat_id, messages, message_count)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (chat_id)
|
||||
DO UPDATE SET
|
||||
messages = EXCLUDED.messages,
|
||||
message_count = EXCLUDED.message_count,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
""", chat_id, json.dumps(serialized_messages), len(messages))
|
||||
|
||||
self._db_operations += len(saves_to_process)
|
||||
if saves_to_process:
|
||||
logger.debug(f"Batch saved {len(saves_to_process)} conversations")
|
||||
self._chat_list_cache = None
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch save worker: {e}")
|
||||
|
||||
async def add_message(self, chat_id: str, message: BaseMessage) -> None:
|
||||
"""Add a single message to conversation (optimized)."""
|
||||
current_messages = await self.get_messages(chat_id)
|
||||
current_messages.append(message)
|
||||
|
||||
await self.save_messages(chat_id, current_messages)
|
||||
|
||||
async def delete_conversation(self, chat_id: str) -> bool:
|
||||
"""Delete a conversation by chat_id."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM conversations WHERE chat_id = $1",
|
||||
chat_id
|
||||
)
|
||||
self._db_operations += 1
|
||||
|
||||
self._invalidate_cache(chat_id)
|
||||
|
||||
return "DELETE 1" in result
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting conversation {chat_id}: {e}")
|
||||
return False
|
||||
|
||||
async def list_conversations(self) -> List[str]:
|
||||
"""List all conversation IDs with caching."""
|
||||
if self._chat_list_cache and not self._chat_list_cache.is_expired():
|
||||
self._cache_hits += 1
|
||||
return self._chat_list_cache.data
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT chat_id FROM conversations ORDER BY updated_at DESC"
|
||||
)
|
||||
self._db_operations += 1
|
||||
|
||||
chat_ids = [row['chat_id'] for row in rows]
|
||||
|
||||
self._chat_list_cache = CacheEntry(
|
||||
data=chat_ids,
|
||||
timestamp=time.time(),
|
||||
ttl=60
|
||||
)
|
||||
self._cache_misses += 1
|
||||
|
||||
return chat_ids
|
||||
|
||||
async def store_image(self, image_id: str, image_base64: str) -> None:
|
||||
"""Store base64 image data with TTL."""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO images (image_id, image_data)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (image_id)
|
||||
DO UPDATE SET
|
||||
image_data = EXCLUDED.image_data,
|
||||
created_at = CURRENT_TIMESTAMP,
|
||||
expires_at = CURRENT_TIMESTAMP + INTERVAL '1 hour'
|
||||
""", image_id, image_base64)
|
||||
self._db_operations += 1
|
||||
|
||||
self._image_cache[image_id] = CacheEntry(
|
||||
data=image_base64,
|
||||
timestamp=time.time(),
|
||||
ttl=3600
|
||||
)
|
||||
|
||||
async def get_image(self, image_id: str) -> Optional[str]:
|
||||
"""Retrieve base64 image data with caching."""
|
||||
cache_entry = self._image_cache.get(image_id)
|
||||
if cache_entry and not cache_entry.is_expired():
|
||||
self._cache_hits += 1
|
||||
return cache_entry.data
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT image_data FROM images
|
||||
WHERE image_id = $1 AND expires_at > CURRENT_TIMESTAMP
|
||||
""", image_id)
|
||||
self._db_operations += 1
|
||||
|
||||
if row:
|
||||
image_data = row['image_data']
|
||||
self._image_cache[image_id] = CacheEntry(
|
||||
data=image_data,
|
||||
timestamp=time.time(),
|
||||
ttl=3600
|
||||
)
|
||||
self._cache_misses += 1
|
||||
return image_data
|
||||
|
||||
return None
|
||||
|
||||
async def get_chat_metadata(self, chat_id: str) -> Optional[Dict]:
|
||||
"""Get chat metadata with caching."""
|
||||
cache_entry = self._metadata_cache.get(chat_id)
|
||||
if cache_entry and not cache_entry.is_expired():
|
||||
self._cache_hits += 1
|
||||
return cache_entry.data
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT name, created_at FROM chat_metadata WHERE chat_id = $1",
|
||||
chat_id
|
||||
)
|
||||
self._db_operations += 1
|
||||
|
||||
if row:
|
||||
metadata = {
|
||||
"name": row['name'],
|
||||
"created_at": row['created_at'].isoformat()
|
||||
}
|
||||
else:
|
||||
metadata = {"name": f"Chat {chat_id[:8]}"}
|
||||
|
||||
self._metadata_cache[chat_id] = CacheEntry(
|
||||
data=metadata,
|
||||
timestamp=time.time(),
|
||||
ttl=self.cache_ttl
|
||||
)
|
||||
self._cache_misses += 1
|
||||
|
||||
return metadata
|
||||
|
||||
async def set_chat_metadata(self, chat_id: str, name: str) -> None:
|
||||
"""Set chat metadata."""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO chat_metadata (chat_id, name)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (chat_id)
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
""", chat_id, name)
|
||||
self._db_operations += 1
|
||||
|
||||
self._metadata_cache[chat_id] = CacheEntry(
|
||||
data={"name": name},
|
||||
timestamp=time.time(),
|
||||
ttl=self.cache_ttl
|
||||
)
|
||||
|
||||
async def cleanup_expired_images(self) -> int:
|
||||
"""Clean up expired images and return count of deleted images."""
|
||||
async with self.pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM images WHERE expires_at < CURRENT_TIMESTAMP"
|
||||
)
|
||||
self._db_operations += 1
|
||||
|
||||
expired_keys = [
|
||||
key for key, entry in self._image_cache.items()
|
||||
if entry.is_expired()
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._image_cache[key]
|
||||
|
||||
deleted_count = int(result.split()[-1]) if result else 0
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"Cleaned up {deleted_count} expired images")
|
||||
|
||||
return deleted_count
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache performance statistics."""
|
||||
total_requests = self._cache_hits + self._cache_misses
|
||||
hit_rate = (self._cache_hits / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
"cache_hits": self._cache_hits,
|
||||
"cache_misses": self._cache_misses,
|
||||
"hit_rate_percent": round(hit_rate, 2),
|
||||
"db_operations": self._db_operations,
|
||||
"cached_conversations": len(self._message_cache),
|
||||
"cached_metadata": len(self._metadata_cache),
|
||||
"cached_images": len(self._image_cache)
|
||||
}
|
||||
|
||||
def load_conversation_history(self, chat_id: str) -> List[Dict]:
|
||||
"""Legacy method - converts to async call."""
|
||||
import asyncio
|
||||
return asyncio.create_task(self._load_conversation_history_dict(chat_id))
|
||||
|
||||
async def _load_conversation_history_dict(self, chat_id: str) -> List[Dict]:
|
||||
"""Load conversation history in dict format for compatibility."""
|
||||
messages = await self.get_messages(chat_id)
|
||||
return [self._message_to_dict(msg) for msg in messages]
|
||||
|
||||
def save_conversation_history(self, chat_id: str, messages: List[Dict]) -> None:
|
||||
"""Legacy method - converts to async call."""
|
||||
import asyncio
|
||||
message_objects = [self._dict_to_message(msg) for msg in messages]
|
||||
return asyncio.create_task(self.save_messages(chat_id, message_objects))
|
||||
152
nvidia/multi-agent-chatbot/assets/backend/prompts.py
Normal file
@ -0,0 +1,152 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import jinja2
|
||||
from typing import Dict
|
||||
|
||||
|
||||
SUPERVISOR_AGENT_STR = """
|
||||
You are a supervisor agent whose role is to be a helpful planner that can use tools to answer questions. DO NOT WRITE CODE YOURSELF, ALWAYS USE THE TOOLS.
|
||||
|
||||
{% if tools %}
|
||||
IMPORTANT: You have access to these tools and you MUST use them when applicable and use tool response in your final answer:
|
||||
{{ tools }}
|
||||
|
||||
CRITICAL RULES:
|
||||
- **ALWAYS** use a tool when the user's request matches a tool's capability. For example:
|
||||
- If the user asks to "generate code", "develop", "build", "create", "write a script", "make a website", "develop an app", etc. → **MUST** use the write_code tool with appropriate programming_language parameter
|
||||
- If the user asks to "search", "find", "summarize", "analyze documents/reports", "key points", etc. → **MUST** use the search_documents tool with the query, don't add any other text to the query. You can assume that the user has already uploaded the document and just call the tool.
|
||||
- If the user asks to analyze/describe/understand an image (e.g., "what's in this image", "describe the picture") → **MUST** use the explain_image tool
|
||||
|
||||
- **NEVER EVER generate code yourself** - you are FORBIDDEN from writing code directly. ALWAYS use the write_code tool for ANY coding requests
|
||||
- **DO NOT** try to answer questions from documents yourself - always use the search_documents tool
|
||||
|
||||
CODING KEYWORDS that REQUIRE write_code tool:
|
||||
- "code", "develop", "build", "create", "write", "make", "implement", "program", "script", "website", "app", "function", "class", "HTML", "CSS", "JavaScript", "Python", "React", "component"
|
||||
|
||||
|
||||
Batching policy:
|
||||
- **Batch** when: (a) calls are independent (e.g., weather in two cities), (b) calls target different tools without dependency, or (c) multiple calls to the same tool with different arguments.
|
||||
- **Do not batch** when: a call’s arguments depend on a previous tool’s output (e.g., writing code which depends on the output of a search_documents tool).
|
||||
|
||||
Output protocol:
|
||||
- In the first assistant message of a turn, if tools are needed, **emit all tool calls together** (as multiple tool_calls). Do not include narrative text before the tool_calls unless required by the API.
|
||||
- After the ToolMessages arrive, produce a single assistant message with the final answer incorporating all results. Do not call the tools again for the same purpose.
|
||||
- **CRITICAL**: When you receive tool results, you MUST use them in your final response. Do NOT ignore successful tool results or claim you don't have information when tools have already provided it.
|
||||
- If any tool call succeeds, base your answer on the successful results. Ignore failed tool calls if you have successful ones.
|
||||
- Always present the information from successful tool calls as your definitive answer.
|
||||
|
||||
|
||||
Few-shot examples:
|
||||
# Direct coding request
|
||||
User: Create a responsive personal website for my AI development business
|
||||
Assistant (tool calls immediately):
|
||||
- write_code({"query": "Create a responsive personal website for my AI development business", "programming_language": "HTML"})
|
||||
|
||||
# Batching independent calls
|
||||
User: now, can you get the weather in egypt and the rain forecast in malibu?
|
||||
Assistant (tool calls in one message):
|
||||
- get_weather({"location": "Egypt"})
|
||||
- get_rain_forecast({"location": "Malibu"})
|
||||
|
||||
# Staged dependent calls
|
||||
User: Search my documents for design requirements then build a website based on those requirements
|
||||
Assistant (first message; dependent plan):
|
||||
- search_documents({"query": "design requirements website"})
|
||||
# (Wait for ToolMessage)
|
||||
Assistant (after ToolMessage):
|
||||
- write_code({"query": "build a website based on these design requirements: <extracted information>", "programming_language": "HTML"})
|
||||
# (Then produce final answer)
|
||||
|
||||
# Using successful tool results
|
||||
User: Can you search NVIDIA's earnings document and summarize the key points?
|
||||
Assistant (tool calls):
|
||||
- search_documents({"query": "NVIDIA earnings document"})
|
||||
# (Wait for ToolMessage with comprehensive earnings data)
|
||||
Assistant (final response):
|
||||
Based on NVIDIA's earnings document, here are the key highlights:
|
||||
[...continues with the actual data from tool results...]
|
||||
|
||||
|
||||
{% else %}
|
||||
You do not have access to any tools right now.
|
||||
{% endif %}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_TEMPLATES = {
|
||||
"supervisor_agent": SUPERVISOR_AGENT_STR,
|
||||
}
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, jinja2.Template] = {
|
||||
name: jinja2.Template(template) for name, template in PROMPT_TEMPLATES.items()
|
||||
}
|
||||
|
||||
|
||||
class Prompts:
|
||||
"""
|
||||
A class providing access to prompt templates.
|
||||
|
||||
This class manages a collection of Jinja2 templates used for generating
|
||||
various prompts in the process.
|
||||
|
||||
The templates are pre-compiled for efficiency and can be accessed either
|
||||
through attribute access or the get_template class method.
|
||||
|
||||
Attributes:
|
||||
None - Templates are stored in module-level constants
|
||||
|
||||
Methods:
|
||||
__getattr__(name: str) -> str:
|
||||
Dynamically retrieves prompt template strings by name
|
||||
get_template(name: str) -> jinja2.Template:
|
||||
Retrieves pre-compiled Jinja2 templates by name
|
||||
"""
|
||||
|
||||
def __getattr__(self, name: str) -> str:
|
||||
"""
|
||||
Dynamically retrieve prompt templates by name.
|
||||
|
||||
Args:
|
||||
name (str): Name of the prompt template to retrieve
|
||||
|
||||
Returns:
|
||||
str: The prompt template string
|
||||
|
||||
Raises:
|
||||
AttributeError: If the requested template name doesn't exist
|
||||
"""
|
||||
if name in PROMPT_TEMPLATES:
|
||||
return PROMPT_TEMPLATES[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||
|
||||
@classmethod
|
||||
def get_template(cls, name: str) -> jinja2.Template:
|
||||
"""
|
||||
Get a pre-compiled Jinja2 template by name.
|
||||
|
||||
Args:
|
||||
name (str): Name of the template to retrieve
|
||||
|
||||
Returns:
|
||||
jinja2.Template: The pre-compiled Jinja2 template object
|
||||
|
||||
Raises:
|
||||
KeyError: If the requested template name doesn't exist
|
||||
"""
|
||||
return TEMPLATES[name]
|
||||
27
nvidia/multi-agent-chatbot/assets/backend/pyproject.toml
Normal file
@ -0,0 +1,27 @@
|
||||
[project]
|
||||
name = "backend"
|
||||
version = "0.1.0"
|
||||
description = "Backend API server for chatbot application"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"fastapi>=0.116.1",
|
||||
"langchain>=0.3.27",
|
||||
"langchain-milvus>=0.2.1",
|
||||
"langchain-mcp-adapters>=0.1.0",
|
||||
"langchain-nvidia-ai-endpoints>=0.3.13",
|
||||
"langchain-openai>=0.3.28",
|
||||
"langchain-text-splitters>=0.3.9",
|
||||
"langchain-unstructured>=0.1.6",
|
||||
"langgraph>=0.6.0",
|
||||
"mcp>=0.1.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pypdf2>=3.0.1",
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-multipart>=0.0.20",
|
||||
"asyncpg>=0.29.0",
|
||||
"requests>=2.28.0",
|
||||
"unstructured[pdf]>=0.18.11",
|
||||
"uvicorn>=0.35.0",
|
||||
"websockets>=15.0.1",
|
||||
]
|
||||
16
nvidia/multi-agent-chatbot/assets/backend/tools/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
@ -0,0 +1,70 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
mcp = FastMCP("Code Generation")
|
||||
model_name = "deepseek-coder:6.7b"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def write_code(query: str, programming_language: str):
|
||||
"""This tool is used to write complete code.
|
||||
|
||||
Args:
|
||||
query: The natural language description of the code to be generated.
|
||||
programming_language: The programming language for the code generation (e.g., 'Python', 'JavaScript', 'HTML', 'CSS', 'Go').
|
||||
|
||||
Returns:
|
||||
The generated code.
|
||||
"""
|
||||
model_client = AsyncOpenAI(
|
||||
base_url="http://deepseek-coder:8000/v1",
|
||||
api_key="ollama"
|
||||
)
|
||||
|
||||
system_prompt = f"""You are an expert coder specializing in {programming_language}.
|
||||
Given a user request, generate clean, efficient {programming_language} code that accomplishes the specified task.
|
||||
Always provide the full code generation so the user can copy and paste a fully working example.
|
||||
Return just the raw code, with no markdown formatting, explanations, or any other text.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query}
|
||||
]
|
||||
|
||||
response = await model_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
generated_code = response.choices[0].message.content
|
||||
return generated_code.strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Starting {mcp.name} MCP server...")
|
||||
mcp.run(transport="stdio")
|
||||
@ -0,0 +1,124 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
MCP server providing image understanding and analysis tools.
|
||||
|
||||
This server exposes a `process_image` tool that uses a vision language model to answer queries about images.
|
||||
It supports multiple image input formats including URLs, file paths, and base64-encoded images.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import requests
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
from langchain_core.tools import tool, Tool
|
||||
from langchain_mcp_adapters.tools import to_fastmcp
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(project_root))
|
||||
from postgres_storage import PostgreSQLConversationStorage
|
||||
|
||||
|
||||
mcp = FastMCP("image-understanding-server")
|
||||
|
||||
|
||||
model_name = "Qwen2.5-VL-7B-Instruct"
|
||||
model_client = OpenAI(
|
||||
base_url=f"http://qwen2.5-vl:8000/v1",
|
||||
api_key="api_key"
|
||||
)
|
||||
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
|
||||
POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", 5432))
|
||||
POSTGRES_DB = os.getenv("POSTGRES_DB", "chatbot")
|
||||
POSTGRES_USER = os.getenv("POSTGRES_USER", "chatbot_user")
|
||||
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "chatbot_password")
|
||||
|
||||
postgres_storage = PostgreSQLConversationStorage(
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
database=POSTGRES_DB,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def explain_image(query: str, image: str):
|
||||
"""
|
||||
This tool is used to understand an image. It will respond to the user's query based on the image.
|
||||
...
|
||||
"""
|
||||
if not image:
|
||||
raise ValueError('Error: explain_image tool received an empty image string.')
|
||||
|
||||
image_url_content = {}
|
||||
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
image_url_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image}
|
||||
}
|
||||
else:
|
||||
if image.startswith("data:image/"):
|
||||
metadata, b64_data = image.split(",", 1)
|
||||
filetype = metadata.split(";")[0].split("/")[-1]
|
||||
elif os.path.exists(image):
|
||||
with open(image, "rb") as image_file:
|
||||
filetype = image.split('.')[-1]
|
||||
b64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
else:
|
||||
raise ValueError(f'Invalid image type -- could not be identified as a url or filepath: {image}')
|
||||
|
||||
image_url_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{filetype if filetype else 'jpeg'};base64,{b64_data}"
|
||||
}
|
||||
}
|
||||
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": query},
|
||||
image_url_content
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
print(f"Sending request to vision model: {query}")
|
||||
response = model_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=message,
|
||||
max_tokens=512,
|
||||
temperature=0.1
|
||||
)
|
||||
print(f"Received response from vision model")
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print(f"Error calling vision model: {e}")
|
||||
raise RuntimeError(f"Failed to process image with vision model: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f'running {mcp.name} MCP server')
|
||||
mcp.run(transport="stdio")
|
||||
@ -0,0 +1,254 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""RAG MCP Server for Document Search and Question Answering.
|
||||
|
||||
This module implements an MCP server that provides document search capabilities using
|
||||
a simple retrieval-augmented generation (RAG) pipeline. The server exposes a
|
||||
search_documents tool that retrieves relevant document chunks and generates answers.
|
||||
|
||||
The simplified RAG workflow consists of:
|
||||
- Document retrieval from a vector store
|
||||
- Answer generation using retrieved context
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Annotated, Dict, List, Optional, Sequence, TypedDict
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import END, START, StateGraph, add_messages
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from openai import AsyncOpenAI
|
||||
from pypdf import PdfReader
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
from config import ConfigManager
|
||||
from vector_store import VectorStore, create_vector_store_with_config
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGState(TypedDict, total=False):
|
||||
"""Type definition for the simplified RAG agent state.
|
||||
|
||||
Attributes:
|
||||
question: The user's question to be answered.
|
||||
messages: Conversation history with automatic message aggregation.
|
||||
context: Retrieved documents from the local vector store.
|
||||
sources: Optional list of source filters for retrieval.
|
||||
"""
|
||||
question: str
|
||||
messages: Annotated[Sequence[AnyMessage], add_messages]
|
||||
context: Optional[List[Document]]
|
||||
sources: Optional[List[str]]
|
||||
|
||||
|
||||
class RAGAgent:
|
||||
"""Simplified RAG Agent for fast document search and answer generation.
|
||||
|
||||
This agent manages a simple two-step pipeline:
|
||||
1. Retrieve documents from the local vector store.
|
||||
2. Generate an answer using the retrieved context.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the RAG agent with model client, configuration, and graph."""
|
||||
config_path = self._get_config_path()
|
||||
self.config_manager = ConfigManager(config_path)
|
||||
self.vector_store = create_vector_store_with_config(self.config_manager)
|
||||
self.model_name = self.config_manager.get_selected_model()
|
||||
self.model_client = AsyncOpenAI(
|
||||
base_url=f"http://{self.model_name}:8000/v1",
|
||||
api_key="api_key"
|
||||
)
|
||||
|
||||
self.generation_prompt = self._get_generation_prompt()
|
||||
|
||||
|
||||
self.graph = self._build_graph()
|
||||
|
||||
def _get_config_path(self):
|
||||
"""Get the configuration file path and validate its existence."""
|
||||
config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
|
||||
if not os.path.exists(config_path):
|
||||
logger.error("ERROR: config.json not found")
|
||||
return config_path
|
||||
|
||||
def _get_generation_prompt(self) -> str:
|
||||
"""Get the system prompt template for the generation node."""
|
||||
return """You are an assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If no context is provided, answer the question using your own knowledge, but state that you could not find relevant information in the provided documents.
|
||||
Don't make up any information that is not provided in the context. Keep the answer concise.
|
||||
|
||||
Context:
|
||||
{context}
|
||||
"""
|
||||
|
||||
def retrieve(self, state: RAGState) -> Dict:
|
||||
"""Retrieve relevant documents from the vector store."""
|
||||
logger.info({"message": "Starting document retrieval"})
|
||||
sources = state.get("sources", [])
|
||||
|
||||
if sources:
|
||||
logger.info({"message": "Attempting retrieval with source filters", "sources": sources})
|
||||
retrieved_docs = self.vector_store.get_documents(state["question"], sources=sources)
|
||||
else:
|
||||
logger.info({"message": "No sources specified, searching all documents"})
|
||||
retrieved_docs = self.vector_store.get_documents(state["question"])
|
||||
|
||||
if not retrieved_docs and sources:
|
||||
logger.info({"message": "No documents found with source filtering, trying without filters"})
|
||||
retrieved_docs = self.vector_store.get_documents(state["question"])
|
||||
|
||||
if retrieved_docs:
|
||||
sources_found = set(doc.metadata.get("source", "unknown") for doc in retrieved_docs)
|
||||
logger.info({"message": "Document sources found", "sources": list(sources_found), "doc_count": len(retrieved_docs)})
|
||||
else:
|
||||
logger.warning({"message": "No documents retrieved", "query": state["question"], "attempted_sources": sources})
|
||||
|
||||
return {"context": retrieved_docs}
|
||||
|
||||
|
||||
async def generate(self, state: RAGState) -> Dict:
|
||||
"""Generate an answer using retrieved context."""
|
||||
logger.info({
|
||||
"message": "Generating answer",
|
||||
"question": state['question']
|
||||
})
|
||||
|
||||
context = state.get("context", [])
|
||||
|
||||
if not context:
|
||||
logger.warning({"message": "No context available for generation", "question": state['question']})
|
||||
docs_content = "No relevant documents were found."
|
||||
else:
|
||||
logger.info({"message": "Generating with context", "context_count": len(context)})
|
||||
docs_content = self._hydrate_context(context)
|
||||
|
||||
system_prompt = self.generation_prompt.format(context=docs_content)
|
||||
user_message = f"Question: {state['question']}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.model_client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
)
|
||||
response_content = response.choices[0].message.content
|
||||
|
||||
logger.info({
|
||||
"message": "Generation completed",
|
||||
"response_length": len(response_content),
|
||||
"response_preview": response_content[:100] + "..."
|
||||
})
|
||||
|
||||
return {
|
||||
"messages": [HumanMessage(content=state["question"]), AIMessage(content=response_content)]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error({"message": "Error during generation", "error": str(e)})
|
||||
fallback_response = f"I apologize, but I encountered an error while processing your query about: {state['question']}"
|
||||
return {
|
||||
"messages": [HumanMessage(content=state["question"]), AIMessage(content=fallback_response)]
|
||||
}
|
||||
|
||||
def _hydrate_context(self, context: List[Document]) -> str:
|
||||
"""Extract text content from document objects."""
|
||||
return "\n\n".join([doc.page_content for doc in context if doc.page_content])
|
||||
|
||||
def _build_graph(self):
|
||||
"""Build and compile the simplified RAG workflow graph."""
|
||||
workflow = StateGraph(RAGState)
|
||||
|
||||
workflow.add_node("retrieve", self.retrieve)
|
||||
workflow.add_node("generate", self.generate)
|
||||
|
||||
workflow.set_entry_point("retrieve")
|
||||
workflow.add_edge("retrieve", "generate")
|
||||
workflow.add_edge("generate", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
mcp = FastMCP("RAG")
|
||||
rag_agent = RAGAgent()
|
||||
vector_store = create_vector_store_with_config(rag_agent.config_manager)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_documents(query: str) -> str:
|
||||
"""Search documents uploaded by the user to generate fast, grounded answers.
|
||||
|
||||
Performs a simple RAG pipeline that retrieves relevant documents and generates answers.
|
||||
|
||||
Args:
|
||||
query: The question or query to search for.
|
||||
|
||||
Returns:
|
||||
A concise answer based on the retrieved documents.
|
||||
"""
|
||||
config_obj = rag_agent.config_manager.read_config()
|
||||
sources = config_obj.selected_sources or []
|
||||
|
||||
initial_state = {
|
||||
"question": query,
|
||||
"sources": sources,
|
||||
"messages": []
|
||||
}
|
||||
|
||||
thread_id = f"rag_session_{time.time()}"
|
||||
|
||||
result = await rag_agent.graph.ainvoke(initial_state)
|
||||
|
||||
if not result.get("messages"):
|
||||
logger.error({"message": "No messages in RAG result", "query": query})
|
||||
return "I apologize, but I encountered an error processing your query and no response was generated."
|
||||
|
||||
final_message = result["messages"][-1]
|
||||
final_content = getattr(final_message, 'content', '') or ''
|
||||
|
||||
if not final_content.strip():
|
||||
logger.warning({"message": "Empty content in final RAG message", "query": query, "message_type": type(final_message).__name__})
|
||||
return f"I found relevant documents for your query '{query}' but was unable to generate a response. Please try rephrasing your question."
|
||||
|
||||
logger.info({"message": "RAG result", "content_length": len(final_content), "query": query})
|
||||
return final_content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Starting {mcp.name} MCP server...")
|
||||
mcp.run(transport="stdio")
|
||||
@ -0,0 +1,72 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Weather Test MCP Server
|
||||
|
||||
A lightweight Model Context Protocol (MCP) server designed for testing and demonstration purposes.
|
||||
This module provides mock weather tools that return humorous, fake responses rather than real
|
||||
weather data.
|
||||
|
||||
Features:
|
||||
- Mock weather data retreival for any location
|
||||
- Example implementation of MCP tool registration with local functions
|
||||
- Demonstration of FastMCP server setup
|
||||
|
||||
Tools provided:
|
||||
- get_weather(location): Returns mock weather information
|
||||
- get_rain_forecast(location): Returns mock rain forecast data
|
||||
|
||||
Usage:
|
||||
- Usage in Chatbot Spark:
|
||||
The server is ran on project startup by MCPClient. SEe client.py for more details
|
||||
- Standalone Usage:
|
||||
Run as standalone script to start the MCP server:
|
||||
$ python weather_test.py
|
||||
"""
|
||||
import time
|
||||
import os
|
||||
|
||||
from langchain_core.tools import tool, Tool
|
||||
from langchain_mcp_adapters.tools import to_fastmcp
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
|
||||
mcp = FastMCP("weather-tools")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_weather(location: str):
|
||||
"""Call to get the weather from a specific location."""
|
||||
if any([city in location.lower() for city in ["sf", "san francisco"]]):
|
||||
return "It's sunny in San Francisco, but you better look out if you're a Gemini 😈."
|
||||
else:
|
||||
return f"The weather is spooky with a chance of gremlins in {location}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_rain_forecast(location: str):
|
||||
"""Call to get the rain forecast from a specific location."""
|
||||
if any([city in location.lower() for city in ["sf", "san francisco"]]):
|
||||
return "It's going to rain cats and dogs in San Francisco tomorrow."
|
||||
else:
|
||||
return f"It is raining muffins in {location}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f'running {mcp.name} MCP server')
|
||||
mcp.run(transport="stdio")
|
||||
182
nvidia/multi-agent-chatbot/assets/backend/utils.py
Normal file
@ -0,0 +1,182 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Utility functions for file processing and message conversion."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, ToolCall
|
||||
|
||||
from logger import logger
|
||||
from vector_store import VectorStore
|
||||
|
||||
|
||||
async def process_and_ingest_files_background(
|
||||
file_info: List[dict],
|
||||
vector_store: VectorStore,
|
||||
config_manager,
|
||||
task_id: str,
|
||||
indexing_tasks: Dict[str, str]
|
||||
) -> None:
|
||||
"""Process and ingest files in the background.
|
||||
|
||||
Args:
|
||||
file_info: List of file dictionaries with 'filename' and 'content' keys
|
||||
vector_store: VectorStore instance for document indexing
|
||||
config_manager: ConfigManager instance for updating sources
|
||||
task_id: Unique identifier for this processing task
|
||||
indexing_tasks: Dictionary to track task status
|
||||
"""
|
||||
try:
|
||||
logger.debug({
|
||||
"message": "Starting background file processing",
|
||||
"task_id": task_id,
|
||||
"file_count": len(file_info)
|
||||
})
|
||||
|
||||
indexing_tasks[task_id] = "saving_files"
|
||||
|
||||
permanent_dir = os.path.join("uploads", task_id)
|
||||
os.makedirs(permanent_dir, exist_ok=True)
|
||||
|
||||
file_paths = []
|
||||
file_names = []
|
||||
|
||||
for info in file_info:
|
||||
try:
|
||||
file_name = info["filename"]
|
||||
content = info["content"]
|
||||
|
||||
file_path = os.path.join(permanent_dir, file_name)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
file_paths.append(file_path)
|
||||
file_names.append(file_name)
|
||||
|
||||
logger.debug({
|
||||
"message": "Saved file",
|
||||
"task_id": task_id,
|
||||
"filename": file_name,
|
||||
"path": file_path
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": f"Error saving file {info['filename']}",
|
||||
"task_id": task_id,
|
||||
"filename": info['filename'],
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
|
||||
indexing_tasks[task_id] = "loading_documents"
|
||||
logger.debug({"message": "Loading documents", "task_id": task_id})
|
||||
|
||||
try:
|
||||
documents = vector_store._load_documents(file_paths)
|
||||
|
||||
logger.debug({
|
||||
"message": "Documents loaded, starting indexing",
|
||||
"task_id": task_id,
|
||||
"document_count": len(documents)
|
||||
})
|
||||
|
||||
indexing_tasks[task_id] = "indexing_documents"
|
||||
vector_store.index_documents(documents)
|
||||
|
||||
if file_names:
|
||||
config = config_manager.read_config()
|
||||
|
||||
config_updated = False
|
||||
for file_name in file_names:
|
||||
if file_name not in config.sources:
|
||||
config.sources.append(file_name)
|
||||
config_updated = True
|
||||
|
||||
if config_updated:
|
||||
config_manager.write_config(config)
|
||||
logger.debug({
|
||||
"message": "Updated config with new sources",
|
||||
"task_id": task_id,
|
||||
"sources": config.sources
|
||||
})
|
||||
|
||||
indexing_tasks[task_id] = "completed"
|
||||
logger.debug({
|
||||
"message": "Background processing and indexing completed successfully",
|
||||
"task_id": task_id
|
||||
})
|
||||
except Exception as e:
|
||||
indexing_tasks[task_id] = f"failed_during_indexing: {str(e)}"
|
||||
logger.error({
|
||||
"message": "Error during document loading or indexing",
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
indexing_tasks[task_id] = f"failed: {str(e)}"
|
||||
logger.error({
|
||||
"message": "Error in background processing",
|
||||
"task_id": task_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
|
||||
|
||||
def convert_langgraph_messages_to_openai(messages: List) -> List[Dict[str, Any]]:
|
||||
"""Convert LangGraph message objects to OpenAI API format.
|
||||
|
||||
Args:
|
||||
messages: List of LangGraph message objects
|
||||
|
||||
Returns:
|
||||
List of dictionaries in OpenAI API format
|
||||
"""
|
||||
openai_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
openai_messages.append({
|
||||
"role": "user",
|
||||
"content": msg.content
|
||||
})
|
||||
elif isinstance(msg, AIMessage):
|
||||
openai_msg = {
|
||||
"role": "assistant",
|
||||
"content": msg.content or ""
|
||||
}
|
||||
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||
openai_msg["tool_calls"] = []
|
||||
for tc in msg.tool_calls:
|
||||
openai_msg["tool_calls"].append({
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc["args"])
|
||||
}
|
||||
})
|
||||
openai_messages.append(openai_msg)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
openai_messages.append({
|
||||
"role": "tool",
|
||||
"content": msg.content,
|
||||
"tool_call_id": msg.tool_call_id
|
||||
})
|
||||
|
||||
return openai_messages
|
||||
4865
nvidia/multi-agent-chatbot/assets/backend/uv.lock
generated
Normal file
388
nvidia/multi-agent-chatbot/assets/backend/vector_store.py
Normal file
@ -0,0 +1,388 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import glob
|
||||
from typing import List, Tuple
|
||||
import os
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_milvus import Milvus
|
||||
from langchain_core.documents import Document
|
||||
from typing_extensions import List
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
from dotenv import load_dotenv
|
||||
from logger import logger
|
||||
from typing import Optional, Callable
|
||||
import requests
|
||||
|
||||
|
||||
class CustomEmbeddings:
|
||||
"""Wraps qwen3 embedding model to match OpenAI format"""
|
||||
def __init__(self, model: str = "Qwen3-Embedding-4B-Q8_0.gguf", host: str = "http://qwen3-embedding:8000"):
|
||||
self.model = model
|
||||
self.url = f"{host}/v1/embeddings"
|
||||
|
||||
def __call__(self, texts: list[str]) -> list[list[float]]:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = requests.post(
|
||||
self.url,
|
||||
json={"input": text, "model": self.model},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
embeddings.append(data["data"][0]["embedding"])
|
||||
return embeddings
|
||||
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of document texts. Required by Milvus library."""
|
||||
return self.__call__(texts)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed a single query text. Required by Milvus library."""
|
||||
return self.__call__([text])[0]
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""Vector store for document embedding and retrieval.
|
||||
|
||||
Decoupled from ConfigManager - uses optional callbacks for source management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings=None,
|
||||
uri: str = "http://milvus:19530",
|
||||
on_source_deleted: Optional[Callable[[str], None]] = None
|
||||
):
|
||||
"""Initialize the vector store.
|
||||
|
||||
Args:
|
||||
embeddings: Embedding model to use (defaults to OllamaEmbeddings)
|
||||
uri: Milvus connection URI
|
||||
on_source_deleted: Optional callback when a source is deleted
|
||||
"""
|
||||
try:
|
||||
self.embeddings = embeddings or CustomEmbeddings(model="qwen3-embedding-custom")
|
||||
self.uri = uri
|
||||
self.on_source_deleted = on_source_deleted
|
||||
self._initialize_store()
|
||||
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200
|
||||
)
|
||||
|
||||
logger.debug({
|
||||
"message": "VectorStore initialized successfully"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": "Error initializing VectorStore",
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise
|
||||
|
||||
def _initialize_store(self):
|
||||
self._store = Milvus(
|
||||
embedding_function=self.embeddings,
|
||||
collection_name="context",
|
||||
connection_args={"uri": self.uri},
|
||||
auto_id=True
|
||||
)
|
||||
logger.debug({
|
||||
"message": "Milvus vector store initialized",
|
||||
"uri": self.uri,
|
||||
"collection": "context"
|
||||
})
|
||||
|
||||
def _load_documents(self, file_paths: List[str] = None, input_dir: str = None) -> List[str]:
|
||||
try:
|
||||
documents = []
|
||||
source_name = None
|
||||
|
||||
if input_dir:
|
||||
source_name = os.path.basename(os.path.normpath(input_dir))
|
||||
logger.debug({
|
||||
"message": "Loading files from directory",
|
||||
"directory": input_dir,
|
||||
"source": source_name
|
||||
})
|
||||
file_paths = glob.glob(os.path.join(input_dir, "**"), recursive=True)
|
||||
file_paths = [f for f in file_paths if os.path.isfile(f)]
|
||||
|
||||
logger.info(f"Processing {len(file_paths)} files: {file_paths}")
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
if not source_name:
|
||||
source_name = os.path.basename(file_path)
|
||||
logger.info(f"Using filename as source: {source_name}")
|
||||
|
||||
logger.info(f"Loading file: {file_path}")
|
||||
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
logger.info(f"File extension: {file_ext}")
|
||||
|
||||
try:
|
||||
loader = UnstructuredLoader(file_path)
|
||||
docs = loader.load()
|
||||
logger.info(f"Successfully loaded {len(docs)} documents from {file_path}")
|
||||
except Exception as pdf_error:
|
||||
logger.error(f'error with unstructured loader, trying to load from scratch')
|
||||
file_text = None
|
||||
if file_ext == ".pdf":
|
||||
logger.info("Attempting PyPDF text extraction fallback")
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
reader = PdfReader(file_path)
|
||||
extracted_pages = []
|
||||
for page in reader.pages:
|
||||
try:
|
||||
extracted_pages.append(page.extract_text() or "")
|
||||
except Exception as per_page_err:
|
||||
logger.info(f"Warning: failed to extract a page: {per_page_err}")
|
||||
extracted_pages.append("")
|
||||
file_text = "\n\n".join(extracted_pages).strip()
|
||||
except Exception as pypdf_error:
|
||||
logger.info(f"PyPDF fallback failed: {pypdf_error}")
|
||||
file_text = None
|
||||
|
||||
if not file_text:
|
||||
logger.info("Falling back to raw text read of file contents")
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
file_text = f.read()
|
||||
except Exception as read_error:
|
||||
logger.info(f"Fallback read failed: {read_error}")
|
||||
file_text = ""
|
||||
|
||||
if file_text and file_text.strip():
|
||||
docs = [Document(
|
||||
page_content=file_text,
|
||||
metadata={
|
||||
"source": source_name,
|
||||
"file_path": file_path,
|
||||
"filename": os.path.basename(file_path),
|
||||
}
|
||||
)]
|
||||
else:
|
||||
logger.info("Creating a simple document as fallback (no text extracted)")
|
||||
docs = [Document(
|
||||
page_content=f"Document: {os.path.basename(file_path)}",
|
||||
metadata={
|
||||
"source": source_name,
|
||||
"file_path": file_path,
|
||||
"filename": os.path.basename(file_path),
|
||||
}
|
||||
)]
|
||||
|
||||
for doc in docs:
|
||||
if not doc.metadata:
|
||||
doc.metadata = {}
|
||||
|
||||
cleaned_metadata = {}
|
||||
cleaned_metadata["source"] = source_name
|
||||
cleaned_metadata["file_path"] = file_path
|
||||
cleaned_metadata["filename"] = os.path.basename(file_path)
|
||||
|
||||
for key, value in doc.metadata.items():
|
||||
if key not in ["source", "file_path"]:
|
||||
if isinstance(value, (list, dict, set)):
|
||||
cleaned_metadata[key] = str(value)
|
||||
elif value is not None:
|
||||
cleaned_metadata[key] = str(value)
|
||||
|
||||
doc.metadata = cleaned_metadata
|
||||
documents.extend(docs)
|
||||
logger.debug({
|
||||
"message": "Loaded documents from file",
|
||||
"file_path": file_path,
|
||||
"document_count": len(docs)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": "Error loading file",
|
||||
"file_path": file_path,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
continue
|
||||
|
||||
logger.info(f"Total documents loaded: {len(documents)}")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": "Error loading documents",
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise
|
||||
|
||||
def index_documents(self, documents: List[Document]) -> List[Document]:
|
||||
try:
|
||||
logger.debug({
|
||||
"message": "Starting document indexing",
|
||||
"document_count": len(documents)
|
||||
})
|
||||
|
||||
splits = self.text_splitter.split_documents(documents)
|
||||
logger.debug({
|
||||
"message": "Split documents into chunks",
|
||||
"chunk_count": len(splits)
|
||||
})
|
||||
|
||||
self._store.add_documents(splits)
|
||||
self.flush_store()
|
||||
|
||||
logger.debug({
|
||||
"message": "Document indexing completed"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": "Error during document indexing",
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise
|
||||
|
||||
def flush_store(self):
|
||||
"""
|
||||
Flush the Milvus collection to ensure that all added documents are persisted to disk.
|
||||
"""
|
||||
try:
|
||||
from pymilvus import connections
|
||||
|
||||
connections.connect(uri=self.uri)
|
||||
|
||||
|
||||
from pymilvus import utility
|
||||
utility.flush_all()
|
||||
|
||||
logger.debug({
|
||||
"message": "Milvus store flushed (persisted to disk)"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": "Error flushing Milvus store",
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
|
||||
|
||||
def get_documents(self, query: str, k: int = 8, sources: List[str] = None) -> List[Document]:
|
||||
"""
|
||||
Get relevant documents using the retriever's invoke method.
|
||||
"""
|
||||
try:
|
||||
search_kwargs = {"k": k}
|
||||
|
||||
if sources:
|
||||
if len(sources) == 1:
|
||||
filter_expr = f'source == "{sources[0]}"'
|
||||
else:
|
||||
source_conditions = [f'source == "{source}"' for source in sources]
|
||||
filter_expr = " || ".join(source_conditions)
|
||||
|
||||
search_kwargs["expr"] = filter_expr
|
||||
logger.debug({
|
||||
"message": "Retrieving with filter",
|
||||
"filter": filter_expr
|
||||
})
|
||||
|
||||
retriever = self._store.as_retriever(
|
||||
search_type="similarity",
|
||||
search_kwargs=search_kwargs
|
||||
)
|
||||
|
||||
docs = retriever.invoke(query)
|
||||
logger.debug({
|
||||
"message": "Retrieved documents",
|
||||
"query": query,
|
||||
"document_count": len(docs)
|
||||
})
|
||||
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": "Error retrieving documents",
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
return []
|
||||
|
||||
def delete_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Delete a collection from Milvus.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to delete
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
from pymilvus import connections, Collection, utility
|
||||
|
||||
connections.connect(uri=self.uri)
|
||||
|
||||
if utility.has_collection(collection_name):
|
||||
collection = Collection(name=collection_name)
|
||||
|
||||
collection.drop()
|
||||
|
||||
if self.on_source_deleted:
|
||||
self.on_source_deleted(collection_name)
|
||||
|
||||
logger.debug({
|
||||
"message": "Collection deleted successfully",
|
||||
"collection_name": collection_name
|
||||
})
|
||||
return True
|
||||
else:
|
||||
logger.warning({
|
||||
"message": "Collection not found",
|
||||
"collection_name": collection_name
|
||||
})
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error({
|
||||
"message": "Error deleting collection",
|
||||
"collection_name": collection_name,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def create_vector_store_with_config(config_manager, uri: str = "http://milvus:19530") -> VectorStore:
|
||||
"""Factory function to create a VectorStore with ConfigManager integration.
|
||||
|
||||
Args:
|
||||
config_manager: ConfigManager instance for source management
|
||||
uri: Milvus connection URI
|
||||
|
||||
Returns:
|
||||
VectorStore instance with source deletion callback
|
||||
"""
|
||||
def handle_source_deleted(source_name: str):
|
||||
"""Handle source deletion by updating config."""
|
||||
config = config_manager.read_config()
|
||||
if hasattr(config, 'sources') and source_name in config.sources:
|
||||
config.sources.remove(source_name)
|
||||
config_manager.write_config(config)
|
||||
|
||||
return VectorStore(
|
||||
uri=uri,
|
||||
on_source_deleted=handle_source_deleted
|
||||
)
|
||||
160
nvidia/multi-agent-chatbot/assets/docker-compose-models.yml
Normal file
@ -0,0 +1,160 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
x-build-base: &build-base
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.llamacpp
|
||||
image: local/llama.cpp:server-cuda
|
||||
|
||||
services:
|
||||
qwen2.5-vl:
|
||||
image: nvcr.io/nvidia/tensorrt-llm/release:spark-single-gpu-dev
|
||||
container_name: qwen2.5-vl
|
||||
shm_size: '1g'
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
environment:
|
||||
- TOKENIZERS_PARALLELISM=false
|
||||
- NCCL_P2P_LEVEL=SYS
|
||||
- NCCL_DEBUG=INFO
|
||||
- UCX_TLS=tcp,sm,self
|
||||
- UCX_MEMTYPE_CACHE=n
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
command: >
|
||||
trtllm-serve serve Qwen/Qwen2-VL-7B-Instruct
|
||||
--backend pytorch
|
||||
--host 0.0.0.0
|
||||
--port 8000
|
||||
--trust_remote_code
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 60s
|
||||
|
||||
qwen3-embedding:
|
||||
<<: *build-base
|
||||
container_name: qwen3-embedding
|
||||
volumes:
|
||||
- ./models:/models
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
command:
|
||||
- "-m"
|
||||
- "/models/Qwen3-Embedding-4B-Q8_0.gguf"
|
||||
- "--port"
|
||||
- "8000"
|
||||
- "--host"
|
||||
- "0.0.0.0"
|
||||
- "--jinja"
|
||||
- "--embeddings"
|
||||
|
||||
# Uncomment next block if you want to use gpt-oss-20b
|
||||
# gpt-oss-20b:
|
||||
# <<: *build-base
|
||||
# container_name: gpt-oss-20b
|
||||
# volumes:
|
||||
# - ./models:/models
|
||||
# deploy:
|
||||
# resources:
|
||||
# reservations:
|
||||
# devices:
|
||||
# - driver: nvidia
|
||||
# count: all
|
||||
# capabilities: [gpu]
|
||||
# command:
|
||||
# - "-m"
|
||||
# - "/models/gpt-oss-20b-mxfp4.gguf"
|
||||
# - "--port"
|
||||
# - "8000"
|
||||
# - "--host"
|
||||
# - "0.0.0.0"
|
||||
# - "-n"
|
||||
# - "2048"
|
||||
# - "--n-gpu-layers"
|
||||
# - "999"
|
||||
# - "--jinja"
|
||||
|
||||
# Comment next block if you want to use gpt-oss-20b
|
||||
gpt-oss-120b:
|
||||
<<: *build-base
|
||||
container_name: gpt-oss-120b
|
||||
volumes:
|
||||
- ./models:/models
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
command:
|
||||
- "-m"
|
||||
- "/models/gpt-oss-120b-mxfp4-00001-of-00003.gguf"
|
||||
- "--port"
|
||||
- "8000"
|
||||
- "--host"
|
||||
- "0.0.0.0"
|
||||
- "-n"
|
||||
- "4096"
|
||||
- "--n-gpu-layers"
|
||||
- "999"
|
||||
- "--jinja"
|
||||
|
||||
deepseek-coder:
|
||||
<<: *build-base
|
||||
container_name: deepseek-coder
|
||||
volumes:
|
||||
- ./models:/models
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
command:
|
||||
- "-m"
|
||||
- "/models/deepseek-coder-6.7b-instruct.Q8_0.gguf"
|
||||
- "--port"
|
||||
- "8000"
|
||||
- "--host"
|
||||
- "0.0.0.0"
|
||||
- "-n"
|
||||
- "256"
|
||||
- "--n-gpu-layers"
|
||||
- "999"
|
||||
- "--jinja"
|
||||
|
||||
volumes:
|
||||
ollama-data:
|
||||
|
||||
networks:
|
||||
default:
|
||||
name: chatbot-net
|
||||
135
nvidia/multi-agent-chatbot/assets/docker-compose.yml
Normal file
@ -0,0 +1,135 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
services:
|
||||
backend:
|
||||
container_name: backend
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
depends_on:
|
||||
- postgres
|
||||
- etcd
|
||||
- minio
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- POSTGRES_HOST=postgres
|
||||
- POSTGRES_DB=chatbot
|
||||
- POSTGRES_USER=chatbot_user
|
||||
- POSTGRES_PASSWORD=chatbot_password
|
||||
- ETCD_ENDPOINTS=etcd:2379
|
||||
- MINIO_ADDRESS=minio:9000
|
||||
- MILVUS_ADDRESS=milvus:19530
|
||||
- MODELS=gpt-oss-120b #,gpt-oss-20b
|
||||
|
||||
frontend:
|
||||
container_name: frontend
|
||||
build:
|
||||
context: ./frontend
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- /frontend:/app/frontend
|
||||
- /app/node_modules
|
||||
depends_on:
|
||||
- backend
|
||||
restart: unless-stopped
|
||||
|
||||
|
||||
postgres:
|
||||
container_name: postgres
|
||||
image: postgres:15-alpine
|
||||
ports:
|
||||
- "5432:5432"
|
||||
environment:
|
||||
- POSTGRES_DB=chatbot
|
||||
- POSTGRES_USER=chatbot_user
|
||||
- POSTGRES_PASSWORD=chatbot_password
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U chatbot_user -d chatbot"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
restart: unless-stopped
|
||||
|
||||
etcd:
|
||||
container_name: milvus-etcd
|
||||
image: quay.io/coreos/etcd:v3.5.5
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
- ETCD_SNAPSHOT_COUNT=50000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
healthcheck:
|
||||
test: ["CMD", "etcdctl", "endpoint", "health"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
minio:
|
||||
container_name: milvus-minio
|
||||
image: minio/minio:RELEASE.2023-03-20T20-16-18Z
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
milvus:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.5.15-20250718-3a3b374f-gpu-arm64
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
MINIO_ADDRESS: minio:9000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
|
||||
interval: 30s
|
||||
start_period: 90s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
ports:
|
||||
- "19530:19530"
|
||||
- "9091:9091"
|
||||
depends_on:
|
||||
- "etcd"
|
||||
- "minio"
|
||||
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
|
||||
networks:
|
||||
default:
|
||||
name: chatbot-net
|
||||
29
nvidia/multi-agent-chatbot/assets/frontend/Dockerfile
Normal file
@ -0,0 +1,29 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
FROM node:20-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY package.json package-lock.json* ./
|
||||
|
||||
RUN npm install
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 3000
|
||||
|
||||
CMD ["npm", "run", "dev"]
|
||||
32
nvidia/multi-agent-chatbot/assets/frontend/eslint.config.mjs
Normal file
@ -0,0 +1,32 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import { dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
import { FlatCompat } from "@eslint/eslintrc";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
const compat = new FlatCompat({
|
||||
baseDirectory: __dirname,
|
||||
});
|
||||
|
||||
const eslintConfig = [
|
||||
...compat.extends("next/core-web-vitals", "next/typescript"),
|
||||
];
|
||||
|
||||
export default eslintConfig;
|
||||
21
nvidia/multi-agent-chatbot/assets/frontend/next-env.d.ts
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
/// <reference types="next" />
|
||||
/// <reference types="next/image-types/global" />
|
||||
|
||||
// NOTE: This file should not be edited
|
||||
// see https://nextjs.org/docs/app/api-reference/config/typescript for more information.
|
||||
30
nvidia/multi-agent-chatbot/assets/frontend/next.config.ts
Normal file
@ -0,0 +1,30 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import type { NextConfig } from "next";
|
||||
|
||||
const nextConfig: NextConfig = {
|
||||
async rewrites() {
|
||||
return [
|
||||
{
|
||||
source: '/api/:path*',
|
||||
destination: 'http://backend:8000/:path*',
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
export default nextConfig;
|
||||
7679
nvidia/multi-agent-chatbot/assets/frontend/package-lock.json
generated
Normal file
30
nvidia/multi-agent-chatbot/assets/frontend/package.json
Normal file
@ -0,0 +1,30 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"next": "15.1.7",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-syntax-highlighter": "^15.6.1",
|
||||
"remark-gfm": "^4.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/eslintrc": "^3",
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "15.1.7",
|
||||
"postcss": "^8",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"typescript": "^5"
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,24 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
/** @type {import('postcss-load-config').Config} */
|
||||
const config = {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
},
|
||||
};
|
||||
|
||||
export default config;
|
||||
@ -0,0 +1 @@
|
||||
<svg fill="none" viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg"><path d="M14.5 13.5V5.41a1 1 0 0 0-.3-.7L9.8.29A1 1 0 0 0 9.08 0H1.5v13.5A2.5 2.5 0 0 0 4 16h8a2.5 2.5 0 0 0 2.5-2.5m-1.5 0v-7H8v-5H3v12a1 1 0 0 0 1 1h8a1 1 0 0 0 1-1M9.5 5V2.12L12.38 5zM5.13 5h-.62v1.25h2.12V5zm-.62 3h7.12v1.25H4.5zm.62 3h-.62v1.25h7.12V11z" clip-rule="evenodd" fill="#666" fill-rule="evenodd"/></svg>
|
||||
|
After Width: | Height: | Size: 391 B |
@ -0,0 +1 @@
|
||||
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><g clip-path="url(#a)"><path fill-rule="evenodd" clip-rule="evenodd" d="M10.27 14.1a6.5 6.5 0 0 0 3.67-3.45q-1.24.21-2.7.34-.31 1.83-.97 3.1M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16m.48-1.52a7 7 0 0 1-.96 0H7.5a4 4 0 0 1-.84-1.32q-.38-.89-.63-2.08a40 40 0 0 0 3.92 0q-.25 1.2-.63 2.08a4 4 0 0 1-.84 1.31zm2.94-4.76q1.66-.15 2.95-.43a7 7 0 0 0 0-2.58q-1.3-.27-2.95-.43a18 18 0 0 1 0 3.44m-1.27-3.54a17 17 0 0 1 0 3.64 39 39 0 0 1-4.3 0 17 17 0 0 1 0-3.64 39 39 0 0 1 4.3 0m1.1-1.17q1.45.13 2.69.34a6.5 6.5 0 0 0-3.67-3.44q.65 1.26.98 3.1M8.48 1.5l.01.02q.41.37.84 1.31.38.89.63 2.08a40 40 0 0 0-3.92 0q.25-1.2.63-2.08a4 4 0 0 1 .85-1.32 7 7 0 0 1 .96 0m-2.75.4a6.5 6.5 0 0 0-3.67 3.44 29 29 0 0 1 2.7-.34q.31-1.83.97-3.1M4.58 6.28q-1.66.16-2.95.43a7 7 0 0 0 0 2.58q1.3.27 2.95.43a18 18 0 0 1 0-3.44m.17 4.71q-1.45-.12-2.69-.34a6.5 6.5 0 0 0 3.67 3.44q-.65-1.27-.98-3.1" fill="#666"/></g><defs><clipPath id="a"><path fill="#fff" d="M0 0h16v16H0z"/></clipPath></defs></svg>
|
||||
|
After Width: | Height: | Size: 1.0 KiB |
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 394 80"><path fill="#000" d="M262 0h68.5v12.7h-27.2v66.6h-13.6V12.7H262V0ZM149 0v12.7H94v20.4h44.3v12.6H94v21h55v12.6H80.5V0h68.7zm34.3 0h-17.8l63.8 79.4h17.9l-32-39.7 32-39.6h-17.9l-23 28.6-23-28.6zm18.3 56.7-9-11-27.1 33.7h17.8l18.3-22.7z"/><path fill="#000" d="M81 79.3 17 0H0v79.3h13.6V17l50.2 62.3H81Zm252.6-.4c-1 0-1.8-.4-2.5-1s-1.1-1.6-1.1-2.6.3-1.8 1-2.5 1.6-1 2.6-1 1.8.3 2.5 1a3.4 3.4 0 0 1 .6 4.3 3.7 3.7 0 0 1-3 1.8zm23.2-33.5h6v23.3c0 2.1-.4 4-1.3 5.5a9.1 9.1 0 0 1-3.8 3.5c-1.6.8-3.5 1.3-5.7 1.3-2 0-3.7-.4-5.3-1s-2.8-1.8-3.7-3.2c-.9-1.3-1.4-3-1.4-5h6c.1.8.3 1.6.7 2.2s1 1.2 1.6 1.5c.7.4 1.5.5 2.4.5 1 0 1.8-.2 2.4-.6a4 4 0 0 0 1.6-1.8c.3-.8.5-1.8.5-3V45.5zm30.9 9.1a4.4 4.4 0 0 0-2-3.3 7.5 7.5 0 0 0-4.3-1.1c-1.3 0-2.4.2-3.3.5-.9.4-1.6 1-2 1.6a3.5 3.5 0 0 0-.3 4c.3.5.7.9 1.3 1.2l1.8 1 2 .5 3.2.8c1.3.3 2.5.7 3.7 1.2a13 13 0 0 1 3.2 1.8 8.1 8.1 0 0 1 3 6.5c0 2-.5 3.7-1.5 5.1a10 10 0 0 1-4.4 3.5c-1.8.8-4.1 1.2-6.8 1.2-2.6 0-4.9-.4-6.8-1.2-2-.8-3.4-2-4.5-3.5a10 10 0 0 1-1.7-5.6h6a5 5 0 0 0 3.5 4.6c1 .4 2.2.6 3.4.6 1.3 0 2.5-.2 3.5-.6 1-.4 1.8-1 2.4-1.7a4 4 0 0 0 .8-2.4c0-.9-.2-1.6-.7-2.2a11 11 0 0 0-2.1-1.4l-3.2-1-3.8-1c-2.8-.7-5-1.7-6.6-3.2a7.2 7.2 0 0 1-2.4-5.7 8 8 0 0 1 1.7-5 10 10 0 0 1 4.3-3.5c2-.8 4-1.2 6.4-1.2 2.3 0 4.4.4 6.2 1.2 1.8.8 3.2 2 4.3 3.4 1 1.4 1.5 3 1.5 5h-5.8z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
@ -0,0 +1 @@
|
||||
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1155 1000"><path d="m577.3 0 577.4 1000H0z" fill="#fff"/></svg>
|
||||
|
After Width: | Height: | Size: 128 B |
@ -0,0 +1 @@
|
||||
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path fill-rule="evenodd" clip-rule="evenodd" d="M1.5 2.5h13v10a1 1 0 0 1-1 1h-11a1 1 0 0 1-1-1zM0 1h16v11.5a2.5 2.5 0 0 1-2.5 2.5h-11A2.5 2.5 0 0 1 0 12.5zm3.75 4.5a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5M7 4.75a.75.75 0 1 1-1.5 0 .75.75 0 0 1 1.5 0m1.75.75a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5" fill="#666"/></svg>
|
||||
|
After Width: | Height: | Size: 385 B |
BIN
nvidia/multi-agent-chatbot/assets/frontend/src/app/favicon.ico
Normal file
|
After Width: | Height: | Size: 25 KiB |
@ -0,0 +1,94 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
:root {
|
||||
--background: white;
|
||||
--foreground: #1e293b;
|
||||
--primary: #76B900;
|
||||
--primary-dark: #669f00;
|
||||
--secondary: #f0f0f0;
|
||||
--accent: #f56565;
|
||||
--border: #e2e8f0;
|
||||
--shimmer-base: #a3a3a3;
|
||||
--shimmer-peak: #ffffff;
|
||||
--placeholder-opacity-min: .6;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--background: #0f172a;
|
||||
--foreground: #f8fafc;
|
||||
--primary: #76B900;
|
||||
--primary-dark: #669f00;
|
||||
--secondary: #1e293b;
|
||||
--accent: #ef4444;
|
||||
--border: #334155;
|
||||
--shimmer-base: #444;
|
||||
--shimmer-peak: #222;
|
||||
--placeholder-opacity-min: .4;
|
||||
}
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
html,
|
||||
body {
|
||||
max-width: 100vw;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
body {
|
||||
background-color: var(--background);
|
||||
color: var(--foreground);
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
||||
}
|
||||
|
||||
a {
|
||||
color: inherit;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
button, input, select, textarea {
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
/* Custom scrollbar */
|
||||
::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: #cbd5e0;
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: #a0aec0;
|
||||
}
|
||||
@ -0,0 +1,42 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import type { Metadata } from "next";
|
||||
import { Inter } from "next/font/google";
|
||||
import "./globals.css";
|
||||
import ThemeToggle from "@/components/ThemeToggle";
|
||||
|
||||
const inter = Inter({ subsets: ["latin"] });
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Spark Chat",
|
||||
description: "AI-powered chat interface",
|
||||
};
|
||||
|
||||
export default function RootLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<html lang="en" className="h-full">
|
||||
<body className={`${inter.className} h-full bg-white dark:bg-gray-900 text-gray-900 dark:text-white transition-colors duration-200`}>
|
||||
<ThemeToggle />
|
||||
{children}
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
132
nvidia/multi-agent-chatbot/assets/frontend/src/app/page.tsx
Normal file
@ -0,0 +1,132 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
"use client";
|
||||
import { useState, useRef, useEffect } from 'react';
|
||||
import QuerySection from '@/components/QuerySection';
|
||||
import DocumentIngestion from '@/components/DocumentIngestion';
|
||||
import Sidebar from '@/components/Sidebar';
|
||||
import styles from '@/styles/Home.module.css';
|
||||
|
||||
export default function Home() {
|
||||
const [query, setQuery] = useState("");
|
||||
const [response, setResponse] = useState("[]");
|
||||
const [files, setFiles] = useState<FileList | null>(null);
|
||||
const [ingestMessage, setIngestMessage] = useState("");
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [isIngesting, setIsIngesting] = useState(false);
|
||||
const [showIngestion, setShowIngestion] = useState(false);
|
||||
const [refreshTrigger, setRefreshTrigger] = useState(0);
|
||||
const [currentChatId, setCurrentChatId] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
// Load initial chat ID
|
||||
useEffect(() => {
|
||||
const fetchCurrentChatId = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/chat_id");
|
||||
if (response.ok) {
|
||||
const { chat_id } = await response.json();
|
||||
setCurrentChatId(chat_id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching current chat ID:", error);
|
||||
}
|
||||
};
|
||||
fetchCurrentChatId();
|
||||
}, []);
|
||||
|
||||
// Handle chat changes
|
||||
const handleChatChange = async (newChatId: string) => {
|
||||
try {
|
||||
const response = await fetch("/api/chat_id", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ chat_id: newChatId })
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
setCurrentChatId(newChatId);
|
||||
setResponse("[]"); // Clear current chat messages with empty JSON array
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating chat ID:", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Clean up any ongoing streams when component unmounts
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Function to handle successful document ingestion
|
||||
const handleSuccessfulIngestion = () => {
|
||||
setRefreshTrigger(prev => prev + 1);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<Sidebar
|
||||
showIngestion={showIngestion}
|
||||
setShowIngestion={setShowIngestion}
|
||||
refreshTrigger={refreshTrigger}
|
||||
currentChatId={currentChatId}
|
||||
onChatChange={handleChatChange}
|
||||
/>
|
||||
|
||||
<div className={styles.mainContent}>
|
||||
<QuerySection
|
||||
query={query}
|
||||
response={response}
|
||||
isStreaming={isStreaming}
|
||||
setQuery={setQuery}
|
||||
setResponse={setResponse}
|
||||
setIsStreaming={setIsStreaming}
|
||||
abortControllerRef={abortControllerRef}
|
||||
setShowIngestion={setShowIngestion}
|
||||
currentChatId={currentChatId}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showIngestion && (
|
||||
<>
|
||||
<div className={styles.overlay} onClick={() => setShowIngestion(false)} />
|
||||
<div className={styles.documentUploadContainer}>
|
||||
<button
|
||||
className={styles.closeButton}
|
||||
onClick={() => setShowIngestion(false)}
|
||||
>
|
||||
×
|
||||
</button>
|
||||
<DocumentIngestion
|
||||
files={files}
|
||||
ingestMessage={ingestMessage}
|
||||
isIngesting={isIngesting}
|
||||
setFiles={setFiles}
|
||||
setIngestMessage={setIngestMessage}
|
||||
setIsIngesting={setIsIngesting}
|
||||
onSuccessfulIngestion={handleSuccessfulIngestion}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,141 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import { SetStateAction, useState } from 'react';
|
||||
import styles from '@/styles/DocumentIngestion.module.css';
|
||||
|
||||
declare module 'react' {
|
||||
interface InputHTMLAttributes<T> extends HTMLAttributes<T> {
|
||||
webkitdirectory?: string;
|
||||
directory?: string;
|
||||
}
|
||||
}
|
||||
|
||||
interface DocumentIngestionProps {
|
||||
files: FileList | null;
|
||||
ingestMessage: string;
|
||||
isIngesting: boolean;
|
||||
setFiles: (files: FileList | null) => void;
|
||||
setIngestMessage: (message: string) => void;
|
||||
setIsIngesting: (value: boolean) => void;
|
||||
onSuccessfulIngestion?: () => void;
|
||||
}
|
||||
|
||||
export default function DocumentIngestion({
|
||||
files,
|
||||
ingestMessage,
|
||||
isIngesting,
|
||||
setFiles,
|
||||
setIngestMessage,
|
||||
setIsIngesting,
|
||||
onSuccessfulIngestion
|
||||
}: DocumentIngestionProps) {
|
||||
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setFiles(e.target.files);
|
||||
};
|
||||
|
||||
const handleIngestSubmit = async (e: { preventDefault: () => void }) => {
|
||||
e.preventDefault();
|
||||
setIsIngesting(true);
|
||||
setIngestMessage("");
|
||||
|
||||
try {
|
||||
if (files && files.length > 0) {
|
||||
const formData = new FormData();
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
formData.append("files", files[i]);
|
||||
}
|
||||
|
||||
const res = await fetch("/api/ingest", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const data = await res.json();
|
||||
setIngestMessage(data.message);
|
||||
|
||||
if (res.ok && onSuccessfulIngestion) {
|
||||
onSuccessfulIngestion();
|
||||
}
|
||||
} else {
|
||||
setIngestMessage("Please select files or specify a directory path.");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error during ingestion:", error);
|
||||
setIngestMessage("Error during ingestion. Please check the console for details.");
|
||||
} finally {
|
||||
setIsIngesting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (e.dataTransfer.files && e.dataTransfer.files.length > 0) {
|
||||
setFiles(e.dataTransfer.files);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={styles.section}>
|
||||
<h1>Document Ingestion</h1>
|
||||
<form onSubmit={handleIngestSubmit} className={styles.ingestForm}>
|
||||
<div
|
||||
className={styles.uploadSection}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<label htmlFor="file-upload" className={styles.customFileLabel}>
|
||||
Choose Files
|
||||
</label>
|
||||
<input
|
||||
id="file-upload"
|
||||
type="file"
|
||||
multiple
|
||||
onChange={handleFileChange}
|
||||
disabled={isIngesting}
|
||||
className={styles.fileInput}
|
||||
/>
|
||||
<span className={styles.fileName}>
|
||||
{files && files.length > 0 ? Array.from(files).map(f => f.name).join(', ') : "No file chosen"}
|
||||
</span>
|
||||
<p className={styles.helpText}>
|
||||
Select files or drag and drop them here
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isIngesting || !files}
|
||||
className={styles.ingestButton}
|
||||
>
|
||||
{isIngesting ? "Ingesting..." : "Ingest Documents"}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
{ingestMessage && (
|
||||
<div className={styles.messageContainer}>
|
||||
<p>{ingestMessage}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,657 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import type React from "react";
|
||||
import { useRef, useEffect, useState } from "react";
|
||||
import styles from "@/styles/QuerySection.module.css";
|
||||
import ReactMarkdown from 'react-markdown'; // NEW
|
||||
import remarkGfm from 'remark-gfm'; // NEW
|
||||
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; // NEW
|
||||
import { oneDark, oneLight, Dark, Light } from "react-syntax-highlighter/dist/esm/styles/prism"; // NEW
|
||||
import WelcomeSection from "./WelcomeSection";
|
||||
const isDark = typeof document !== "undefined" && document.documentElement.classList.contains("dark");
|
||||
|
||||
|
||||
export function makeChatTheme(isDark: boolean) {
|
||||
const base = isDark ? oneDark : oneLight;
|
||||
|
||||
const accents = isDark
|
||||
? {
|
||||
tag: "#E3E3E3",
|
||||
prolog: "#E3E3E3",
|
||||
doctype: "#E3E3E3",
|
||||
punctuation:"#99CFCF",
|
||||
}
|
||||
: {
|
||||
tag: "#9a6700",
|
||||
prolog: "#7a6200",
|
||||
doctype: "#7a6200",
|
||||
punctuation:"#6b7280",
|
||||
};
|
||||
|
||||
return {
|
||||
...base,
|
||||
|
||||
'pre[class*="language-"]': {
|
||||
...(base['pre[class*="language-"]'] || {}),
|
||||
background: "transparent",
|
||||
},
|
||||
'code[class*="language-"]': {
|
||||
...(base['code[class*="language-"]'] || {}),
|
||||
background: "transparent",
|
||||
},
|
||||
|
||||
tag: { ...(base.tag || {}), color: accents.tag },
|
||||
prolog: { ...(base.prolog || {}), color: accents.prolog },
|
||||
doctype: { ...(base.doctype || {}), color: accents.doctype },
|
||||
punctuation: { ...(base.punctuation || {}), color: accents.punctuation },
|
||||
|
||||
'attr-name': { ...(base['attr-name'] || {}), color: isDark ? "#e6b450" : "#6b4f00" },
|
||||
} as const;
|
||||
}
|
||||
const theme = makeChatTheme(isDark);
|
||||
|
||||
function CodeBlockWithCopy({ code, language }: { code: string; language: string }) {
|
||||
const [copied, setCopied] = useState(false);
|
||||
|
||||
const handleCopy = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(code);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 1200);
|
||||
} catch (err) {
|
||||
try {
|
||||
const textarea = document.createElement("textarea");
|
||||
textarea.value = code;
|
||||
textarea.style.position = "fixed";
|
||||
textarea.style.left = "-9999px";
|
||||
document.body.appendChild(textarea);
|
||||
textarea.focus();
|
||||
textarea.select();
|
||||
document.execCommand("copy");
|
||||
document.body.removeChild(textarea);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 1200);
|
||||
} catch {}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={styles.codeBlock}>
|
||||
<button
|
||||
type="button"
|
||||
className={styles.copyButton}
|
||||
onClick={handleCopy}
|
||||
aria-label="Copy code"
|
||||
title={copied ? "Copied" : "Copy"}
|
||||
>
|
||||
<svg
|
||||
className={styles.copyButtonIcon}
|
||||
viewBox="0 0 460 460"
|
||||
aria-hidden="true"
|
||||
focusable="false"
|
||||
fill="currentColor"
|
||||
>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path d="M425.934,0H171.662c-18.122,0-32.864,14.743-32.864,32.864v77.134h30V32.864c0-1.579,1.285-2.864,2.864-2.864h254.272
|
||||
c1.579,0,2.864,1.285,2.864,2.864v254.272c0,1.58-1.285,2.865-2.864,2.865h-74.729v30h74.729
|
||||
c18.121,0,32.864-14.743,32.864-32.865V32.864C458.797,14.743,444.055,0,425.934,0z"/>
|
||||
<path d="M288.339,139.998H34.068c-18.122,0-32.865,14.743-32.865,32.865v254.272C1.204,445.257,15.946,460,34.068,460h254.272
|
||||
c18.122,0,32.865-14.743,32.865-32.864V172.863C321.206,154.741,306.461,139.998,288.339,139.998z M288.341,430H34.068
|
||||
c-1.58,0-2.865-1.285-2.865-2.864V172.863c0-1.58,1.285-2.865,2.865-2.865h254.272c1.58,0,2.865,1.285,2.865,2.865v254.273h0.001
|
||||
C291.206,428.715,289.92,430,288.341,430z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
<span className={styles.copyButtonLabel}>{copied ? "Copied" : "Copy"}</span>
|
||||
</button>
|
||||
<SyntaxHighlighter
|
||||
language={language}
|
||||
style={theme}
|
||||
PreTag="div"
|
||||
wrapLongLines
|
||||
showLineNumbers={false}
|
||||
customStyle={{ margin: "0.6rem 0", borderRadius: 10, background: "transparent" }}
|
||||
>
|
||||
{code}
|
||||
</SyntaxHighlighter>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
interface QuerySectionProps {
|
||||
query: string;
|
||||
response: string;
|
||||
isStreaming: boolean;
|
||||
setQuery: (value: string) => void;
|
||||
setResponse: React.Dispatch<React.SetStateAction<string>>;
|
||||
setIsStreaming: (value: boolean) => void;
|
||||
abortControllerRef: React.RefObject<AbortController | null>;
|
||||
setShowIngestion: (value: boolean) => void;
|
||||
currentChatId: string | null;
|
||||
}
|
||||
|
||||
interface Message {
|
||||
type: "HumanMessage" | "AssistantMessage" | "ToolMessage";
|
||||
content: string;
|
||||
}
|
||||
|
||||
|
||||
|
||||
export default function QuerySection({
|
||||
query,
|
||||
response,
|
||||
isStreaming,
|
||||
setQuery,
|
||||
setResponse,
|
||||
setIsStreaming,
|
||||
abortControllerRef,
|
||||
setShowIngestion,
|
||||
currentChatId,
|
||||
}: QuerySectionProps) {
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const chatContainerRef = useRef<HTMLDivElement>(null);
|
||||
const [showButtons, setShowButtons] = useState(false);
|
||||
const [showWelcome, setShowWelcome] = useState(true);
|
||||
const [inferenceStats, setInferenceStats] = useState({
|
||||
tokensReceived: 0,
|
||||
startTime: Date.now(),
|
||||
tokensPerSecond: 0
|
||||
});
|
||||
const [selectedSources, setSelectedSources] = useState<string[]>([]);
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const [uploadedImage, setUploadedImage] = useState<string | null>(null);
|
||||
const [imagePreview, setImagePreview] = useState<string | null>(null);
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [toolOutput, setToolOutput] = useState("");
|
||||
const [graphStatus, setGraphStatus] = useState("");
|
||||
const [isPinnedToolOutputVisible, setPinnedToolOutputVisible] = useState(false);
|
||||
const [isToolContentVisible, setIsToolContentVisible] = useState(false);
|
||||
const [fadeIn, setFadeIn] = useState(false);
|
||||
const firstTokenReceived = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
setShowButtons(true);
|
||||
}, 800);
|
||||
return () => clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isStreaming) {
|
||||
setInferenceStats(prev => ({
|
||||
...prev,
|
||||
tokensReceived: 0,
|
||||
startTime: 0
|
||||
}));
|
||||
}
|
||||
}, [isStreaming]);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchSelectedSources = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/selected_sources");
|
||||
if (response.ok) {
|
||||
const { sources } = await response.json();
|
||||
setSelectedSources(sources);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching selected sources:", error);
|
||||
}
|
||||
};
|
||||
fetchSelectedSources();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const initWebSocket = async () => {
|
||||
if (!currentChatId) return;
|
||||
|
||||
try {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
}
|
||||
|
||||
const wsProtocol = 'ws:';
|
||||
const wsHost = 'localhost';
|
||||
const wsPort = '8000';
|
||||
const ws = new WebSocket(`${wsProtocol}//${wsHost}:${wsPort}/ws/chat/${currentChatId}`);
|
||||
wsRef.current = ws;
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const msg = JSON.parse(event.data);
|
||||
const type = msg.type
|
||||
const text = msg.data ?? msg.token ?? "";
|
||||
|
||||
switch (type) {
|
||||
case "history": {
|
||||
console.log('history messages: ', msg.messages);
|
||||
if (Array.isArray(msg.messages)) {
|
||||
// const filtered = msg.messages.filter(m => m.type !== "ToolMessage"); // TODO: add this back in
|
||||
setResponse(JSON.stringify(msg.messages));
|
||||
setIsStreaming(false);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "tool_token": {
|
||||
if (text !== undefined && text !== "undefined") {
|
||||
setToolOutput(prev => prev + text);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "token": {
|
||||
if (!text) break;
|
||||
if (!firstTokenReceived.current) {
|
||||
console.log('TTFT: ', new Date().toISOString());
|
||||
firstTokenReceived.current = true;
|
||||
setIsStreaming(false);
|
||||
}
|
||||
setResponse(prev => {
|
||||
try {
|
||||
const messages = JSON.parse(prev);
|
||||
const last = messages[messages.length - 1];
|
||||
if (last && last.type === "AssistantMessage") {
|
||||
last.content = String(last.content || "") + text;
|
||||
} else {
|
||||
messages.push({ type: "AssistantMessage", content: text });
|
||||
}
|
||||
return JSON.stringify(messages);
|
||||
} catch {
|
||||
return String(prev || "") + text;
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "node_start": {
|
||||
if (msg?.data === "generate") {
|
||||
setGraphStatus("Thinking...");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "tool_start": {
|
||||
console.log(type, msg.data);
|
||||
setGraphStatus(`calling tool: ${msg?.data}`);
|
||||
break;
|
||||
}
|
||||
case "tool_end":
|
||||
case "node_end": {
|
||||
console.log(type, msg.data);
|
||||
if (msg.data === 'generate') {
|
||||
console.log('generate complete. time: ', new Date().toISOString());
|
||||
}
|
||||
setGraphStatus("");
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// ignore unknown events
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
console.log("WebSocket connection closed");
|
||||
setIsStreaming(false);
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
console.error("WebSocket error:", error);
|
||||
setIsStreaming(false);
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error initializing WebSocket:", error);
|
||||
setIsStreaming(false);
|
||||
}
|
||||
};
|
||||
|
||||
initWebSocket();
|
||||
|
||||
return () => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
setIsStreaming(false);
|
||||
}
|
||||
};
|
||||
}, [currentChatId]);
|
||||
|
||||
useEffect(() => {
|
||||
try {
|
||||
const messages = JSON.parse(response);
|
||||
setShowWelcome(messages.length === 0);
|
||||
} catch {
|
||||
// If response can't be parsed as JSON, check if it's empty
|
||||
setShowWelcome(!response.trim());
|
||||
}
|
||||
}, [response]);
|
||||
|
||||
// Show/hide pinnedToolOutput with fade
|
||||
useEffect(() => {
|
||||
if (graphStatus) {
|
||||
setPinnedToolOutputVisible(true);
|
||||
} else if (isPinnedToolOutputVisible) {
|
||||
// Delay hiding to allow fade-out
|
||||
const timeout = setTimeout(() => {
|
||||
setPinnedToolOutputVisible(false);
|
||||
}, 800); // match CSS transition duration
|
||||
return () => clearTimeout(timeout);
|
||||
}
|
||||
}, [graphStatus]);
|
||||
|
||||
// Replace the effect for fade logic with this minimal version
|
||||
useEffect(() => {
|
||||
if (isPinnedToolOutputVisible && graphStatus) {
|
||||
setFadeIn(false);
|
||||
const t = setTimeout(() => setFadeIn(true), 10); // next tick for fade-in
|
||||
return () => clearTimeout(t);
|
||||
} else {
|
||||
setFadeIn(false);
|
||||
}
|
||||
}, [isPinnedToolOutputVisible, graphStatus]);
|
||||
|
||||
// Cleanup image preview URL on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (imagePreview) {
|
||||
URL.revokeObjectURL(imagePreview);
|
||||
}
|
||||
};
|
||||
}, [imagePreview]);
|
||||
|
||||
const programmaticScroll = useRef(false);
|
||||
const scrollTimeout = useRef<number | null>(null);
|
||||
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(true); };
|
||||
const handleDragLeave = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(false); };
|
||||
const handleDragOver = (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); };
|
||||
|
||||
const handleDrop = async (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
const imageFile = files.find(file => file.type.startsWith('image/'));
|
||||
|
||||
if (imageFile) {
|
||||
const previewUrl = URL.createObjectURL(imageFile);
|
||||
setImagePreview(previewUrl);
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('image', imageFile);
|
||||
formData.append('chat_id', currentChatId || '');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/upload-image', { method: 'POST', body: formData });
|
||||
const result = await response.json();
|
||||
setUploadedImage(result.image_id);
|
||||
} catch (error) {
|
||||
console.error('Error uploading image:', error);
|
||||
URL.revokeObjectURL(previewUrl);
|
||||
setImagePreview(null);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleQuerySubmit = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
const currentQuery = query.trim();
|
||||
if (!currentQuery || isStreaming || !wsRef.current) return;
|
||||
|
||||
setQuery("");
|
||||
setIsStreaming(true);
|
||||
firstTokenReceived.current = false;
|
||||
|
||||
try {
|
||||
console.log('sending uploaded image: ', uploadedImage, ' with query: ', currentQuery)
|
||||
console.log('current time: ', new Date().toISOString());
|
||||
wsRef.current.send(JSON.stringify({
|
||||
message: currentQuery,
|
||||
image_id: uploadedImage
|
||||
}));
|
||||
|
||||
setResponse(prev => {
|
||||
try {
|
||||
const messages = JSON.parse(prev);
|
||||
messages.push({
|
||||
type: "HumanMessage",
|
||||
content: currentQuery
|
||||
});
|
||||
return JSON.stringify(messages);
|
||||
} catch {
|
||||
return prev + `\n\nHuman: ${currentQuery}\n\nAssistant: `;
|
||||
}
|
||||
});
|
||||
|
||||
// NEW CODE
|
||||
if (imagePreview) {
|
||||
URL.revokeObjectURL(imagePreview);
|
||||
}
|
||||
setUploadedImage(null);
|
||||
setImagePreview(null);
|
||||
// NEW CODE
|
||||
} catch (error) {
|
||||
console.error("Error sending message:", error);
|
||||
setIsStreaming(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCancelStream = () => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
wsRef.current = null;
|
||||
setIsStreaming(false);
|
||||
}
|
||||
};
|
||||
|
||||
// filter out all ToolMessages
|
||||
const parseMessages = (response: string): Message[] => {
|
||||
try {
|
||||
const parsed = JSON.parse(response);
|
||||
if (!Array.isArray(parsed)) return [];
|
||||
|
||||
return parsed
|
||||
.map((msg: any): Message => ({
|
||||
type: msg?.type === "HumanMessage"
|
||||
? "HumanMessage"
|
||||
: msg?.type === "ToolMessage"
|
||||
? "ToolMessage"
|
||||
: "AssistantMessage",
|
||||
content: typeof msg?.content === "string" ? msg.content : String(msg?.content ?? "")
|
||||
}))
|
||||
.filter((msg) => msg.type !== "ToolMessage"); // discard ToolMessage completely
|
||||
} catch {
|
||||
if (!response?.trim()) return [];
|
||||
|
||||
return [{ type: "AssistantMessage", content: String(response) }];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
return (
|
||||
<div className={styles.chatContainer}>
|
||||
{showWelcome && <WelcomeSection setQuery={setQuery}/>}
|
||||
|
||||
{/* Minimal fade-in/fade-out: always start hidden, fade in on next tick */}
|
||||
{/* {isPinnedToolOutputVisible && ( )}*/}
|
||||
<div className={`${styles.pinnedToolOutput} ${!fadeIn ? styles.pinnedToolOutputHidden : ""}`}>
|
||||
{graphStatus && (
|
||||
<div className={styles.toolHeader} onClick={() => setIsToolContentVisible(v => !v)} style={{ cursor: 'pointer' }}>
|
||||
<span className={styles.toolLabel}> {graphStatus} </span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className={styles.messagesContainer} ref={chatContainerRef}>
|
||||
{parseMessages(response).map((message, index) => {
|
||||
const isHuman = message.type === "HumanMessage";
|
||||
const key = `${message.type}-${index}`;
|
||||
|
||||
if (!message.content?.trim()) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={key}
|
||||
className={`${styles.messageWrapper} ${isHuman ? styles.userMessageWrapper : styles.assistantMessageWrapper}`}
|
||||
style={{
|
||||
animationDelay: `${index * 0.1}s`
|
||||
}}
|
||||
>
|
||||
|
||||
<div className={`${styles.message} ${isHuman ? styles.userMessage : styles.assistantMessage}`}>
|
||||
<div className={styles.markdown}>
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[remarkGfm]}
|
||||
components={{
|
||||
code({ inline, className, children, ...props }) {
|
||||
const match = /language-(\w+)/.exec(className || "");
|
||||
const code = String(children ?? "").replace(/\n$/, "");
|
||||
|
||||
if (inline || !match) {
|
||||
return (
|
||||
<code className={className} {...props}>
|
||||
{code}
|
||||
</code>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<CodeBlockWithCopy code={code} language={match[1]} />
|
||||
);
|
||||
},
|
||||
}}
|
||||
>
|
||||
{message.content}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
{isStreaming && (
|
||||
<div
|
||||
className={`${styles.messageWrapper} ${styles.assistantMessageWrapper}`}
|
||||
style={{
|
||||
animationDelay: `${parseMessages(response).length * 0.1}s`
|
||||
}}
|
||||
>
|
||||
<div className={`${styles.message} ${styles.assistantMessage}`}>
|
||||
<div className={styles.typingIndicator}>
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleQuerySubmit} className={styles.inputContainer}>
|
||||
{/* NEW CODE - Image preview moved to the left of inputWrapper */}
|
||||
{imagePreview && (
|
||||
<div className={styles.imagePreview}>
|
||||
<img
|
||||
src={imagePreview}
|
||||
alt="Image preview"
|
||||
className={styles.previewImage}
|
||||
/>
|
||||
<button
|
||||
className={styles.removeImageButton}
|
||||
onClick={() => {
|
||||
if (imagePreview) {
|
||||
URL.revokeObjectURL(imagePreview);
|
||||
}
|
||||
setUploadedImage(null);
|
||||
setImagePreview(null);
|
||||
}}
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* NEW CODE */}
|
||||
<div
|
||||
className={`${styles.inputWrapper} ${isDragging ? styles.dragging : ''}`}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowIngestion(true)}
|
||||
className={`${styles.uploadButton} ${showButtons ? styles.show : ''}`}
|
||||
title="Upload Documents"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
width="20"
|
||||
height="20"
|
||||
>
|
||||
<path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.49-8.49l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48" />
|
||||
</svg>
|
||||
</button>
|
||||
<textarea
|
||||
rows={1}
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder="Send a message..."
|
||||
disabled={isStreaming}
|
||||
className={styles.messageInput}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleQuerySubmit(e as any);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{!isStreaming ? (
|
||||
<button
|
||||
type="submit"
|
||||
className={`${styles.sendButton} ${showButtons ? styles.show : ''}`}
|
||||
disabled={!query.trim()}
|
||||
>
|
||||
→
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleCancelStream}
|
||||
className={`${styles.streamingCancelButton} ${showButtons ? styles.show : ''}`}
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
)}
|
||||
</form>
|
||||
|
||||
{inferenceStats.tokensPerSecond > 0 && (
|
||||
<div className={styles.inferenceStats}>
|
||||
{inferenceStats.tokensPerSecond} tokens/sec
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,696 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import styles from '@/styles/Sidebar.module.css';
|
||||
|
||||
interface Model {
|
||||
id: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
interface ChatMetadata {
|
||||
name: string;
|
||||
}
|
||||
|
||||
interface SidebarProps {
|
||||
showIngestion: boolean;
|
||||
setShowIngestion: (value: boolean) => void;
|
||||
refreshTrigger?: number;
|
||||
currentChatId: string | null;
|
||||
onChatChange: (chatId: string) => Promise<void>;
|
||||
}
|
||||
|
||||
|
||||
export default function Sidebar({
|
||||
showIngestion,
|
||||
setShowIngestion,
|
||||
refreshTrigger = 0,
|
||||
currentChatId,
|
||||
onChatChange
|
||||
}: SidebarProps) {
|
||||
const [isVisible, setIsVisible] = useState(false);
|
||||
const [isClosing, setIsClosing] = useState(false);
|
||||
const [expandedSections, setExpandedSections] = useState<Set<string>>(new Set(["config", "history"]));
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [availableSources, setAvailableSources] = useState<string[]>([]);
|
||||
const [selectedSources, setSelectedSources] = useState<string[]>([]);
|
||||
const [selectedModel, setSelectedModel] = useState<string>("");
|
||||
const [isLoadingSources, setIsLoadingSources] = useState(false);
|
||||
const [availableModels, setAvailableModels] = useState<Model[]>([]);
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(false);
|
||||
const [chats, setChats] = useState<string[]>([]);
|
||||
const [isLoadingChats, setIsLoadingChats] = useState(false);
|
||||
const [chatMetadata, setChatMetadata] = useState<Record<string, ChatMetadata>>({});
|
||||
|
||||
// Add ref for chat list
|
||||
const chatListRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Load initial configuration
|
||||
useEffect(() => {
|
||||
const loadInitialConfig = async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
|
||||
// Get selected model
|
||||
const modelResponse = await fetch("/api/selected_model");
|
||||
if (modelResponse.ok) {
|
||||
const { model } = await modelResponse.json();
|
||||
setSelectedModel(model);
|
||||
}
|
||||
|
||||
// Get selected sources
|
||||
const sourcesResponse = await fetch("/api/selected_sources");
|
||||
if (sourcesResponse.ok) {
|
||||
const { sources } = await sourcesResponse.json();
|
||||
setSelectedSources(sources);
|
||||
}
|
||||
|
||||
|
||||
// Get available models
|
||||
await fetchAvailableModels();
|
||||
|
||||
// Get sources
|
||||
await fetchSources();
|
||||
|
||||
// Get chats if history section is expanded (which it is by default)
|
||||
if (expandedSections.has('history')) {
|
||||
await fetchChats();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error loading initial config:", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadInitialConfig();
|
||||
}, []);
|
||||
|
||||
// Fetch available models
|
||||
const fetchAvailableModels = async () => {
|
||||
try {
|
||||
setIsLoadingModels(true);
|
||||
const response = await fetch("/api/available_models");
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error(`Error fetching available models: ${response.status} - ${errorText}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const models = data.models.map((modelId: string) => ({
|
||||
id: modelId,
|
||||
name: modelId.split('-').map(word => word.charAt(0).toUpperCase() + word.slice(1)).join(' ')
|
||||
}));
|
||||
setAvailableModels(models);
|
||||
} catch (error) {
|
||||
console.error("Error fetching available models:", error);
|
||||
} finally {
|
||||
setIsLoadingModels(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch available sources
|
||||
const fetchSources = async () => {
|
||||
try {
|
||||
setIsLoadingSources(true);
|
||||
console.log("Fetching sources...");
|
||||
const response = await fetch("/api/sources");
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error(`Error fetching sources: ${response.status} - ${errorText}`);
|
||||
setAvailableSources([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Sources fetched:", data.sources);
|
||||
setAvailableSources(data.sources || []);
|
||||
} catch (error) {
|
||||
console.error("Error fetching sources:", error);
|
||||
setAvailableSources([]);
|
||||
} finally {
|
||||
setIsLoadingSources(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Get sources on initial load and when the context section is expanded
|
||||
useEffect(() => {
|
||||
if (expandedSections.has('context')) {
|
||||
fetchSources();
|
||||
}
|
||||
}, [expandedSections]);
|
||||
|
||||
// Refresh sources when refreshTrigger changes (document ingestion)
|
||||
useEffect(() => {
|
||||
if (refreshTrigger > 0) { // Only refresh if not the initial render
|
||||
fetchSources();
|
||||
}
|
||||
}, [refreshTrigger]);
|
||||
|
||||
// Add function to fetch chat metadata
|
||||
const fetchChatMetadata = async (chatId: string) => {
|
||||
try {
|
||||
const response = await fetch(`/api/chat/${chatId}/metadata`);
|
||||
if (response.ok) {
|
||||
const metadata = await response.json();
|
||||
setChatMetadata(prev => ({
|
||||
...prev,
|
||||
[chatId]: metadata
|
||||
}));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error fetching metadata for chat ${chatId}:`, error);
|
||||
}
|
||||
};
|
||||
|
||||
// Update fetchChats to also fetch metadata
|
||||
const fetchChats = async () => {
|
||||
try {
|
||||
console.log("fetchChats: Starting to fetch chats...");
|
||||
setIsLoadingChats(true);
|
||||
const response = await fetch("/api/chats");
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
console.log("fetchChats: Received chats:", data.chats);
|
||||
setChats(data.chats);
|
||||
// Fetch metadata for each chat
|
||||
await Promise.all(data.chats.map(fetchChatMetadata));
|
||||
console.log("fetchChats: Completed fetching all chat metadata");
|
||||
} else {
|
||||
console.error("fetchChats: Failed to fetch chats, status:", response.status);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching chats:", error);
|
||||
} finally {
|
||||
setIsLoadingChats(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch chats when history section is expanded
|
||||
useEffect(() => {
|
||||
if (expandedSections.has('history')) {
|
||||
fetchChats();
|
||||
}
|
||||
}, [expandedSections]);
|
||||
|
||||
|
||||
// Update highlight position when currentChatId changes
|
||||
useEffect(() => {
|
||||
if (currentChatId && chatListRef.current) {
|
||||
const activeChat = chatListRef.current.querySelector(`.${styles.active}`) as HTMLElement;
|
||||
if (activeChat) {
|
||||
const offset = activeChat.offsetTop;
|
||||
chatListRef.current.style.setProperty('--highlight-offset', `${offset}px`);
|
||||
}
|
||||
}
|
||||
}, [currentChatId, chats]);
|
||||
|
||||
// Add new effect to handle initial position and chat list loading
|
||||
useEffect(() => {
|
||||
if (chatListRef.current && currentChatId && chats.length > 0) {
|
||||
// Small delay to ensure DOM is ready
|
||||
setTimeout(() => {
|
||||
const activeChat = chatListRef.current?.querySelector(`.${styles.active}`) as HTMLElement;
|
||||
if (activeChat) {
|
||||
const offset = activeChat.offsetTop;
|
||||
chatListRef.current?.style.setProperty('--highlight-offset', `${offset}px`);
|
||||
}
|
||||
}, 50);
|
||||
}
|
||||
}, [isVisible, expandedSections.has('history'), chats.length]);
|
||||
|
||||
const handleClose = () => {
|
||||
setIsClosing(true);
|
||||
setTimeout(() => {
|
||||
setIsVisible(false);
|
||||
setIsClosing(false);
|
||||
}, 500); // Match the new animation duration
|
||||
};
|
||||
|
||||
const toggleSidebar = () => {
|
||||
if (!isVisible) {
|
||||
setIsVisible(true);
|
||||
setIsClosing(false);
|
||||
fetchSources();
|
||||
// Also fetch chats when sidebar opens
|
||||
if (expandedSections.has('history')) {
|
||||
fetchChats();
|
||||
}
|
||||
} else {
|
||||
handleClose();
|
||||
}
|
||||
};
|
||||
|
||||
const toggleSection = (section: string) => {
|
||||
const newExpandedSections = new Set(expandedSections);
|
||||
if (newExpandedSections.has(section)) {
|
||||
newExpandedSections.delete(section);
|
||||
} else {
|
||||
newExpandedSections.add(section);
|
||||
// Get sources when context section is expanded
|
||||
if (section === 'context') {
|
||||
fetchSources();
|
||||
}
|
||||
}
|
||||
setExpandedSections(newExpandedSections);
|
||||
};
|
||||
|
||||
const isSectionExpanded = (section: string) => {
|
||||
return expandedSections.has(section);
|
||||
};
|
||||
|
||||
const handleSourceToggle = async (source: string) => {
|
||||
let newSelectedSources: string[];
|
||||
|
||||
if (selectedSources.includes(source)) {
|
||||
// Remove source if already selected
|
||||
newSelectedSources = selectedSources.filter(s => s !== source);
|
||||
} else {
|
||||
// Add source if not selected
|
||||
newSelectedSources = [...selectedSources, source];
|
||||
}
|
||||
|
||||
setSelectedSources(newSelectedSources);
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/selected_sources", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(newSelectedSources)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Failed to update selected sources");
|
||||
// Revert the local state if the update failed
|
||||
setSelectedSources(selectedSources);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating selected sources:", error);
|
||||
// Revert the local state if the update failed
|
||||
setSelectedSources(selectedSources);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const handleChatSelect = async (chatId: string) => {
|
||||
try {
|
||||
await onChatChange(chatId);
|
||||
|
||||
// Close sidebar on mobile after selection
|
||||
if (window.innerWidth < 768) {
|
||||
handleClose();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error selecting chat:", error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRenameChat = async (chatId: string, currentName: string) => {
|
||||
const newName = prompt("Enter new chat name:", currentName);
|
||||
if (newName && newName.trim() && newName !== currentName) {
|
||||
try {
|
||||
const response = await fetch("/api/chat/rename", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ chat_id: chatId, new_name: newName.trim() })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Failed to rename chat");
|
||||
return;
|
||||
}
|
||||
|
||||
// Fetch updated metadata for the renamed chat
|
||||
await fetchChatMetadata(chatId);
|
||||
} catch (error) {
|
||||
console.error("Error renaming chat:", error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeleteChat = async (chatId: string) => {
|
||||
try {
|
||||
// Delete the chat
|
||||
const response = await fetch(`/api/chat/${chatId}`, {
|
||||
method: "DELETE"
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Failed to delete chat");
|
||||
return;
|
||||
}
|
||||
|
||||
// Refresh chat list
|
||||
await fetchChats();
|
||||
|
||||
// If we deleted the current chat
|
||||
if (currentChatId === chatId) {
|
||||
// Get updated list of chats
|
||||
const chatsResponse = await fetch("/api/chats");
|
||||
const { chats: remainingChats } = await chatsResponse.json();
|
||||
|
||||
if (remainingChats.length > 0) {
|
||||
// Switch to another chat
|
||||
await onChatChange(remainingChats[0]);
|
||||
} else {
|
||||
// No chats left, create a new one
|
||||
await handleNewChat();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting chat:", error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleNewChat = async () => {
|
||||
try {
|
||||
console.log("handleNewChat: Starting new chat creation...");
|
||||
|
||||
// Create new chat using backend endpoint
|
||||
const response = await fetch("/api/chat/new", {
|
||||
method: "POST"
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Failed to create new chat");
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("handleNewChat: Created new chat:", data.chat_id);
|
||||
|
||||
// First, refresh the chat list to ensure the new chat is available
|
||||
console.log("handleNewChat: Refreshing chat list...");
|
||||
await fetchChats();
|
||||
console.log("handleNewChat: Chat list refreshed");
|
||||
|
||||
// Then change to the new chat
|
||||
await onChatChange(data.chat_id);
|
||||
console.log("handleNewChat: Changed to new chat");
|
||||
|
||||
// Close sidebar on mobile
|
||||
if (window.innerWidth < 768) {
|
||||
handleClose();
|
||||
}
|
||||
|
||||
// Add a small delay to ensure the DOM has updated, then trigger highlight animation
|
||||
setTimeout(() => {
|
||||
if (chatListRef.current) {
|
||||
const activeChat = chatListRef.current.querySelector(`.${styles.active}`) as HTMLElement;
|
||||
if (activeChat) {
|
||||
const offset = activeChat.offsetTop;
|
||||
chatListRef.current.style.setProperty('--highlight-offset', `${offset}px`);
|
||||
}
|
||||
}
|
||||
}, 100); // Increased delay for more reliability
|
||||
} catch (error) {
|
||||
console.error("Error creating new chat:", error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleClearAllChats = async () => {
|
||||
// Show confirmation dialog
|
||||
const confirmClear = window.confirm(
|
||||
`Are you sure you want to clear all ${chats.length} chat conversations? This action cannot be undone.`
|
||||
);
|
||||
|
||||
if (!confirmClear) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/chats/clear", {
|
||||
method: "DELETE"
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Failed to clear all chats");
|
||||
alert("Failed to clear chats. Please try again.");
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// Switch to the new chat created by the backend
|
||||
await onChatChange(data.new_chat_id);
|
||||
|
||||
// Refresh chat list
|
||||
await fetchChats();
|
||||
|
||||
// Close sidebar on mobile
|
||||
if (window.innerWidth < 768) {
|
||||
handleClose();
|
||||
}
|
||||
|
||||
console.log(`Successfully cleared ${data.cleared_count} chats`);
|
||||
} catch (error) {
|
||||
console.error("Error clearing all chats:", error);
|
||||
alert("An error occurred while clearing chats. Please try again.");
|
||||
}
|
||||
};
|
||||
|
||||
const handleModelChange = async (event: React.ChangeEvent<HTMLSelectElement>) => {
|
||||
const newModel = event.target.value;
|
||||
const newModelLower = newModel.toLowerCase();
|
||||
setSelectedModel(newModel);
|
||||
|
||||
try {
|
||||
console.log("Updating selected model to:", newModel);
|
||||
|
||||
const response = await fetch("/api/selected_model", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ model: newModel })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Failed to update selected model");
|
||||
// Revert the local state if the update failed
|
||||
setSelectedModel(selectedModel);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating selected model:", error);
|
||||
// Revert the local state if the update failed
|
||||
setSelectedModel(selectedModel);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<button
|
||||
className={`${styles.toggleSidebarButton} ${isVisible && !isClosing ? styles.active : ''}`}
|
||||
onClick={toggleSidebar}
|
||||
>
|
||||
☰
|
||||
</button>
|
||||
|
||||
{isVisible && (
|
||||
<>
|
||||
<div
|
||||
className={`${styles.sidebarOverlay} ${isClosing ? styles.closing : ''}`}
|
||||
onClick={handleClose}
|
||||
/>
|
||||
<div className={`${styles.sidebar} ${isClosing ? styles.closing : ''}`}>
|
||||
<button
|
||||
className={styles.closeSidebarButton}
|
||||
onClick={handleClose}
|
||||
>
|
||||
×
|
||||
</button>
|
||||
<div className={styles.sidebarHeader}>
|
||||
<h2 className={styles.title}>Spark Chat</h2>
|
||||
</div>
|
||||
|
||||
{/* Model Selection */}
|
||||
<div className={styles.sidebarSection}>
|
||||
<div
|
||||
className={styles.sectionHeader}
|
||||
onClick={() => toggleSection('model')}
|
||||
>
|
||||
<h3>Model</h3>
|
||||
<span className={isSectionExpanded('model') ? styles.arrowUp : styles.arrowDown}>▼</span>
|
||||
</div>
|
||||
<div className={`${styles.sectionContent} ${isSectionExpanded('model') ? styles.expanded : ''}`}>
|
||||
<div className={styles.configItem}>
|
||||
<label htmlFor="model-select">Select Supervisor Model</label>
|
||||
<select
|
||||
id="model-select"
|
||||
className={styles.modelSelect}
|
||||
value={selectedModel}
|
||||
onChange={handleModelChange}
|
||||
disabled={isLoadingModels}
|
||||
>
|
||||
{isLoadingModels ? (
|
||||
<option value="">Loading models...</option>
|
||||
) : (
|
||||
availableModels.map(model => (
|
||||
<option key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Context */}
|
||||
<div className={styles.sidebarSection}>
|
||||
<div
|
||||
className={styles.sectionHeader}
|
||||
onClick={() => toggleSection('context')}
|
||||
>
|
||||
<h3>Context</h3>
|
||||
<span className={isSectionExpanded('context') ? styles.arrowUp : styles.arrowDown}>▼</span>
|
||||
</div>
|
||||
<div className={`${styles.sectionContent} ${isSectionExpanded('context') ? styles.expanded : ''}`}>
|
||||
<div className={styles.configItem}>
|
||||
<label>Select Sources</label>
|
||||
<div className={styles.sourcesContainer}>
|
||||
{availableSources.length === 0 ? (
|
||||
<div className={styles.noSources}>No sources available</div>
|
||||
) : (
|
||||
availableSources.map(source => (
|
||||
<div key={source} className={styles.sourceItem}>
|
||||
<input
|
||||
type="checkbox"
|
||||
id={`source-${source}`}
|
||||
checked={selectedSources.includes(source)}
|
||||
onChange={() => handleSourceToggle(source)}
|
||||
/>
|
||||
<label htmlFor={`source-${source}`}>{source}</label>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
<button
|
||||
className={styles.refreshButton}
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
fetchSources();
|
||||
}}
|
||||
disabled={isLoadingSources}
|
||||
>
|
||||
{isLoadingSources ? "Loading..." : "Refresh Sources"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Chat History */}
|
||||
<div className={styles.sidebarSection}>
|
||||
<div
|
||||
className={styles.sectionHeader}
|
||||
onClick={() => toggleSection('history')}
|
||||
>
|
||||
<h3>Chat History</h3>
|
||||
<span className={isSectionExpanded('history') ? styles.arrowUp : styles.arrowDown}>▼</span>
|
||||
</div>
|
||||
<div className={`${styles.sectionContent} ${isSectionExpanded('history') ? styles.expanded : ''}`}>
|
||||
<div className={styles.chatList} ref={chatListRef}>
|
||||
{isLoadingChats ? (
|
||||
<div className={styles.loadingText}>Loading chats...</div>
|
||||
) : chats.length === 0 ? (
|
||||
<div className={styles.noChatText}>No previous chats</div>
|
||||
) : (
|
||||
chats.map((chatId) => (
|
||||
<div
|
||||
key={chatId}
|
||||
className={`${styles.chatItem} ${currentChatId === chatId ? styles.active : ''}`}
|
||||
onClick={() => handleChatSelect(chatId)}
|
||||
>
|
||||
<div className={styles.chatName}>
|
||||
{chatMetadata[chatId]?.name || chatId.slice(0, 8)}
|
||||
</div>
|
||||
<div className={styles.chatActions}>
|
||||
<button
|
||||
className={styles.chatActionButton}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleRenameChat(chatId, chatMetadata[chatId]?.name || `Chat ${chatId.slice(0, 8)}`);
|
||||
}}
|
||||
>
|
||||
✎
|
||||
</button>
|
||||
<button
|
||||
className={styles.chatActionButton}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleDeleteChat(chatId);
|
||||
}}
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Chat Action Buttons at bottom */}
|
||||
<div className={styles.chatButtonsContainer}>
|
||||
<button
|
||||
className={styles.newChatButton}
|
||||
onClick={handleNewChat}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<line x1="12" y1="5" x2="12" y2="19" />
|
||||
<line x1="5" y1="12" x2="19" y2="12" />
|
||||
</svg>
|
||||
New Chat
|
||||
</button>
|
||||
|
||||
<button
|
||||
className={styles.clearChatsButton}
|
||||
onClick={handleClearAllChats}
|
||||
disabled={chats.length === 0}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<path d="M3 6h18" />
|
||||
<path d="M19 6v14c0 1-1 2-2 2H7c-1 0-2-1-2-2V6" />
|
||||
<path d="M8 6V4c0-1 1-2 2-2h4c1 0 2 1 2 2v2" />
|
||||
</svg>
|
||||
Clear All
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,63 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from 'react';
|
||||
|
||||
export default function ThemeToggle() {
|
||||
const [isDark, setIsDark] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
// Check if user has a theme preference
|
||||
const isDarkMode = localStorage.getItem('theme') === 'dark' ||
|
||||
(!localStorage.getItem('theme') && window.matchMedia('(prefers-color-scheme: dark)').matches);
|
||||
|
||||
setIsDark(isDarkMode);
|
||||
if (isDarkMode) {
|
||||
document.documentElement.classList.add('dark');
|
||||
}
|
||||
}, []);
|
||||
|
||||
const toggleTheme = () => {
|
||||
setIsDark(!isDark);
|
||||
if (!isDark) {
|
||||
document.documentElement.classList.add('dark');
|
||||
localStorage.setItem('theme', 'dark');
|
||||
} else {
|
||||
document.documentElement.classList.remove('dark');
|
||||
localStorage.setItem('theme', 'light');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={toggleTheme}
|
||||
className="fixed top-4 right-4 p-2 rounded-full bg-white dark:bg-gray-800 shadow-lg hover:shadow-xl transition-all duration-200 z-50"
|
||||
aria-label="Toggle theme"
|
||||
>
|
||||
{isDark ? (
|
||||
<svg className="w-6 h-6 text-yellow-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M12 3v1m0 16v1m9-9h-1M4 12H3m15.364 6.364l-.707-.707M6.343 6.343l-.707-.707m12.728 0l-.707.707M6.343 17.657l-.707.707M16 12a4 4 0 11-8 0 4 4 0 018 0z" />
|
||||
</svg>
|
||||
) : (
|
||||
<svg className="w-6 h-6 text-gray-700" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M20.354 15.354A9 9 0 018.646 3.646 9.003 9.003 0 0012 21a9.003 9.003 0 008.354-5.646z" />
|
||||
</svg>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,103 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import styles from "@/styles/WelcomeSection.module.css";
|
||||
|
||||
interface WelcomeSectionProps {
|
||||
setQuery: (value: string) => void;
|
||||
}
|
||||
|
||||
export default function WelcomeSection({ setQuery }: WelcomeSectionProps) {
|
||||
const promptTemplates = {
|
||||
rag: "What is the Blackwell GB202 GPU according the whitepaper document i uploaded?",
|
||||
image: "Can you analyze the graphs in this image and tell me any surprising stats? https://menlovc.com/wp-content/uploads/2025/06/5-parents_power_users-062425.png",
|
||||
code: `Can you generate code to develop a responsive personal website for my freelance AI dev business based on my personal brand palette?
|
||||
|
||||
My palette is:
|
||||
#606C38
|
||||
#283618
|
||||
#FEFAE0
|
||||
#DDA15E
|
||||
#BC6C25`,
|
||||
chat: "Hey Spark! Can you draft an email asking a product manager in distributed systems to a coffee chat?"
|
||||
};
|
||||
|
||||
const handleCardClick = (promptKey: keyof typeof promptTemplates) => {
|
||||
setQuery(promptTemplates[promptKey]);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={styles.welcomeContainer}>
|
||||
<div className={styles.welcomeMessage}>
|
||||
Hello! Send a message to start chatting with Spark.
|
||||
</div>
|
||||
<div className={styles.agentCards}>
|
||||
<div
|
||||
className={`${styles.agentCard} ${styles.animate1}`}
|
||||
onClick={() => handleCardClick('rag')}
|
||||
>
|
||||
<div className={styles.agentIcon}>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
|
||||
<circle cx="11" cy="11" r="8"/>
|
||||
<path d="m21 21-4.35-4.35"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 className={styles.agentTitle}>Search Documents</h3>
|
||||
<p className={styles.agentSubtitle}>RAG Agent</p>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.agentCard} ${styles.animate2}`}
|
||||
onClick={() => handleCardClick('image')}
|
||||
>
|
||||
<div className={styles.agentIcon}>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
|
||||
<rect width="18" height="18" x="3" y="3" rx="2" ry="2"/>
|
||||
<circle cx="9" cy="9" r="2"/>
|
||||
<path d="m21 15-3.086-3.086a2 2 0 0 0-2.828 0L6 21"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 className={styles.agentTitle}>Image Processor</h3>
|
||||
<p className={styles.agentSubtitle}>Image Understanding Agent</p>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.agentCard} ${styles.animate3}`}
|
||||
onClick={() => handleCardClick('code')}
|
||||
>
|
||||
<div className={styles.agentIcon}>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
|
||||
<polyline points="16,18 22,12 16,6"/>
|
||||
<polyline points="8,6 2,12 8,18"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 className={styles.agentTitle}>Code Generation</h3>
|
||||
<p className={styles.agentSubtitle}>Coding Agent</p>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.agentCard} ${styles.animate4}`}
|
||||
onClick={() => handleCardClick('chat')}
|
||||
>
|
||||
<div className={styles.agentIcon}>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" width="24" height="24">
|
||||
<path d="m3 21 1.9-5.7a8.5 8.5 0 1 1 3.8 3.8z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 className={styles.agentTitle}>Chat</h3>
|
||||
<p className={styles.agentSubtitle}>Local LLM</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,136 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
}}
|
||||
|
||||
.section {
|
||||
background: white;
|
||||
color: black;
|
||||
border-radius: 12px;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.08);
|
||||
width: 350px;
|
||||
flex-shrink: 0;
|
||||
animation: fadeIn 0.8s cubic-bezier(0.4,0.2,0.2,1) both;
|
||||
}
|
||||
|
||||
.ingestForm {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.uploadSection {
|
||||
padding: 1.5rem;
|
||||
border: 2px dashed #e2e8f0;
|
||||
border-radius: 8px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: border-color 0.2s;
|
||||
/* background-color: #29526d; */
|
||||
}
|
||||
|
||||
.uploadSection:hover {
|
||||
border-color: #4299e1;
|
||||
}
|
||||
|
||||
.fileInput {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* choose files button*/
|
||||
.customFileLabel {
|
||||
padding: 0.25rem 1rem;
|
||||
background-color: #f5f6fa;
|
||||
color: #222;
|
||||
border: 1px solid #d1d5db;
|
||||
border-radius: 4px;
|
||||
font-weight: 400;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s, border 0.2s, color 0.2s;
|
||||
font-size: 0.98rem;
|
||||
margin-bottom: 0.5em;
|
||||
display: inline-block;
|
||||
text-align: center;
|
||||
}
|
||||
.customFileLabel:hover {
|
||||
background-color: #ffffff;
|
||||
color: #111;
|
||||
border: 1px solid var(--primary);
|
||||
}
|
||||
.customFileLabel:active, .customFileLabel:focus {
|
||||
outline: 2px solid #76B900;
|
||||
}
|
||||
|
||||
.fileName {
|
||||
display: block;
|
||||
margin-top: 0.5em;
|
||||
color: #222;
|
||||
text-align: center;
|
||||
font-size: 1rem;
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
.directoryInput {
|
||||
width: 100%;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 2px;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.helpText {
|
||||
color: #718096;
|
||||
font-size: 0.875rem;
|
||||
margin-top: 0.5rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.ingestButton {
|
||||
padding: 0.75rem 1.5rem;
|
||||
background-color: #48bb78;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.ingestButton:hover {
|
||||
background-color: #38a169;
|
||||
}
|
||||
|
||||
.ingestButton:disabled {
|
||||
background-color: #cbd5e0;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.messageContainer {
|
||||
margin-top: 1.5rem;
|
||||
padding: 1rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background-color: #f7fafc;
|
||||
}
|
||||
@ -0,0 +1,117 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
.container {
|
||||
display: flex;
|
||||
height: 100vh;
|
||||
width: 100vw;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.mainContent {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
overflow: hidden;
|
||||
width: 100%;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
}
|
||||
|
||||
.documentUploadContainer {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
z-index: 20;
|
||||
background: white;
|
||||
padding: 24px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
||||
max-width: 500px;
|
||||
width: 90%;
|
||||
}
|
||||
|
||||
.overlay {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.5);
|
||||
z-index: 15;
|
||||
}
|
||||
|
||||
.closeButton {
|
||||
position: absolute;
|
||||
top: 12px;
|
||||
right: 12px;
|
||||
background: none;
|
||||
border: none;
|
||||
font-size: 1.25rem;
|
||||
cursor: pointer;
|
||||
color: #718096;
|
||||
}
|
||||
|
||||
.closeButton:hover {
|
||||
color: #4a5568;
|
||||
}
|
||||
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.title {
|
||||
color: #2d3748;
|
||||
font-size: 2rem;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.toggleButton {
|
||||
background-color: #4299e1;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 8px 16px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.toggleButton:hover {
|
||||
background-color: #3182ce;
|
||||
}
|
||||
|
||||
.mainContent > div:only-child {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.container {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.header {
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,781 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
.sidebar {
|
||||
width: 320px;
|
||||
height: auto;
|
||||
max-height: 90vh;
|
||||
background-color: white;
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow-y: auto;
|
||||
position: absolute;
|
||||
top: 60px;
|
||||
left: 20px;
|
||||
z-index: 5;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
pointer-events: auto;
|
||||
transform-origin: top;
|
||||
animation: menuExpand 0.6s cubic-bezier(0.34, 1.56, 0.64, 1) forwards;
|
||||
}
|
||||
|
||||
:global(.dark) .sidebar {
|
||||
background-color: #1f2937;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
.sidebar.closing {
|
||||
animation: menuCollapse 0.5s cubic-bezier(0.34, 1.56, 0.64, 1) forwards;
|
||||
}
|
||||
|
||||
@keyframes menuExpand {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: scaleY(0.3) translateY(-40px);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: scaleY(1) translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes menuCollapse {
|
||||
0% {
|
||||
opacity: 1;
|
||||
transform: scaleY(1) translateY(0);
|
||||
}
|
||||
100% {
|
||||
opacity: 0;
|
||||
transform: scaleY(0.3) translateY(40px);
|
||||
}
|
||||
}
|
||||
|
||||
.sidebarOverlay {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: transparent;
|
||||
z-index: 5;
|
||||
display: block;
|
||||
pointer-events: none;
|
||||
animation: fadeIn 0.2s ease forwards;
|
||||
}
|
||||
|
||||
.sidebarOverlay.closing {
|
||||
animation: fadeOut 0.2s ease forwards;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes fadeOut {
|
||||
from {
|
||||
opacity: 1;
|
||||
}
|
||||
to {
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.toggleSidebarButton {
|
||||
position: absolute;
|
||||
top: 20px;
|
||||
left: 20px;
|
||||
z-index: 10;
|
||||
background-color: white;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
|
||||
font-size: 18px;
|
||||
transition: transform 0.5s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||
color: #1f2937;
|
||||
}
|
||||
|
||||
:global(.dark) .toggleSidebarButton {
|
||||
background-color: #1f2937;
|
||||
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.toggleSidebarButton.active {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
|
||||
.closeSidebarButton {
|
||||
position: absolute;
|
||||
top: 12px;
|
||||
right: 12px;
|
||||
background: none;
|
||||
border: none;
|
||||
font-size: 24px;
|
||||
line-height: 1;
|
||||
cursor: pointer;
|
||||
color: #718096;
|
||||
padding: 4px 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.closeSidebarButton:hover {
|
||||
color: #4a5568;
|
||||
background-color: #f3f4f6;
|
||||
}
|
||||
|
||||
:global(.dark) .closeSidebarButton {
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
:global(.dark) .closeSidebarButton:hover {
|
||||
color: #e5e7eb;
|
||||
background-color: #374151;
|
||||
}
|
||||
|
||||
.sidebarHeader {
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
:global(.dark) .sidebarHeader {
|
||||
border-bottom-color: #374151;
|
||||
}
|
||||
|
||||
.title {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
color: #2d3748;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
:global(.dark) .title {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.sidebarSection {
|
||||
margin-bottom: 16px;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
padding-bottom: 16px;
|
||||
}
|
||||
|
||||
:global(.dark) .sidebarSection {
|
||||
border-bottom-color: #374151;
|
||||
}
|
||||
|
||||
.sidebarSection:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.sectionHeader {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
cursor: pointer;
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.sectionHeader h3 {
|
||||
margin: 0;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: #4a5568;
|
||||
}
|
||||
|
||||
:global(.dark) .sectionHeader h3 {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.sectionContent {
|
||||
padding-top: 8px;
|
||||
max-height: 0;
|
||||
overflow: hidden;
|
||||
opacity: 0;
|
||||
transform: translateY(-8px);
|
||||
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.sectionContent.expanded {
|
||||
max-height: 1000px;
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
pointer-events: auto;
|
||||
}
|
||||
|
||||
.arrowDown, .arrowUp {
|
||||
font-size: 12px;
|
||||
color: #718096;
|
||||
transition: transform 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
:global(.dark) .arrowDown, :global(.dark) .arrowUp {
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
.arrowUp {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
|
||||
.configItem {
|
||||
margin-bottom: 12px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.configItem label {
|
||||
font-size: 0.875rem;
|
||||
color: #4a5568;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
:global(.dark) .configItem label {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.select {
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #e2e8f0;
|
||||
background-color: #f9fafb;
|
||||
font-size: 0.875rem;
|
||||
color: #4a5568;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
:global(.dark) .select {
|
||||
background-color: #374151;
|
||||
border-color: #4b5563;
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.toolItem {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.toolItem input[type="checkbox"] {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.toolItem label {
|
||||
font-size: 0.875rem;
|
||||
color: #4a5568;
|
||||
}
|
||||
|
||||
:global(.dark) .toolItem label {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.toggleSwitch {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 48px;
|
||||
height: 24px;
|
||||
}
|
||||
|
||||
.toggleSwitch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
}
|
||||
|
||||
.toggleSlider {
|
||||
position: absolute;
|
||||
cursor: pointer;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: #e2e8f0;
|
||||
transition: .4s;
|
||||
border-radius: 24px;
|
||||
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
:global(.dark) .toggleSlider {
|
||||
background-color: #374151;
|
||||
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.toggleSlider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
transition: .4s;
|
||||
border-radius: 50%;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
input:checked + .toggleSlider {
|
||||
background-color: #76B900;
|
||||
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
input:checked + .toggleSlider:before {
|
||||
transform: translateX(24px);
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.toolLabel {
|
||||
font-size: 14px;
|
||||
color: #1f2937;
|
||||
}
|
||||
|
||||
:global(.dark) .toolLabel {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.chatList {
|
||||
position: relative;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
padding: 4px;
|
||||
}
|
||||
|
||||
.chatList::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 36px;
|
||||
background: rgba(118, 185, 0, 0.08);
|
||||
border: 1px solid rgba(118, 185, 0, 0.2);
|
||||
border-radius: 8px;
|
||||
transform: translateY(calc(var(--highlight-offset, 0) - 3px));
|
||||
transition: transform 0.2s ease-out;
|
||||
pointer-events: none;
|
||||
z-index: 0;
|
||||
}
|
||||
|
||||
.chatItem {
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 6px 12px;
|
||||
border-radius: 8px;
|
||||
height: 36px;
|
||||
cursor: pointer;
|
||||
z-index: 1;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.chatItem:hover {
|
||||
background-color: rgba(0, 0, 0, 0.03);
|
||||
}
|
||||
|
||||
.chatName {
|
||||
flex: 1;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
transition: color 0.2s ease;
|
||||
color: #1f2937;
|
||||
line-height: 24px;
|
||||
padding-top: 1px;
|
||||
}
|
||||
|
||||
:global(.dark) .chatName {
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.chatItem.active .chatName {
|
||||
color: #76B900;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.chatActions {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.chatItem:hover .chatActions {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.chatActionButton {
|
||||
background: none;
|
||||
border: none;
|
||||
color: #718096;
|
||||
padding: 4px;
|
||||
cursor: pointer;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
line-height: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.chatActionButton:hover {
|
||||
background-color: rgba(0, 0, 0, 0.05);
|
||||
color: #1f2937;
|
||||
}
|
||||
|
||||
.loadingText, .noChatText {
|
||||
color: #718096;
|
||||
font-size: 0.875rem;
|
||||
text-align: center;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
:global(.dark) .loadingText, :global(.dark) .noChatText {
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
.activeIndicator {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background-color: #76B900;
|
||||
margin-right: 8px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.chatInfo {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex: 1;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.chatDate {
|
||||
font-size: 0.75rem;
|
||||
color: #718096;
|
||||
}
|
||||
|
||||
:global(.dark) .chatDate {
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
.actionButton {
|
||||
width: 100%;
|
||||
padding: 10px 16px;
|
||||
background-color: #76B900;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.actionButton:hover {
|
||||
background-color: #669f00;
|
||||
}
|
||||
|
||||
.refreshButton {
|
||||
margin-top: 8px;
|
||||
padding: 6px 12px;
|
||||
background-color: #e2e8f0;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
:global(.dark) .refreshButton {
|
||||
background-color: #374151;
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.refreshButton:hover {
|
||||
background-color: #cbd5e0;
|
||||
}
|
||||
|
||||
:global(.dark) .refreshButton:hover {
|
||||
background-color: #4b5563;
|
||||
}
|
||||
|
||||
.refreshButton:disabled {
|
||||
opacity: 0.6;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.sourcesContainer {
|
||||
margin-top: 8px;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 4px;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
:global(.dark) .sourcesContainer {
|
||||
background-color: #374151;
|
||||
border-color: #4b5563;
|
||||
}
|
||||
|
||||
.sourceItem {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
|
||||
.sourceItem input[type="checkbox"] {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.sourceItem label {
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
:global(.dark) .sourceItem label {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.noSources {
|
||||
color: #718096;
|
||||
font-size: 14px;
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
:global(.dark) .noSources {
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
.modelSelect {
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #e2e8f0;
|
||||
background-color: #f9fafb;
|
||||
font-size: 0.875rem;
|
||||
color: #4a5568;
|
||||
width: 100%;
|
||||
margin-top: 4px;
|
||||
cursor: pointer;
|
||||
appearance: none;
|
||||
background-image: url("data:image/svg+xml;charset=US-ASCII,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20width%3D%22292.4%22%20height%3D%22292.4%22%3E%3Cpath%20fill%3D%22%23007CB2%22%20d%3D%22M287%2069.4a17.6%2017.6%200%200%200-13-5.4H18.4c-5%200-9.3%201.8-12.9%205.4A17.6%2017.6%200%200%200%200%2082.2c0%205%201.8%209.3%205.4%2012.9l128%20127.9c3.6%203.6%207.8%205.4%2012.8%205.4s9.2-1.8%2012.8-5.4L287%2095c3.5-3.5%205.4-7.8%205.4-12.8%200-5-1.9-9.2-5.5-12.8z%22%2F%3E%3C%2Fsvg%3E");
|
||||
background-repeat: no-repeat;
|
||||
background-position: right 12px top 50%;
|
||||
background-size: 12px auto;
|
||||
}
|
||||
|
||||
:global(.dark) .modelSelect {
|
||||
background-color: #374151;
|
||||
border-color: #4b5563;
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.modelSelect:hover {
|
||||
border-color: #cbd5e0;
|
||||
}
|
||||
|
||||
:global(.dark) .modelSelect:hover {
|
||||
border-color: #6b7280;
|
||||
}
|
||||
|
||||
.modelSelect:focus {
|
||||
outline: none;
|
||||
border-color: #76B900;
|
||||
box-shadow: 0 0 0 3px rgba(118, 185, 0, 0.2);
|
||||
}
|
||||
|
||||
:global(.dark) .modelSelect:focus {
|
||||
box-shadow: 0 0 0 3px rgba(118, 185, 0, 0.4);
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.sidebar {
|
||||
width: 280px;
|
||||
}
|
||||
}
|
||||
|
||||
.newChatButton {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
background-color: #76B900;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 12px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
transition: all 0.3s ease;
|
||||
box-shadow: 0 2px 8px rgba(118, 185, 0, 0.2);
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.newChatButton::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
transform: translate(-50%, -50%) scale(0);
|
||||
border-radius: 50%;
|
||||
transition: transform 0.5s ease;
|
||||
}
|
||||
|
||||
.newChatButton:hover {
|
||||
background-color: #669f00;
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 4px 12px rgba(118, 185, 0, 0.3);
|
||||
}
|
||||
|
||||
.newChatButton:hover::before {
|
||||
transform: translate(-50%, -50%) scale(2);
|
||||
}
|
||||
|
||||
:global(.dark) .newChatButton {
|
||||
box-shadow: 0 2px 8px rgba(118, 185, 0, 0.4);
|
||||
}
|
||||
|
||||
:global(.dark) .newChatButton:hover {
|
||||
box-shadow: 0 4px 12px rgba(118, 185, 0, 0.5);
|
||||
}
|
||||
|
||||
.newChatButton svg {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.newChatButton:hover svg {
|
||||
transform: rotate(90deg);
|
||||
}
|
||||
|
||||
.chatButtonsContainer {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
margin-top: 16px;
|
||||
}
|
||||
|
||||
.clearChatsButton {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
background-color: #dc2626;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 12px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
transition: all 0.3s ease;
|
||||
box-shadow: 0 2px 8px rgba(220, 38, 38, 0.2);
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.clearChatsButton:disabled {
|
||||
background-color: #9ca3af;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
color: #d1d5db;
|
||||
}
|
||||
|
||||
.clearChatsButton:disabled:hover {
|
||||
transform: none;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
.clearChatsButton::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
transform: translate(-50%, -50%) scale(0);
|
||||
border-radius: 50%;
|
||||
transition: transform 0.5s ease;
|
||||
}
|
||||
|
||||
.clearChatsButton:hover:not(:disabled) {
|
||||
background-color: #b91c1c;
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.3);
|
||||
}
|
||||
|
||||
.clearChatsButton:hover:not(:disabled)::before {
|
||||
transform: translate(-50%, -50%) scale(2);
|
||||
}
|
||||
|
||||
:global(.dark) .clearChatsButton {
|
||||
box-shadow: 0 2px 8px rgba(220, 38, 38, 0.4);
|
||||
}
|
||||
|
||||
:global(.dark) .clearChatsButton:hover:not(:disabled) {
|
||||
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.5);
|
||||
}
|
||||
|
||||
:global(.dark) .clearChatsButton:disabled {
|
||||
background-color: #4b5563;
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.clearChatsButton svg {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.clearChatsButton:hover:not(:disabled) svg {
|
||||
transform: scale(1.1);
|
||||
}
|
||||
|
||||
:global(.dark) .chatItem:hover {
|
||||
background-color: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
|
||||
:global(.dark) .chatList::before {
|
||||
background: rgba(118, 185, 0, 0.15);
|
||||
border-color: rgba(118, 185, 0, 0.3);
|
||||
}
|
||||
@ -0,0 +1,217 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
.welcomeContainer {
|
||||
position: absolute;
|
||||
left: 50%;
|
||||
top: 45%;
|
||||
transform: translate(-50%, -50%);
|
||||
text-align: center;
|
||||
max-width: 900px;
|
||||
width: 100%;
|
||||
padding: 20px;
|
||||
z-index: 10;
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.welcomeMessage {
|
||||
font-size: 42px;
|
||||
font-weight: 200;
|
||||
color: #1f2937;
|
||||
max-width: 800px;
|
||||
letter-spacing: -0.5px;
|
||||
background: linear-gradient(45deg, #76B900, #669f00);
|
||||
background-clip: text;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
margin: 0 auto 40px auto;
|
||||
opacity: 0;
|
||||
animation: fadeInUp 1.2s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
|
||||
}
|
||||
|
||||
:global(.dark) .welcomeMessage {
|
||||
color: #e5e7eb;
|
||||
background: linear-gradient(45deg, #76B900, #9be024);
|
||||
background-clip: text;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
}
|
||||
|
||||
.agentCards {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 24px;
|
||||
width: 100%;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
|
||||
}
|
||||
|
||||
.agentCard {
|
||||
background: white;
|
||||
border-radius: 12px;
|
||||
padding: 24px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
transition: all 0.3s ease;
|
||||
cursor: pointer;
|
||||
text-align: center;
|
||||
border: 1px solid rgba(118, 185, 0, 0.1);
|
||||
opacity: 0;
|
||||
transform: translateY(30px);
|
||||
}
|
||||
|
||||
.agentCard.animate1 {
|
||||
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
|
||||
animation-delay: 0.3s;
|
||||
}
|
||||
|
||||
.agentCard.animate2 {
|
||||
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
|
||||
animation-delay: 0.4s;
|
||||
}
|
||||
|
||||
.agentCard.animate3 {
|
||||
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
|
||||
animation-delay: 0.6s;
|
||||
}
|
||||
|
||||
.agentCard.animate4 {
|
||||
animation: fadeInUp 0.6s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
|
||||
animation-delay: 0.7s;
|
||||
}
|
||||
|
||||
.agentCard.animate1:hover,
|
||||
.agentCard.animate2:hover,
|
||||
.agentCard.animate3:hover,
|
||||
.agentCard.animate4:hover {
|
||||
transform: translateY(-4px);
|
||||
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.15);
|
||||
border-color: rgba(118, 185, 0, 0.3);
|
||||
}
|
||||
|
||||
:global(.dark) .agentCard {
|
||||
background: #1f2937;
|
||||
border-color: rgba(118, 185, 0, 0.2);
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
:global(.dark) .agentCard:hover {
|
||||
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4);
|
||||
border-color: rgba(118, 185, 0, 0.4);
|
||||
}
|
||||
|
||||
.agentIcon {
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
margin: 0 auto 16px auto;
|
||||
background: linear-gradient(45deg, #76B900, #669f00);
|
||||
border-radius: 12px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.agentTitle {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #1f2937;
|
||||
margin: 0 0 8px 0;
|
||||
line-height: 1.2;
|
||||
}
|
||||
|
||||
:global(.dark) .agentTitle {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.agentSubtitle {
|
||||
font-size: 14px;
|
||||
font-weight: 400;
|
||||
color: #6b7280;
|
||||
margin: 0;
|
||||
line-height: 1.2;
|
||||
}
|
||||
|
||||
:global(.dark) .agentSubtitle {
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.agentCards {
|
||||
grid-template-columns: 1fr;
|
||||
gap: 16px;
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
.welcomeContainer {
|
||||
max-width: 95%;
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.welcomeMessage {
|
||||
font-size: 32px;
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
|
||||
.agentCard {
|
||||
padding: 20px;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: translateY(30px);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes welcomeBubble {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: translate(-50%, 30px) scale(0.8);
|
||||
filter: blur(10px);
|
||||
}
|
||||
60% {
|
||||
opacity: 0.8;
|
||||
filter: blur(0px);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
filter: blur(0px);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes welcomeDisappear {
|
||||
0% {
|
||||
opacity: 1;
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
filter: blur(0px);
|
||||
}
|
||||
100% {
|
||||
opacity: 0;
|
||||
transform: translate(-50%, -80px) scale(0.9);
|
||||
filter: blur(10px);
|
||||
}
|
||||
}
|
||||
|
||||
.welcomeMessage.hide {
|
||||
animation: welcomeDisappear 0.6s cubic-bezier(0.4, 0, 0.2, 1) forwards;
|
||||
}
|
||||
@ -0,0 +1,41 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
export interface ChatConfig {
|
||||
id: string;
|
||||
name: string;
|
||||
model: string;
|
||||
selectedSources: string[];
|
||||
}
|
||||
|
||||
export interface AppConfig {
|
||||
currentChatId: string;
|
||||
chats: ChatConfig[];
|
||||
sources: string[];
|
||||
}
|
||||
|
||||
export const DEFAULT_CONFIG: AppConfig = {
|
||||
currentChatId: "default-chat",
|
||||
chats: [
|
||||
{
|
||||
id: "default-chat",
|
||||
name: "New Chat",
|
||||
model: "gpt-3.5-turbo",
|
||||
selectedSources: []
|
||||
}
|
||||
],
|
||||
sources: []
|
||||
};
|
||||
@ -0,0 +1,34 @@
|
||||
/*
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
*/
|
||||
import type { Config } from "tailwindcss";
|
||||
|
||||
export default {
|
||||
content: [
|
||||
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
],
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
background: "var(--background)",
|
||||
foreground: "var(--foreground)",
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [],
|
||||
} satisfies Config;
|
||||
27
nvidia/multi-agent-chatbot/assets/frontend/tsconfig.json
Normal file
@ -0,0 +1,27 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2017",
|
||||
"lib": ["dom", "dom.iterable", "esnext"],
|
||||
"allowJs": true,
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"noEmit": true,
|
||||
"esModuleInterop": true,
|
||||
"module": "esnext",
|
||||
"moduleResolution": "bundler",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "preserve",
|
||||
"incremental": true,
|
||||
"plugins": [
|
||||
{
|
||||
"name": "next"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"@/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
|
||||
"exclude": ["node_modules"]
|
||||
}
|
||||
56
nvidia/multi-agent-chatbot/assets/setup.sh
Executable file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(pwd)"
|
||||
MODELS_DIR="$ROOT_DIR/models"
|
||||
mkdir -p "$MODELS_DIR"
|
||||
cd "$MODELS_DIR"
|
||||
|
||||
download_if_needed() {
|
||||
url="$1"
|
||||
file="$2"
|
||||
if [ -f "$file" ]; then
|
||||
echo "$file already exists, skipping."
|
||||
else
|
||||
curl -C - -L -o "$file" "$url"
|
||||
fi
|
||||
}
|
||||
|
||||
download_if_needed "https://huggingface.co/TheBloke/deepseek-coder-6.7B-instruct-GGUF/resolve/main/deepseek-coder-6.7b-instruct.Q8_0.gguf" "deepseek-coder-6.7b-instruct.Q8_0.gguf"
|
||||
|
||||
download_if_needed "https://huggingface.co/Qwen/Qwen3-Embedding-4B-GGUF/resolve/main/Qwen3-Embedding-4B-Q8_0.gguf" "Qwen3-Embedding-4B-Q8_0.gguf"
|
||||
|
||||
# Comment next three lines if you want to use gpt-oss-20b
|
||||
download_if_needed "https://huggingface.co/ggml-org/gpt-oss-120b-GGUF/resolve/main/gpt-oss-120b-mxfp4-00001-of-00003.gguf" "gpt-oss-120b-mxfp4-00001-of-00003.gguf"
|
||||
|
||||
download_if_needed "https://huggingface.co/ggml-org/gpt-oss-120b-GGUF/resolve/main/gpt-oss-120b-mxfp4-00002-of-00003.gguf" "gpt-oss-120b-mxfp4-00002-of-00003.gguf"
|
||||
|
||||
download_if_needed "https://huggingface.co/ggml-org/gpt-oss-120b-GGUF/resolve/main/gpt-oss-120b-mxfp4-00003-of-00003.gguf" "gpt-oss-120b-mxfp4-00003-of-00003.gguf"
|
||||
|
||||
# Uncomment next line if you want to use gpt-oss-20b
|
||||
# download_if_needed "https://huggingface.co/ggml-org/gpt-oss-20b-GGUF/resolve/main/gpt-oss-20b-mxfp4.gguf" "gpt-oss-20b-mxfp4.gguf"
|
||||
|
||||
echo "All models downloaded."
|
||||
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
echo "Starting Docker Compose services..."
|
||||
docker compose -f docker-compose.yml -f docker-compose-models.yml up -d --build
|
||||
|
||||
echo "Docker Compose services are up and running."
|
||||
168
nvidia/vlm-finetuning/assets/Dockerfile
Normal file
@ -0,0 +1,168 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# Dockerfile stage to build GPU-accelerated ffmpeg
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3 AS ffmpeg-builder
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
yasm \
|
||||
libx264-dev \
|
||||
libfaac-dev \
|
||||
libmp3lame-dev \
|
||||
libtheora-dev \
|
||||
libvorbis-dev \
|
||||
libxvidcore-dev \
|
||||
libxext-dev \
|
||||
libxfixes-dev \
|
||||
build-essential \
|
||||
git \
|
||||
pkg-config && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc-11 g++-11
|
||||
|
||||
ENV PATH=/usr/local/cuda/bin:${PATH} \
|
||||
CUDA_HOME=/usr/local/cuda \
|
||||
NVCC=/usr/local/cuda/bin/nvcc \
|
||||
CC=/usr/bin/gcc-11 \
|
||||
CXX=/usr/bin/g++-11 \
|
||||
CUDAHOSTCXX=/usr/bin/g++-11 \
|
||||
FFMPEG_VERSION=4.4.6 \
|
||||
LD_LIBRARY_PATH=/usr/local/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
RUN mkdir -p /deps && \
|
||||
cd /deps && \
|
||||
wget -q https://ffmpeg.org/releases/ffmpeg-${FFMPEG_VERSION}.tar.xz && \
|
||||
tar -xf ffmpeg-${FFMPEG_VERSION}.tar.xz && \
|
||||
rm ffmpeg-${FFMPEG_VERSION}.tar.xz && \
|
||||
cd /deps/ffmpeg-${FFMPEG_VERSION} && \
|
||||
apt-get update && \
|
||||
apt-get install -y libdrm-dev && \
|
||||
./configure \
|
||||
--prefix=/usr/local \
|
||||
--enable-nonfree \
|
||||
--enable-shared \
|
||||
--disable-static \
|
||||
--enable-libdrm \
|
||||
--enable-v4l2-m2m \
|
||||
--enable-gpl \
|
||||
--enable-libx264 \
|
||||
--extra-cflags="-I/usr/include/aarch64-linux-gnu" \
|
||||
--extra-ldflags="-L/usr/lib/aarch64-linux-gnu/tegra" \
|
||||
--nvccflags="-ccbin=/usr/bin/g++-11" \
|
||||
|| (echo "---- tail ffbuild/config.log ----" && tail -n 200 ffbuild/config.log && exit 1) && \
|
||||
make -j"$(nproc)" && \
|
||||
make install && \
|
||||
ldconfig && \
|
||||
echo "✅ FFmpeg installed:" && \
|
||||
ffmpeg -hide_banner -version | head -n 8 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Dockerfile stage to compile decord from source
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3 AS decord-builder
|
||||
COPY --from=ffmpeg-builder /usr/local /usr/local
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git \
|
||||
cmake \
|
||||
ninja-build \
|
||||
pkg-config \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
gcc-11 \
|
||||
g++-11 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV CC=/usr/bin/gcc-11 \
|
||||
CXX=/usr/bin/g++-11 \
|
||||
PATH=/usr/local/bin:/lib/aarch64-linux-gnu:${PATH} \
|
||||
LD_LIBRARY_PATH=/usr/lib/aarch64-linux-gnu/:/lib/aarch64-linux-gnu/:/usr/local/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH} \
|
||||
PKG_CONFIG_PATH=/usr/local/lib/pkgconfig:/usr/lib/aarch64-linux-gnu/pkgconfig${PKG_CONFIG_PATH:+:$PKG_CONFIG_PATH}
|
||||
|
||||
RUN ln -sf /usr/lib/aarch64-linux-gnu/libnvcuvid.so.1 /usr/lib/aarch64-linux-gnu/libnvcuvid.so && \
|
||||
cp /usr/lib/aarch64-linux-gnu/libnvcuvid.so* /usr/local/cuda/lib64/ || true && \
|
||||
ln -sf /usr/lib/aarch64-linux-gnu/libnvcuvid.so.1 /usr/local/cuda/lib64/libnvcuvid.so && \
|
||||
echo "/usr/lib/aarch64-linux-gnu" > /etc/ld.so.conf.d/nvidia.conf && \
|
||||
ldconfig && \
|
||||
python3 -m pip install --no-cache-dir --upgrade pip wheel build numpy && \
|
||||
python3 -m pip install --no-cache-dir --upgrade pip ninja && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends libnvidia-decode-575 libnvidia-encode-575
|
||||
|
||||
RUN cd /workspace && \
|
||||
git clone --recursive https://github.com/dmlc/decord && \
|
||||
cmake -S /workspace/decord -B /workspace/decord/build \
|
||||
-G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DUSE_CUDA=ON \
|
||||
-DCMAKE_CUDA_HOST_COMPILER=/usr/bin/g++-11 \
|
||||
-DCMAKE_C_COMPILER=/usr/bin/gcc-11 \
|
||||
-DCMAKE_CXX_COMPILER=/usr/bin/g++-11 \
|
||||
-DFFMPEG_ROOT=/usr/local \
|
||||
-DUSE_VIDEO_CODEC=OFF && \
|
||||
cd /workspace/decord/build && \
|
||||
ninja -j"$(nproc)" && \
|
||||
cd /workspace/decord/python && \
|
||||
python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel build && \
|
||||
python3 -m build --wheel
|
||||
|
||||
# Dockerfile for demo
|
||||
FROM nvcr.io/nvidia/pytorch:25.09-py3
|
||||
COPY --from=ffmpeg-builder /usr/local /usr/local
|
||||
COPY --from=decord-builder /workspace/decord/python/dist/*.whl /tmp/wheels/
|
||||
|
||||
ARG HF_TOKEN
|
||||
|
||||
RUN pip install --no-cache-dir /tmp/wheels/*.whl && \
|
||||
rm -rf /tmp/wheels && \
|
||||
apt-get update && \
|
||||
apt-get install -y libdrm2 libdrm-dev libx264-dev && \
|
||||
pip install streamlit timm wandb && \
|
||||
hf auth login --token $HF_TOKEN
|
||||
|
||||
# Set CUDA environment variables
|
||||
ENV CUDA_HOME=/usr/local/cuda-13.0/
|
||||
ENV CUDA_PATH=$CUDA_HOME
|
||||
ENV PATH=$CUDA_HOME/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
||||
ENV C_INCLUDE_PATH=$CUDA_HOME/include:$C_INCLUDE_PATH
|
||||
ENV CPLUS_INCLUDE_PATH=$CUDA_HOME/include:$CPLUS_INCLUDE_PATH
|
||||
|
||||
# install triton from source for latest blackwell support
|
||||
RUN git clone https://github.com/triton-lang/triton.git && \
|
||||
cd triton && \
|
||||
git checkout c5d671f91d90f40900027382f98b17a3e04045f6 && \
|
||||
pip install -r python/requirements.txt && \
|
||||
pip install . && \
|
||||
cd ..
|
||||
|
||||
# install xformers from source for blackwell support
|
||||
RUN git clone --depth=1 https://github.com/facebookresearch/xformers --recursive && \
|
||||
cd xformers && \
|
||||
export TORCH_CUDA_ARCH_LIST="12.1" && \
|
||||
python setup.py install && \
|
||||
cd ..
|
||||
# install unsloth without depedencies so we can build them from source
|
||||
RUN pip install unsloth unsloth_zoo bitsandbytes==0.48.0
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
52
nvidia/vlm-finetuning/assets/README.md
Normal file
@ -0,0 +1,52 @@
|
||||
# VLM Fine-tuning Recipes
|
||||
|
||||
This repository contains comprehensive fine-tuning recipes for Vision-Language Models (VLMs), supporting both **image** and **video** understanding tasks with modern models and training techniques.
|
||||
|
||||
## 🎯 Available Recipes
|
||||
|
||||
### 📸 Image VLM Fine-tuning (`ui_image/`)
|
||||
- **Model**: Qwen2.5-VL-7B-Instruct
|
||||
- **Task**: Wildfire detection from satellite imagery
|
||||
- **Training Method**: GRPO (Generalized Reward Preference Optimization) and LoRA (Low-rank Adaptation)
|
||||
|
||||
### 🎥 Video VLM Fine-tuning (`ui_video/`)
|
||||
- **Model**: InternVL3-8B
|
||||
- **Task**: Dangerous driving detection and structured metadata generation
|
||||
- **Training Method**: Supervised Fine-tuning on Multimodal Instructions
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 1. Build the Docker Container
|
||||
|
||||
```bash
|
||||
# Build the VLM fine-tuning container
|
||||
docker build --build-arg HF_TOKEN=$HF_TOKEN -t vlm_demo .
|
||||
```
|
||||
|
||||
### 2. Launch the Container
|
||||
|
||||
```bash
|
||||
# Run the container with GPU support
|
||||
docker run -it \
|
||||
--gpus=all \
|
||||
--net=host \
|
||||
--ipc=host \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
-v $(pwd):/vlm_finetuning \
|
||||
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||
vlm_demo
|
||||
|
||||
# Enter the mounted directory
|
||||
cd /vlm_finetuning
|
||||
```
|
||||
|
||||
> **Note**: The same Docker container and launch commands work for both image and video VLM recipes. The container includes all necessary dependencies including FFmpeg, Decord, and optimized libraries for both workflows.
|
||||
|
||||
## 📚 Detailed Instructions
|
||||
|
||||
Each recipe includes comprehensive documentation:
|
||||
|
||||
- **[Image VLM README](ui_image/README.md)**: Complete guide for wildfire detection fine-tuning with Qwen2.5-VL, including dataset setup, GRPO training configuration, and interactive inference
|
||||
- **[Video VLM README](ui_video/README.md)**: Full walkthrough for dangerous driving detection with InternVL3, covering video data preparation, training notebooks, and structured output generation
|
||||
|
||||
26
nvidia/vlm-finetuning/assets/launch.sh
Executable file
@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
docker run -it \
|
||||
--gpus=all \
|
||||
--net=host \
|
||||
--ipc=host \
|
||||
-w $HOME \
|
||||
-v $HOME:$HOME \
|
||||
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||
vlm_demo
|
||||
@ -0,0 +1,3 @@
|
||||
[theme]
|
||||
base="dark"
|
||||
greenTextColor = "#76b900"
|
||||
438
nvidia/vlm-finetuning/assets/ui_image/Image_VLM.py
Normal file
@ -0,0 +1,438 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
import os
|
||||
import re
|
||||
import gc
|
||||
import yaml
|
||||
import glob
|
||||
import random
|
||||
import subprocess
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
from PIL import Image
|
||||
import streamlit as st
|
||||
|
||||
|
||||
REASONING_START = "<REASONING>"
|
||||
REASONING_END = "</REASONING>"
|
||||
SOLUTION_START = "<SOLUTION>"
|
||||
SOLUTION_END = "</SOLUTION>"
|
||||
|
||||
|
||||
def initialize_session_state(resources):
|
||||
# Initialize page-specific session state
|
||||
st.session_state["base"] = st.session_state.get("base", resources["base"])
|
||||
st.session_state["finetuned"] = st.session_state.get("finetuned", resources["finetuned"])
|
||||
st.session_state["current_image"] = st.session_state.get("current_image", glob.glob("assets/image_vlm/images/*/*")[0])
|
||||
st.session_state["train_process"] = st.session_state.get("train_process", None)
|
||||
|
||||
|
||||
def load_config():
|
||||
config_key = "config"
|
||||
if getattr(st.session_state, config_key, None) is None:
|
||||
with open("src/image_vlm_config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
setattr(st.session_state, config_key, config)
|
||||
else:
|
||||
config = getattr(st.session_state, config_key)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def initialize_resources(inference_config):
|
||||
base_model, base_tokenizer = load_model_for_inference(inference_config, "base")
|
||||
finetuned_model, finetuned_tokenizer = load_model_for_inference(inference_config, "finetuned")
|
||||
|
||||
return {
|
||||
"base": {"model": base_model, "tokenizer": base_tokenizer},
|
||||
"finetuned": {"model": finetuned_model, "tokenizer": finetuned_tokenizer},
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
# set page ui
|
||||
st.title("Image VLM Finetuning")
|
||||
st.caption("A DGX Spark showcase for on-device VLM finetuning")
|
||||
# st.page_link("https://github.com/your-username/your-repo", label="GitHub", icon=":material/github:")
|
||||
|
||||
# load css
|
||||
with open("src/styles.css", "r") as f:
|
||||
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
||||
|
||||
# load resources
|
||||
config = load_config()
|
||||
if st.session_state.get("base", None) is None:
|
||||
st.toast("Loading model", icon="⏳", duration="short")
|
||||
resource = initialize_resources(config["inference"])
|
||||
if st.session_state.get("base", None) is None:
|
||||
st.toast("Model loaded", icon="✅", duration="short")
|
||||
initialize_session_state(resource)
|
||||
|
||||
# train section
|
||||
st.markdown("---")
|
||||
train_section()
|
||||
|
||||
# inference Section
|
||||
st.markdown("---")
|
||||
inference_section()
|
||||
|
||||
|
||||
def train_section():
|
||||
st.header("GRPO Training")
|
||||
|
||||
column_1, column_2, column_3 = st.columns(3, gap="large")
|
||||
with column_1:
|
||||
finetuning_method = st.selectbox(
|
||||
"Finetuning Method:",
|
||||
["LoRA", "QLoRA", "Full Fine-tuning"],
|
||||
)
|
||||
|
||||
# update lora config
|
||||
if finetuning_method in ("QLoRA", "LoRA"):
|
||||
lora_config = st.session_state["config"]["train"]["model"]["lora_config"]
|
||||
|
||||
with column_2:
|
||||
lora_rank = st.slider(
|
||||
"LoRA Rank",
|
||||
min_value=8,
|
||||
max_value=64,
|
||||
value=lora_config["rank"],
|
||||
step=8,
|
||||
)
|
||||
|
||||
with column_3:
|
||||
lora_alpha = st.slider(
|
||||
"LoRA Alpha",
|
||||
min_value=8,
|
||||
max_value=64,
|
||||
value=lora_config["alpha"],
|
||||
step=8,
|
||||
)
|
||||
|
||||
st.session_state["config"]["train"]["model"]["lora_config"].update({
|
||||
'rank': lora_rank,
|
||||
'alpha': lora_alpha,
|
||||
})
|
||||
|
||||
# update model config based on selection
|
||||
st.session_state["config"]["train"]["model"]["use_lora"] = finetuning_method == "LoRA"
|
||||
st.session_state["config"]["train"]["model"]["use_qlora"] = finetuning_method == "QLoRA"
|
||||
|
||||
# update train config
|
||||
st.write("")
|
||||
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
|
||||
with column_1:
|
||||
finetune_vision_layers = st.toggle(
|
||||
"Finetune Vision Layers",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_vision_layers"])
|
||||
|
||||
with column_2:
|
||||
finetune_language_layers = st.toggle(
|
||||
"Finetune Language Layers",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_language_layers"])
|
||||
|
||||
with column_3:
|
||||
finetune_attention_modules = st.toggle(
|
||||
"Finetune Attention Modules",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_attention_modules"])
|
||||
|
||||
with column_4:
|
||||
finetune_mlp_modules = st.toggle(
|
||||
"Finetune MLP Modules",
|
||||
value=st.session_state["config"]["train"]["model"]["finetune_mlp_modules"])
|
||||
|
||||
st.write("")
|
||||
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
|
||||
with column_1:
|
||||
epochs = st.slider(
|
||||
"Epochs",
|
||||
min_value=1,
|
||||
max_value=100,
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["epochs"])
|
||||
|
||||
with column_2:
|
||||
batch_size = st.select_slider(
|
||||
"Batch Size",
|
||||
options=[1, 2, 4, 8, 16],
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["batch_size"])
|
||||
|
||||
with column_3:
|
||||
learning_rate = st.number_input(
|
||||
"Learning Rate",
|
||||
min_value=1e-6,
|
||||
max_value=1e-2,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["learning_rate"]),
|
||||
format="%.2e")
|
||||
|
||||
with column_4:
|
||||
optimizer = st.selectbox(
|
||||
"Optimizer",
|
||||
options=["adamw_torch", "adafactor"])
|
||||
|
||||
st.session_state["config"]["train"]["hyperparameters"].update({
|
||||
'epochs': epochs,
|
||||
'batch_size': batch_size,
|
||||
'learning_rate': learning_rate,
|
||||
'optimizer': optimizer,
|
||||
})
|
||||
|
||||
st.session_state["config"]["train"]["model"].update({
|
||||
'finetune_vision_layers': finetune_vision_layers,
|
||||
'finetune_language_layers': finetune_language_layers,
|
||||
'finetune_attention_modules': finetune_attention_modules,
|
||||
'finetune_mlp_modules': finetune_mlp_modules,
|
||||
})
|
||||
|
||||
st.write("")
|
||||
column_1, column_2, column_3, column_4 = st.columns(4, gap="large")
|
||||
with column_1:
|
||||
enable_grpo = st.toggle(
|
||||
"Enable GRPO",
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["enable_grpo"],
|
||||
disabled=True)
|
||||
|
||||
with column_2:
|
||||
format_reward = st.number_input(
|
||||
"Reward for reasoning format",
|
||||
min_value=0.0,
|
||||
max_value=5.0,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["format_reward"]),
|
||||
format="%.2e")
|
||||
|
||||
with column_3:
|
||||
correctness_reward = st.number_input(
|
||||
"Reward for correct response",
|
||||
min_value=0.0,
|
||||
max_value=5.0,
|
||||
value=float(st.session_state["config"]["train"]["hyperparameters"]["correctness_reward"]),
|
||||
format="%.2e")
|
||||
|
||||
with column_4:
|
||||
num_generations = st.number_input(
|
||||
"Number of generations",
|
||||
min_value=1,
|
||||
max_value=16,
|
||||
value=st.session_state["config"]["train"]["hyperparameters"]["num_generations"],
|
||||
format="%.2e")
|
||||
|
||||
# Training control
|
||||
st.write("")
|
||||
column_1, column_2, column_3 = st.columns([4, 4, 1])
|
||||
|
||||
with column_1:
|
||||
button_type = "secondary" if st.session_state["train_process"] else "primary"
|
||||
if st.button("▶️ Start Finetuning", type=button_type, width="stretch", disabled=bool(st.session_state["train_process"])):
|
||||
if st.session_state["train_process"] is None:
|
||||
# store config
|
||||
with open("src/train.yaml", "w") as f:
|
||||
yaml.dump(st.session_state["config"]["train"], f, default_flow_style=False)
|
||||
|
||||
# start training
|
||||
st.session_state["train_process"] = subprocess.Popen(
|
||||
["python", "src/train_image_vlm.py"],
|
||||
stdout=None, stderr=None
|
||||
)
|
||||
else:
|
||||
st.toast("Training already in progress", icon="❌", duration="short")
|
||||
|
||||
with column_2:
|
||||
button_type = "primary" if st.session_state["train_process"] else "secondary"
|
||||
if st.button("⏹️ Stop Finetuning", type=button_type, width="stretch", disabled=not bool(st.session_state["train_process"])):
|
||||
if st.session_state["train_process"] is not None:
|
||||
st.session_state["train_process"].terminate()
|
||||
st.session_state["train_process"] = None
|
||||
st.toast("Training stopped", icon="✅", duration="short")
|
||||
st.toast("Re-deploy the app with updated finetuned model", icon=":material/info:", duration="short")
|
||||
else:
|
||||
st.toast("No training to stop", icon="❌", duration="short")
|
||||
|
||||
with column_3:
|
||||
if st.session_state["train_process"]:
|
||||
st.badge("Running", icon=":material/hourglass_arrow_up:", color="green", width="stretch")
|
||||
else:
|
||||
st.badge("Idle", icon=":material/hourglass_disabled:", color="red", width="stretch")
|
||||
|
||||
# display wandb
|
||||
runs = wandb.Api().runs(f"{os.environ.get('WANDB_ENTITY')}/{os.environ.get('WANDB_PROJECT')}")
|
||||
if runs:
|
||||
base_url = runs[0].url
|
||||
loss_url = f"{base_url}?panelDisplayName=train%2Floss&panelSectionName=train"
|
||||
memory_url = f"{base_url}?panelDisplayName=GPU+Memory+Allocated+%28%25%29&panelSectionName=System"
|
||||
|
||||
column_1, column_2 = st.columns(2)
|
||||
with column_1:
|
||||
st.markdown(f"""
|
||||
<div class="wandb-wrapper">
|
||||
<iframe src="{loss_url}" class="wandb-iframe"></iframe>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
with column_2:
|
||||
st.markdown(f"""
|
||||
<div class="wandb-wrapper">
|
||||
<iframe src="{memory_url}" class="wandb-iframe"></iframe>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
|
||||
def inference_section():
|
||||
st.header("Image Inference")
|
||||
|
||||
columns = st.columns([3, 3, 1, 2])
|
||||
with columns[1]:
|
||||
with st.container(border=True, horizontal_alignment="center", vertical_alignment="center"):
|
||||
image_holder = st.empty()
|
||||
image_holder.image(st.session_state["current_image"])
|
||||
|
||||
with columns[3]:
|
||||
if st.button("🎲 Test another sample"):
|
||||
while True:
|
||||
current_image = random.choice(glob.glob("assets/image_vlm/images/*/*"))
|
||||
if current_image != st.session_state["current_image"]:
|
||||
break
|
||||
st.session_state["current_image"] = current_image
|
||||
image_holder.image(st.session_state["current_image"])
|
||||
|
||||
columns = st.columns(2, gap="small")
|
||||
with columns[0]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Base Qwen2.5-VL-7B]")
|
||||
base_generation = st.empty()
|
||||
base_generation.write("...")
|
||||
|
||||
with columns[1]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Finetuned Qwen2.5-VL-7B]")
|
||||
finetuned_generation = st.empty()
|
||||
finetuned_generation.write("...")
|
||||
|
||||
columns = st.columns([9, 1], gap="small")
|
||||
with columns[0]:
|
||||
prompt = st.text_input(
|
||||
"Prompt Input",
|
||||
label_visibility="collapsed",
|
||||
key="prompt_input",
|
||||
on_change=lambda: st.session_state.update(prompt=st.session_state["prompt_input"])
|
||||
)
|
||||
|
||||
with columns[1]:
|
||||
if st.button("Generate", width="stretch"):
|
||||
if st.session_state.get("prompt", None):
|
||||
st.session_state["prompt"] = prompt
|
||||
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("base")
|
||||
base_generation.markdown(response)
|
||||
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("finetuned")
|
||||
finetuned_generation.markdown(response)
|
||||
|
||||
|
||||
def load_model_for_inference(config, model_type):
|
||||
if model_type == "finetuned":
|
||||
model_name = config["finetuned_model_id"]
|
||||
elif model_type == "base":
|
||||
model_name = config["model_id"]
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {model_type}")
|
||||
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=config["max_seq_length"],
|
||||
load_in_4bit=False,
|
||||
)
|
||||
|
||||
FastVisionModel.for_inference(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def start_inference(model_type):
|
||||
# define prompt
|
||||
prompt = st.session_state["prompt"]
|
||||
if model_type == "finetuned":
|
||||
prompt = (
|
||||
f"{prompt}. Also first provide your reasoning or working out"\
|
||||
f" on how you would go about identifying the presence of wildfire affected regions between {REASONING_START} and {REASONING_END}"
|
||||
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
|
||||
)
|
||||
|
||||
# load image
|
||||
image = Image.open(st.session_state["current_image"])
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# construct instruction prompt
|
||||
prompt = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# apply chat template
|
||||
prompt = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].apply_chat_template(
|
||||
prompt,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# tokenize inputs
|
||||
inputs = st.session_state[f"{model_type}_image_vlm"]["tokenizer"](
|
||||
image,
|
||||
prompt,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
).to("cuda")
|
||||
|
||||
# perform inference
|
||||
response = st.session_state[f"{model_type}_image_vlm"]["model"].generate(
|
||||
**inputs,
|
||||
max_new_tokens=1024,
|
||||
use_cache=True,
|
||||
do_sample=False
|
||||
)[0][inputs["input_ids"].shape[1]: ]
|
||||
|
||||
# decode tokens
|
||||
response = st.session_state[f"{model_type}_image_vlm"]["tokenizer"].decode(response, skip_special_tokens=True)
|
||||
|
||||
# format response
|
||||
if model_type == "finetuned":
|
||||
response = response.replace(REASONING_START, "```")
|
||||
response = response.replace(REASONING_END, "```")
|
||||
|
||||
# Handle solution formatting with proper newline handling
|
||||
solution_pattern = f'{re.escape(SOLUTION_START)}(.*?){re.escape(SOLUTION_END)}'
|
||||
solution_match = re.search(solution_pattern, response, re.DOTALL)
|
||||
if solution_match:
|
||||
solution_content = solution_match.group(1).strip()
|
||||
response = re.sub(solution_pattern, f"**{solution_content}**", response, flags=re.DOTALL)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
nvidia/vlm-finetuning/assets/ui_image/README.md
Normal file
@ -0,0 +1,137 @@
|
||||
# Image VLM Fine-tuning with Qwen2.5-VL
|
||||
|
||||
This project demonstrates fine-tuning Vision-Language Models (VLMs) for image understanding tasks, specifically using the Qwen2.5-VL-7B model for wildfire detection from satellite imagery using GRPO (Generalized Reward Preference Optimization).
|
||||
|
||||
## Overview
|
||||
|
||||
The project includes:
|
||||
- **Interactive Training Interface**: Streamlit-based UI for configuring and monitoring VLM fine-tuning
|
||||
- **GRPO Training**: Advanced preference optimization for better reasoning capabilities
|
||||
- **Multiple Fine-tuning Methods**: Support for LoRA, QLoRA, and Full Finetuning
|
||||
- **Side-by-side Inference**: Compare base model vs fine-tuned model performance
|
||||
|
||||
## Getting Started
|
||||
|
||||
> **Note**: These instructions assume you are already inside the Docker container. For container setup, refer to the main project README at `vlm-finetuning/`.
|
||||
|
||||
### 1. Set Up Weights & Biases
|
||||
|
||||
Configure your wandb credentials for training monitoring:
|
||||
|
||||
```bash
|
||||
export WANDB_PROJECT="vlm_finetuning"
|
||||
export WANDB_ENTITY=<WANDB_USERNAME>
|
||||
export WANDB_API_KEY=<WANDB_API_KEY>
|
||||
```
|
||||
|
||||
### 2. Launch the Application
|
||||
|
||||
```bash
|
||||
# Start the Streamlit interface
|
||||
streamlit run Image_VLM.py
|
||||
```
|
||||
|
||||
The application will be available at `http://localhost:8501`
|
||||
|
||||
## Training
|
||||
|
||||
### Dataset
|
||||
|
||||
The project uses a **wildfire detection dataset** with satellite imagery for training the model to identify wildfire-affected regions. The dataset includes:
|
||||
- Satellite and aerial imagery from wildfire-affected areas
|
||||
- Binary classification: wildfire vs no wildfire
|
||||
|
||||
#### Dataset Setup
|
||||
|
||||
1. **Download from Kaggle**: Visit the [Wildfire Prediction Dataset](https://www.kaggle.com/datasets/abdelghaniaaba/wildfire-prediction-dataset) on Kaggle
|
||||
|
||||
2. **Get the curl command**: On the Kaggle dataset page, click the download button and copy the curl command provided
|
||||
|
||||
3. **Download and extract**: Run the following commands in your container:
|
||||
|
||||
```bash
|
||||
mkdir data
|
||||
cd data
|
||||
|
||||
# Paste the curl command from Kaggle here, and then continue to unzip the dataset
|
||||
|
||||
unzip -qq wildfire-prediction-dataset.zip
|
||||
rm wildfire-prediction-dataset.zip
|
||||
cd ..
|
||||
```
|
||||
|
||||
> **Note**: You'll need to be logged into Kaggle and may need to accept the dataset terms before the download link works.
|
||||
|
||||
### Training Configuration
|
||||
|
||||
Configure training through the interactive interface:
|
||||
|
||||
#### Model Settings
|
||||
- **Base Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
- **Fine-tuning Method**: Choose from LoRA, QLoRA, or Full Finetuning
|
||||
- **LoRA Parameters**: Adjustable rank (8-64) and alpha (8-64)
|
||||
|
||||
#### Training Parameters
|
||||
- **Epochs**: 1-100 (default: 10)
|
||||
- **Batch Size**: 1, 2, 4, 8, or 16 (default: 2)
|
||||
- **Learning Rate**: 1e-6 to 1e-2 (default: 1e-5)
|
||||
- **Optimizer**: AdamW or Adafactor
|
||||
|
||||
#### GRPO Settings
|
||||
- **Format Reward**: 2.0 (reward for proper reasoning format)
|
||||
- **Correctness Reward**: 5.0 (reward for correct answers)
|
||||
- **Number of Generations**: 4 (for preference optimization)
|
||||
|
||||
### Training Process
|
||||
|
||||
1. **Configure Parameters**: Use the web interface to set training hyperparameters
|
||||
2. **Start Training**: Click "▶️ Start Finetuning" to begin GRPO training
|
||||
3. **Monitor Progress**: View real-time loss curves and GPU utilization via embedded wandb charts
|
||||
4. **Stop if Needed**: Use "⏹️ Stop Finetuning" to halt training early
|
||||
|
||||
> **Important**: After training completes, follow these steps:
|
||||
> 1. **Stop the UI**: Use Ctrl+C to stop the Streamlit application
|
||||
> 2. **Update Config**: Edit `src/image_vlm_config.yaml` and change the `finetuned_model_id` path to point to your newly trained model in the `saved_model/` directory
|
||||
> 3. **Restart UI**: Launch the application again to test your fine-tuned model
|
||||
|
||||
## Inference
|
||||
|
||||
### Interactive Comparison
|
||||
|
||||

|
||||
*Side-by-side comparison showing base model vs fine-tuned model performance on wildfire detection*
|
||||
|
||||
The inference section provides:
|
||||
- **Sample Images**: Test on wildfire detection samples from both categories (wildfire/no wildfire)
|
||||
- **Dual Inference**: Run both base and fine-tuned models simultaneously
|
||||
- **Random Sampling**: Test different samples with the "🎲 Test another sample" button
|
||||
- **Structured Reasoning**: Fine-tuned model provides reasoning in `<REASONING>` tags before final answer
|
||||
|
||||
### Sample Questions
|
||||
|
||||
The interface includes prompts for wildfire detection:
|
||||
- "Identify if this region has been affected by a wildfire"
|
||||
- The fine-tuned model provides structured reasoning followed by a Yes/No answer
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
ui_image/
|
||||
├── Image_VLM_Finetuning.py # Main Streamlit application
|
||||
├── README.md # This file
|
||||
├── src/
|
||||
│ ├── image_vlm_config.yaml # Configuration file (update finetuned_model_id after training)
|
||||
│ └── styles.css # Custom UI styling
|
||||
├── assets/
|
||||
│ └── image_vlm/
|
||||
│ └── images/
|
||||
│ ├── wildfire/ # Wildfire-affected images
|
||||
│ └── nowildfire/ # Non-wildfire images
|
||||
├── assets/
|
||||
│ └── inference_screenshot.png # UI demonstration screenshot
|
||||
└── saved_model/ # Training checkpoints directory (update config to point here)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The `src/image_vlm_config.yaml` file contains all training and inference settings:
|
||||
|
After Width: | Height: | Size: 18 KiB |
|
After Width: | Height: | Size: 36 KiB |
|
After Width: | Height: | Size: 27 KiB |
|
After Width: | Height: | Size: 15 KiB |
|
After Width: | Height: | Size: 1.2 MiB |
16
nvidia/vlm-finetuning/assets/ui_image/src/__init__.py
Executable file
@ -0,0 +1,16 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
50
nvidia/vlm-finetuning/assets/ui_image/src/image_vlm_config.yaml
Executable file
@ -0,0 +1,50 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
inference:
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
finetuned_model_id: RLakshmi24/qwen_wildfire_qrpo_lora
|
||||
max_seq_length: 16384
|
||||
|
||||
train:
|
||||
model:
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
max_seq_length: 16384
|
||||
use_lora: true
|
||||
use_qlora: false
|
||||
lora_config:
|
||||
rank: 32
|
||||
alpha: 64
|
||||
dropout: 0.05
|
||||
finetune_vision_layers: true
|
||||
finetune_language_layers: true
|
||||
finetune_attention_modules: true
|
||||
finetune_mlp_modules: true
|
||||
|
||||
data:
|
||||
dataset_id: data
|
||||
|
||||
hyperparameters:
|
||||
epochs: 10
|
||||
batch_size: 2
|
||||
enable_grpo: true
|
||||
num_generations: 4
|
||||
format_reward: 2.0
|
||||
learning_rate: 1e-5
|
||||
correctness_reward: 5.0
|
||||
optimizer: adamw_torch
|
||||
output_dir: saved_model
|
||||
178
nvidia/vlm-finetuning/assets/ui_image/src/styles.css
Executable file
@ -0,0 +1,178 @@
|
||||
/*
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
*/
|
||||
|
||||
/* VLM Fine-tuning Streamlit App Styles */
|
||||
:root {
|
||||
--nvidia-green: #76b900;
|
||||
}
|
||||
|
||||
/* Set maximum page width - multiple selectors for compatibility */
|
||||
.main .block-container,
|
||||
.block-container,
|
||||
.main > div,
|
||||
section.main > div {
|
||||
max-width: 1200px !important;
|
||||
margin: 0 auto !important;
|
||||
padding-left: 1rem !important;
|
||||
padding-right: 1rem !important;
|
||||
}
|
||||
|
||||
/* h3 {
|
||||
color:#76b900 !important;
|
||||
} */
|
||||
|
||||
.main > div {
|
||||
padding-top: 1rem;
|
||||
}
|
||||
|
||||
/* Global button styling - applies to ALL buttons */
|
||||
.stButton > button,
|
||||
button[kind="primary"],
|
||||
button[kind="secondary"] {
|
||||
width: 100% !important;
|
||||
font-size: 16px !important;
|
||||
font-weight: 600 !important;
|
||||
padding: 12px 20px !important;
|
||||
border-radius: 8px !important;
|
||||
border: none !important;
|
||||
transition: all 0.3s ease !important;
|
||||
cursor: pointer !important;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
|
||||
}
|
||||
|
||||
/* Primary button styling */
|
||||
.stButton > button[kind="primary"],
|
||||
button[kind="primary"] {
|
||||
background: linear-gradient(135deg, #32864c 0%, #6edd93 100%) !important;
|
||||
/* background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; */
|
||||
color: white !important;
|
||||
}
|
||||
|
||||
/* Primary button hover effect */
|
||||
.stButton > button[kind="primary"]:hover,
|
||||
button[kind="primary"]:hover {
|
||||
transform: translateY(-2px) !important;
|
||||
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
|
||||
}
|
||||
|
||||
/* Secondary button styling */
|
||||
.stButton > button[kind="secondary"],
|
||||
button[kind="secondary"] {
|
||||
background: linear-gradient(135deg, #b6fb93 0%, #57f586 100%) !important;
|
||||
/* background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; */
|
||||
/* color: white !important; */
|
||||
color: darkslategray !important;
|
||||
}
|
||||
|
||||
/* Regular button styling */
|
||||
.stButton > button:not([kind]) {
|
||||
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%) !important;
|
||||
color: white !important;
|
||||
}
|
||||
|
||||
/* Button hover effect for all buttons */
|
||||
.stButton > button:hover {
|
||||
transform: translateY(-2px) !important;
|
||||
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
|
||||
}
|
||||
|
||||
/* Button active/pressed state */
|
||||
.stButton > button:active {
|
||||
transform: translateY(0px) !important;
|
||||
}
|
||||
|
||||
/* Metric card styling */
|
||||
.metric-card {
|
||||
background-color: #f0f2f6;
|
||||
padding: 1rem;
|
||||
border-radius: 0.5rem;
|
||||
border-left: 4px solid #ff6b6b;
|
||||
}
|
||||
|
||||
/* Center images and maintain aspect ratio */
|
||||
.stImage > img {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 500px;
|
||||
height: auto;
|
||||
max-height: 300px
|
||||
}
|
||||
|
||||
/* Larger text input for questions */
|
||||
.stTextInput > div > div > input {
|
||||
font-size: 18px !important;
|
||||
padding: 10px !important;
|
||||
}
|
||||
|
||||
/* Larger inference button */
|
||||
.inference-button button,
|
||||
.inference-button .stButton > button {
|
||||
padding: 20px !important;
|
||||
height: 200px !important;
|
||||
min-height: 100% !important;
|
||||
max-height: 100% !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
line-height: 1.2 !important;
|
||||
white-space: nowrap !important;
|
||||
font-size: large !important;
|
||||
}
|
||||
|
||||
|
||||
.card {
|
||||
background-color: #1F2022;
|
||||
color: #bdc3c7;
|
||||
padding: 1.5rem;
|
||||
border-radius: 0.5rem;
|
||||
/* outline: 2px solid var(--nvidia-green); */
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.10);
|
||||
text-align: left;
|
||||
height: 300px;
|
||||
overflow-y: auto;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.card--empty {
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
height: 300px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
/* Wandb iframe wrapper for cropping */
|
||||
.wandb-wrapper {
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
border-radius: 15px;
|
||||
height: var(--visible-height, 320px);
|
||||
}
|
||||
|
||||
/* Wandb iframe styling */
|
||||
.wandb-iframe {
|
||||
border: none;
|
||||
width: 100%;
|
||||
position: absolute;
|
||||
left: 0;
|
||||
height: var(--iframe-height, 600px);
|
||||
top: var(--shift-up, -230px);
|
||||
}
|
||||
25
nvidia/vlm-finetuning/assets/ui_image/src/train.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
data:
|
||||
dataset_id: data
|
||||
hyperparameters:
|
||||
batch_size: 1
|
||||
correctness_reward: 5.0
|
||||
enable_grpo: true
|
||||
epochs: 2
|
||||
format_reward: 2.0
|
||||
learning_rate: 1.0e-05
|
||||
num_generations: 4
|
||||
optimizer: adamw_torch
|
||||
output_dir: saved_model
|
||||
model:
|
||||
finetune_attention_modules: true
|
||||
finetune_language_layers: true
|
||||
finetune_mlp_modules: true
|
||||
finetune_vision_layers: true
|
||||
lora_config:
|
||||
alpha: 64
|
||||
dropout: 0.05
|
||||
rank: 16
|
||||
max_seq_length: 16384
|
||||
model_id: unsloth/Qwen2.5-VL-7B-Instruct
|
||||
use_lora: true
|
||||
use_qlora: false
|
||||
204
nvidia/vlm-finetuning/assets/ui_image/src/train_image_vlm.py
Normal file
@ -0,0 +1,204 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unsloth import FastVisionModel
|
||||
|
||||
import re
|
||||
import yaml
|
||||
|
||||
from PIL import ImageFile
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
REASONING_START = "<REASONING>"
|
||||
REASONING_END = "</REASONING>"
|
||||
SOLUTION_START = "<SOLUTION>"
|
||||
SOLUTION_END = "</SOLUTION>"
|
||||
|
||||
|
||||
def load_model_for_train(config):
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
model_name=config["model"]["model_id"],
|
||||
max_seq_length=config["model"]["max_seq_length"],
|
||||
load_in_4bit=config["model"]["use_qlora"],
|
||||
)
|
||||
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers=config["model"]["finetune_vision_layers"],
|
||||
finetune_language_layers=config["model"]["finetune_language_layers"],
|
||||
finetune_attention_modules=config["model"]["finetune_attention_modules"],
|
||||
finetune_mlp_modules=config["model"]["finetune_mlp_modules"],
|
||||
r=config["model"]["lora_config"]["rank"],
|
||||
lora_alpha=config["model"]["lora_config"]["alpha"],
|
||||
lora_dropout=config["model"]["lora_config"]["dropout"],
|
||||
bias="none",
|
||||
random_state=42,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def format_instruction(sample, label_dict):
|
||||
label = label_dict.int2str(sample["label"])
|
||||
if label == "nowildfire":
|
||||
answer = "No"
|
||||
else:
|
||||
answer = "Yes"
|
||||
|
||||
# reasoning prompt
|
||||
prompt = "Identify if this region has been affected by a wildfire"
|
||||
prompt = (
|
||||
f"{prompt}. Also first provide your reasoning or working out"\
|
||||
f" on how you would go about identifying the presence of wildfire affected regions between {REASONING_START} and {REASONING_END}"
|
||||
f" and then your final answer between {SOLUTION_START} and (put a simple Yes or No here) {SOLUTION_END}"
|
||||
)
|
||||
|
||||
# convert image format
|
||||
image = sample["image"]
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# construct instruction prompt
|
||||
prompt = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
return {"prompt": prompt, "image": sample["image"], "answer": answer}
|
||||
|
||||
|
||||
def load_wildfire_dataset(config, tokenizer):
|
||||
# load dataset
|
||||
train_dataset = load_dataset(config["data"]["dataset_id"])["train"]
|
||||
|
||||
# preprocess the dataset
|
||||
train_dataset = train_dataset.map(
|
||||
lambda sample: format_instruction(sample, train_dataset.features["label"]),
|
||||
num_proc=8)
|
||||
train_dataset = train_dataset.map(
|
||||
lambda sample: {
|
||||
"prompt": tokenizer.apply_chat_template(
|
||||
sample["prompt"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
}, num_proc=8)
|
||||
|
||||
return train_dataset
|
||||
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
thinking_pattern = f'{REASONING_START}(.*?){REASONING_END}'
|
||||
answer_pattern = f'{SOLUTION_START}(.*?){SOLUTION_END}'
|
||||
|
||||
scores = []
|
||||
for completion in completions:
|
||||
score = 0
|
||||
thinking_matches = re.findall(thinking_pattern, completion, re.DOTALL)
|
||||
answer_matches = re.findall(answer_pattern, completion, re.DOTALL)
|
||||
if len(thinking_matches) == 1:
|
||||
score += 1.0
|
||||
if len(answer_matches) == 1:
|
||||
score += 1.0
|
||||
|
||||
# penalize excessive addCriterion predictions in qwen2.5-vl
|
||||
if len(completion) != 0:
|
||||
removal = completion.replace("addCriterion", "").replace("\n", "")
|
||||
if (len(completion)-len(removal))/len(completion) >= 0.5:
|
||||
score -= 1.0
|
||||
|
||||
scores.append(score)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
||||
answer_pattern = f'{SOLUTION_START}(.*?){SOLUTION_END}'
|
||||
|
||||
responses = [re.findall(answer_pattern, completion, re.DOTALL) for completion in completions]
|
||||
q = prompts[0]
|
||||
print("----------------------------------")
|
||||
print(f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{completions[0]}")
|
||||
return [
|
||||
5.0 if len(r)==1 and a == r[0].replace('\n','') else 0.0
|
||||
for r, a in zip(responses, answer)
|
||||
]
|
||||
|
||||
|
||||
def start_train(config):
|
||||
# load base model for finetuning
|
||||
model, tokenizer = load_model_for_train(config)
|
||||
|
||||
# load dataset
|
||||
train_dataset = load_wildfire_dataset(config, tokenizer)
|
||||
|
||||
# define training arguments
|
||||
training_args = GRPOConfig(
|
||||
learning_rate=config["hyperparameters"]["learning_rate"],
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
optim="adamw_torch",
|
||||
logging_steps=1,
|
||||
log_completions=False,
|
||||
per_device_train_batch_size=config["hyperparameters"]["batch_size"],
|
||||
gradient_accumulation_steps=1,
|
||||
num_generations=2,
|
||||
max_prompt_length=config["model"]["max_seq_length"],
|
||||
max_completion_length=config["model"]["max_seq_length"],
|
||||
num_train_epochs=config["hyperparameters"]["epochs"],
|
||||
save_steps=100,
|
||||
max_grad_norm=0.1,
|
||||
report_to="none",
|
||||
output_dir=config["hyperparameters"]["output_dir"],
|
||||
# importance_sampling_level="sequence",
|
||||
# mask_truncated_completions=False,
|
||||
# loss_type="dr_grpo",
|
||||
)
|
||||
|
||||
# start training
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
reward_funcs=[
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("src/train.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
start_train(config)
|
||||
@ -0,0 +1,3 @@
|
||||
[theme]
|
||||
base="dark"
|
||||
greenTextColor = "#76b900"
|
||||
137
nvidia/vlm-finetuning/assets/ui_video/README.md
Normal file
@ -0,0 +1,137 @@
|
||||
# Video VLM Fine-tuning with InternVL3
|
||||
|
||||
This project demonstrates fine-tuning the InternVL3 model for video analysis, specifically for dangerous driving detection and structured metadata generation from driving videos.
|
||||
|
||||
## Workflow Overview
|
||||
|
||||

|
||||
|
||||
### Training Workflow Steps:
|
||||
|
||||
1. **🎥 Dashcam Footage**: Dashcam footage from the Nexar Collision Prediction dataset
|
||||
2. **Generate Structed caption**: Leverage a very large VLM (InternVL3-78B) to generate structured captions from raw videos
|
||||
3. **🧠 Train InternVL3 Model**: Perform Supervised Finetuning on InternVL3-8B to extract structured metadata
|
||||
4. **🚀 Fine-tuned VLM**: Trained model ready for analysing driver behaviour and risk factors
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
### Data Requirements
|
||||
|
||||
Your dataset should be structured as follows:
|
||||
```
|
||||
dataset/
|
||||
├── videos/
|
||||
│ ├── video1.mp4
|
||||
│ ├── video2.mp4
|
||||
│ └── ...
|
||||
└── metadata.jsonl # Contains video paths and labels
|
||||
```
|
||||
|
||||
Each line in `metadata.jsonl` should contain:
|
||||
```json
|
||||
{
|
||||
"video": "videos/video1.mp4",
|
||||
"caption": "Description of the video events",
|
||||
"event_type": "collision" | "near_miss" | "no_incident",
|
||||
"rule_violations": choose relevant items from ["speeding", "failure_to_yield", "ignoring_traffic_signs"],
|
||||
"intended_driving_action": "turn_left" | "turn_right" | "change_lanes",
|
||||
"traffic_density": "low" | "high",
|
||||
"visibility": "good" | "bad"
|
||||
}
|
||||
```
|
||||
|
||||
### Running Training
|
||||
|
||||
1. **Update Dataset Path**: Edit the training notebook to point to your dataset:
|
||||
```python
|
||||
dataset_path = "/path/to/your/dataset"
|
||||
```
|
||||
|
||||
2. **Run Training Notebook**:
|
||||
```bash
|
||||
# Inside the container, navigate to the training directory
|
||||
cd ui_video/train
|
||||
jupyter notebook video_vlm.ipynb
|
||||
```
|
||||
|
||||
3. **Monitor Training**: Training progress and metrics are displayed directly in the notebook interface.
|
||||
|
||||
### Training Configuration
|
||||
|
||||
Key training parameters configurable:
|
||||
- **Model**: InternVL3-8B
|
||||
- **Video Frames**: 12 to 16 frames per video
|
||||
- **Sampling Mode**: Uniform temporal sampling
|
||||
- **LoRA Configuration**: Efficient parameter updates for large-scale fine-tuning
|
||||
- **Hyperparameters**: Exhaustive suite of hyperparameters to tune for video VLM finetuning
|
||||
|
||||
## Inference
|
||||
|
||||
### Running Inference
|
||||
|
||||
1. **Streamlit Web Interface**:
|
||||
```bash
|
||||
# Start the interactive web interface
|
||||
cd ui_video
|
||||
streamlit run Video_VLM.py
|
||||
```
|
||||
|
||||
The interface provides:
|
||||
- Dashcam video gallery and playback
|
||||
- Side-by-side comparison between base and finetuned model
|
||||
- JSON output generation
|
||||
- Tabular view of structured data extracted for analysis
|
||||
|
||||
2. **Configuration**: Edit `src/video_vlm_config.yaml` to modify model settings, frame count, and sampling strategy.
|
||||
|
||||
### Sample Output
|
||||
|
||||
The model generates structured JSON output like:
|
||||
```json
|
||||
{
|
||||
"caption": "A vehicle makes a dangerous lane change without signaling while speeding on a highway during daytime with clear weather conditions.",
|
||||
"event_type": "near_miss",
|
||||
"cause_of_risk": ["speeding", "risky_maneuver"],
|
||||
"presence_of_rule_violations": ["failure_to_use_turn_signals"],
|
||||
"intended_driving_action": ["change_lanes"],
|
||||
"traffic_density": "medium",
|
||||
"driving_setting": ["highway"],
|
||||
"time_of_day": "day",
|
||||
"light_conditions": "normal",
|
||||
"weather": "clear",
|
||||
"scene": "highway"
|
||||
}
|
||||
```
|
||||
|
||||
Inference Screenshot
|
||||
|
||||

|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
ui_video/
|
||||
├── README.md # This file
|
||||
├── Video_VLM.py # Streamlit web interface for inference
|
||||
├── src/
|
||||
│ ├── styles.css # CSS styling for Streamlit app
|
||||
│ └── video_vlm_config.yaml # Model and inference configuration
|
||||
├── train/
|
||||
│ └── video_vlm.ipynb # Jupyter notebook for model training
|
||||
└── assets/
|
||||
└── video_vlm/
|
||||
├── videos/ # Sample video files
|
||||
└── thumbnails/ # Video thumbnail previews
|
||||
|
||||
# Root directory also contains:
|
||||
├── Dockerfile # Multi-stage Docker build with FFmpeg/Decord
|
||||
└── launch.sh # Docker launch script
|
||||
```
|
||||
|
||||
## Model Capabilities
|
||||
|
||||
The fine-tuned InternVL3 model can:
|
||||
- **Video Analysis**: Process multi-frame dashcam footage for comprehensive scene understanding
|
||||
- **Safety Detection**: Identify dangerous driving patterns, near-misses, and traffic violations
|
||||
- **Structured Output**: Generate JSON metadata with standardized driving scene categories
|
||||
413
nvidia/vlm-finetuning/assets/ui_video/Video_VLM.py
Normal file
@ -0,0 +1,413 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import yaml
|
||||
import string
|
||||
import random
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import streamlit as st
|
||||
import torchvision.transforms as T
|
||||
from decord import VideoReader, cpu
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
SCAP_PROMPT = """You are a vision-language assistant analyzing driving videos. You will receive a 5-second video clip of a specific scene. {prompt}
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Dense Caption
|
||||
Generate a 2 sentence caption describing:
|
||||
- Ego vehicle behavior
|
||||
- Interactions with other vehicles or pedestrians
|
||||
|
||||
Focus on **what happens**, **when**, and **who/what is involved**, using only visible information and metadata.
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Structured JSON
|
||||
Generate the caption from the perspective of the ego vehicle in a structured JSON object with:
|
||||
|
||||
- `"caption"`: from Task 1
|
||||
- `"event_type"`: "collision" | "near_miss" | "no_incident"
|
||||
- `"rule_violations"`: choose relevant items from ["speeding", "failure_to_yield", "ignoring_traffic_signs"]
|
||||
- `"intended_action"`: "turn_left" | "turn_right" | "change_lanes"
|
||||
- `"traffic_density"`: "low" | "high"
|
||||
- `"visibility"`: "good" | "bad"
|
||||
- `"scene"`: "Urban" | "Sub-urban" | "Rural" | "Highway"
|
||||
|
||||
**Rules:**
|
||||
1. Use only visible info and metadata.
|
||||
2. Do not invent details.
|
||||
3. Include all fields; enum values must match allowed options.
|
||||
4. Output a single valid JSON object—no extra text or markdown.
|
||||
"""
|
||||
|
||||
|
||||
def random_id():
|
||||
chars = string.ascii_letters + string.digits
|
||||
return "".join(random.choices(chars, k=8)).lower()
|
||||
|
||||
|
||||
def initialize_session_state(resources):
|
||||
# Initialize page-specific session state
|
||||
st.session_state["base_video_vlm"] = st.session_state.get("base_video_vlm", resources["base"])
|
||||
st.session_state["finetuned_video_vlm"] = st.session_state.get("finetuned_video_vlm", resources["finetuned"])
|
||||
st.session_state["current_sample"] = st.session_state.get("current_sample", None)
|
||||
st.session_state["df"] = st.session_state.get("df",
|
||||
pd.DataFrame(columns=[
|
||||
"Driver ID",
|
||||
"Event Type",
|
||||
"Rule Violations",
|
||||
"Intended Action",
|
||||
"Traffic Density",
|
||||
"Driving Scene",
|
||||
"Visibility"]))
|
||||
|
||||
|
||||
def load_config():
|
||||
config_key = "config_video_vlm"
|
||||
if getattr(st.session_state, config_key, None) is None:
|
||||
with open("src/video_vlm_config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
setattr(st.session_state, config_key, config)
|
||||
else:
|
||||
config = getattr(st.session_state, config_key)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def initialize_resources(inference_config):
|
||||
base_model = load_model_for_inference(inference_config, "base")
|
||||
finetuned_model = load_model_for_inference(inference_config, "finetuned")
|
||||
|
||||
return {
|
||||
"base": {"model": base_model},
|
||||
"finetuned": {"model": finetuned_model},
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
# set page ui
|
||||
st.title(":green[Video VLM on DGX Spark]")
|
||||
st.caption("Driver Behavior Analysis via Structured Video Captioning")
|
||||
# st.page_link("https://github.com/your-username/your-repo", label="GitHub", icon=":material/github:")
|
||||
|
||||
# load css
|
||||
with open("src/styles.css", "r") as f:
|
||||
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
||||
|
||||
# load resources
|
||||
config = load_config()
|
||||
if st.session_state.get("base_video_vlm", None) is None:
|
||||
st.toast("Loading model", icon="⏳", duration="short")
|
||||
resource = initialize_resources(config["inference"])
|
||||
if st.session_state.get("base_video_vlm", None) is None:
|
||||
st.toast("Model loaded", icon="✅", duration="short")
|
||||
initialize_session_state(resource)
|
||||
|
||||
# gallery section
|
||||
st.markdown("---")
|
||||
st.header("Dashcam Gallery")
|
||||
|
||||
columns = st.columns([4, 1, 1, 4, 1, 1, 4, 1], gap="small")
|
||||
|
||||
with columns[0]:
|
||||
st.image("assets/video_vlm/thumbnails/1.png")
|
||||
|
||||
with columns[1]:
|
||||
if st.button(":material/file_open:", key="video_1"):
|
||||
st.session_state["current_sample"] = "assets/video_vlm/videos/1.mp4"
|
||||
|
||||
with columns[3]:
|
||||
st.image("assets/video_vlm/thumbnails/2.png")
|
||||
|
||||
with columns[4]:
|
||||
if st.button(":material/file_open:", key="video_2"):
|
||||
st.session_state["current_sample"] = "assets/video_vlm/videos/2.mp4"
|
||||
|
||||
with columns[6]:
|
||||
st.image("assets/video_vlm/thumbnails/3.png")
|
||||
|
||||
with columns[7]:
|
||||
if st.button(":material/file_open:", key="video_3"):
|
||||
st.session_state["current_sample"] = "assets/video_vlm/videos/3.mp4"
|
||||
|
||||
# inference section
|
||||
st.markdown("---")
|
||||
st.header("Video Inference")
|
||||
|
||||
with st.container(border=True):
|
||||
if st.session_state["current_sample"]:
|
||||
st.video(st.session_state["current_sample"], autoplay=True, loop=True, muted=True)
|
||||
else:
|
||||
st.write(":gray[Please select a video from the dashcam gallery.]")
|
||||
|
||||
columns = st.columns(2, gap="small")
|
||||
with columns[0]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Base InternVL3-8B]")
|
||||
base_generation = st.empty()
|
||||
base_generation.write("...")
|
||||
|
||||
with columns[1]:
|
||||
with st.container(border=True):
|
||||
st.write("##### :green[Finetuned InternVL3-8B]")
|
||||
finetuned_generation = st.empty()
|
||||
finetuned_generation.write("...")
|
||||
|
||||
columns = st.columns([9, 1], gap="small")
|
||||
with columns[0]:
|
||||
prompt = st.text_input(
|
||||
"Prompt Input",
|
||||
label_visibility="collapsed",
|
||||
key="prompt_input",
|
||||
on_change=lambda: st.session_state.update(prompt=st.session_state.prompt_input)
|
||||
)
|
||||
|
||||
with columns[1]:
|
||||
if st.button("Generate", width="stretch"):
|
||||
if st.session_state.get("prompt", None):
|
||||
st.session_state["prompt"] = prompt
|
||||
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("base")
|
||||
base_generation.markdown(response)
|
||||
|
||||
with st.spinner("Running..."):
|
||||
response = start_inference("finetuned")
|
||||
finetuned_generation.markdown(response)
|
||||
|
||||
response = json.loads(response[7: -3].strip())
|
||||
response["caption"] = random_id() # replace caption with driver id
|
||||
st.session_state["df"].loc[len(st.session_state["df"])] = list(response.values())
|
||||
|
||||
# data analysis section
|
||||
st.markdown("---")
|
||||
st.header("Data Analysis")
|
||||
|
||||
with st.container(border=True):
|
||||
st.dataframe(st.session_state["df"])
|
||||
|
||||
|
||||
class InternVLVideoProcessor:
|
||||
def __init__(self):
|
||||
self.frame_size = 448
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Resize(self.frame_size, interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * self.frame_size * self.frame_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def dynamic_preprocess(self, frame, min_num=1, max_num=12, use_thumbnail=False):
|
||||
orig_width, orig_height = frame.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, self.frame_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = self.frame_size * target_aspect_ratio[0]
|
||||
target_height = self.frame_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = frame.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // self.frame_size)) * self.frame_size,
|
||||
(i // (target_width // self.frame_size)) * self.frame_size,
|
||||
((i % (target_width // self.frame_size)) + 1) * self.frame_size,
|
||||
((i // (target_width // self.frame_size)) + 1) * self.frame_size
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = frame.resize((self.frame_size, self.frame_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
def load_video(self, video_path, num_frames, start_frame=None, end_frame=None):
|
||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
|
||||
if start_frame is None:
|
||||
start_frame = 0
|
||||
if end_frame is None:
|
||||
end_frame = len(vr) - 1
|
||||
|
||||
# sample a random number of equally-spaced frames from the video
|
||||
frame_indices = np.linspace(
|
||||
start_frame,
|
||||
end_frame,
|
||||
num_frames,
|
||||
dtype=int
|
||||
)
|
||||
|
||||
pixel_values_list, num_patches_list = [], []
|
||||
for frame_index in frame_indices:
|
||||
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
|
||||
img = self.dynamic_preprocess(img, use_thumbnail=True, max_num=1)
|
||||
pixel_values = [self.transform(tile) for tile in img]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
num_patches_list.append(pixel_values.shape[0])
|
||||
pixel_values_list.append(pixel_values)
|
||||
pixel_values = torch.cat(pixel_values_list)
|
||||
return pixel_values, num_patches_list
|
||||
|
||||
|
||||
class InternVLModel:
|
||||
def __init__(self, model_path):
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=False,
|
||||
).eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
use_fast=False
|
||||
)
|
||||
|
||||
self.processor = InternVLVideoProcessor()
|
||||
self.generation_config = dict(
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
temperature=0.0,
|
||||
num_beams=1,
|
||||
top_k=1,
|
||||
top_p=1.0)
|
||||
|
||||
@torch.no_grad()
|
||||
def infer(self, video_path, prompt, sampling_mode, num_frames=32, chunk_duration=2):
|
||||
if sampling_mode == "default":
|
||||
return self._infer_default(video_path, prompt, num_frames)
|
||||
elif sampling_mode == "real-time":
|
||||
return self._infer_realtime(video_path, prompt, num_frames, chunk_duration)
|
||||
|
||||
def _infer_default(self, video_path, prompt, num_frames):
|
||||
pixel_values, num_patches_list = self.processor.load_video(video_path, num_frames)
|
||||
pixel_values = pixel_values.to(device=self.model.device, dtype=torch.bfloat16)
|
||||
|
||||
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
|
||||
prompt = f"{video_prefix}{prompt}"
|
||||
|
||||
response = self.model.chat(
|
||||
self.tokenizer,
|
||||
pixel_values,
|
||||
prompt,
|
||||
self.generation_config,
|
||||
num_patches_list=num_patches_list
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _infer_realtime(self, video_path, prompt, num_frames, chunk_duration):
|
||||
video = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
fps = video.get_avg_fps()
|
||||
total_frames = len(video)
|
||||
frames_per_chunk = int(chunk_duration * fps)
|
||||
|
||||
for start_frame in range(0, total_frames, frames_per_chunk):
|
||||
end_frame = start_frame + frames_per_chunk
|
||||
if end_frame > total_frames:
|
||||
return
|
||||
|
||||
pixel_values, num_patches_list = self.processor.load_video(video_path, num_frames, start_frame, end_frame)
|
||||
pixel_values = pixel_values.to(device=self.model.device, dtype=torch.bfloat16)
|
||||
|
||||
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
|
||||
prompt = f"{video_prefix}{prompt}"
|
||||
|
||||
response = self.model.chat(
|
||||
self.tokenizer,
|
||||
pixel_values,
|
||||
prompt,
|
||||
self.generation_config,
|
||||
num_patches_list=num_patches_list
|
||||
)
|
||||
|
||||
yield response
|
||||
|
||||
|
||||
def load_model_for_inference(config, model_type):
|
||||
if model_type == "finetuned":
|
||||
model_name = config["finetuned_model_id"]
|
||||
elif model_type == "base":
|
||||
model_name = config["model_id"]
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {model_type}")
|
||||
|
||||
return InternVLModel(model_name)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def start_inference(model_type):
|
||||
# define prompt
|
||||
prompt = st.session_state["prompt"]
|
||||
if model_type == "finetuned":
|
||||
prompt = SCAP_PROMPT.format(prompt=prompt)
|
||||
|
||||
response = st.session_state[f"{model_type}_video_vlm"]["model"].infer(
|
||||
st.session_state["current_sample"],
|
||||
prompt,
|
||||
num_frames=st.session_state["config_video_vlm"]["inference"]["num_frames"],
|
||||
sampling_mode=st.session_state["config_video_vlm"]["inference"]["sampling_mode"]
|
||||
)
|
||||
|
||||
if model_type == "finetuned":
|
||||
response = f"```json\n{json.dumps(json.loads(response), indent=2)}\n```"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
After Width: | Height: | Size: 1.3 MiB |
BIN
nvidia/vlm-finetuning/assets/ui_video/assets/training_video.png
Normal file
|
After Width: | Height: | Size: 554 KiB |
|
After Width: | Height: | Size: 1.2 MiB |
|
After Width: | Height: | Size: 1.1 MiB |