+
+
+
+## Default Models
+| Model | Quantization | Model Type | VRAM |
+|------------------------------|--------------|------------|-------------|
+| GPT-OSS:120B | MXFP4 | Chat | ~ 63.5 GB |
+| Deepseek-Coder:6.7B-Instruct | Q8 | Coding | ~ 9.5 GB |
+| Qwen2.5-VL:7B-Instruct | BF16 | Image | ~ 35.4 GB |
+| Qwen3-Embedding-4B | Q8 | Embedding | ~ 5.39 GB |
+
+**Total VRAM required:** ~114 GB
+
+> **Warning**:
+> Since the default models use majority of available VRAM, ensure that you don't have anything already running on DGX Spark using `nvidia-smi`. If you do, switch to `gpt-oss-20b` following [this guide](#using-different-models).
+
+---
+
+## Quick Start
+#### 1. Clone the repository and change directories to the multi-agent chatbot directory.
+
+#### 2. Run the setup script
+The setup script will take care of pulling model GGUF files from HuggingFace, building base llama cpp server images and starting all the required docker services to serve models, the backend API server as well as the frontend UI.
+```bash
+chmod +x setup.sh
+./setup.sh
+```
+Wait for all the containers to become ready and healthy.
+```bash
+watch 'docker ps --format "table {{.ID}}\t{{.Names}}\t{{.Status}}"'
+```
+> Note: Downloading model files may take ~10 minutes and starting containers may take another 10 minutes depending on network speed. Look for "server is listening on http://0.0.0.0:8000" in the logs of model server containers.
+
+
+#### 3. Access the frontend UI
+
+Open your browser and go to: [http://localhost:3000](http://localhost:3000)
+
+> Note: If you are running this on a remote GPU via an ssh connection, in a new terminal window, you need to run to be able to access the UI at localhost:3000 and for the UI to be able to communicate to the backend at localhost:8000:
+>```bash
+> ssh -L 3000:localhost:3000 -L 8000:localhost:8000 username@IP-address
+>```
+
+You should see the following UI in your browser:
+
+
+### 4. Try out the sample prompts
+Click on any of the tiles on the frontend to try out the supervisor and the other agents.
+
+#### RAG Agent:
+Before trying out the RAG agent, upload the example PDF document [NVIDIA Blackwell Whitepaper](https://images.nvidia.com/aem-dam/Solutions/geforce/blackwell/nvidia-rtx-blackwell-gpu-architecture.pdf) as context by clicking on the "Attach" icon in the text input space at the botton of the UI and then make sure to check the box in the "Select Sources" section on the left side of the UI.
+
+
+
+#### Image Understanding Agent:
+You can either provide URLs or drag and drop images.
+
+**Example Prompt:**
+
+
+Describe this image: https://en.wikipedia.org/wiki/London_Bridge#/media/File:London_Bridge_from_St_Olaf_Stairs.jpg
+
+
+## Customizations
+
+### Using different models
+
+You can use swap the model that the supervisor agent is using, for example to gpt-oss-20b.
+
+1. In `setup.sh`, uncomment the line to download gpt-oss-20b.
+> Note: If you already have the model files downloaded, you can skip to step 2.
+2. In `docker-compose-models.yml`, uncomment the block for gpt-oss-20b.
+> Note: Since the default models use all of the existing VRAM, you will need to comment out the block for gpt-oss-120b in `docker-compose-models.yml`.
+3. In `docker-compose.yml`, add `gpt-oss-20b` to the `MODELS` environment variable (line 40).
+> Note: This name should match the container name that you set for this model in `docker-compose-models.yml`.
+
+### Adding MCP servers and tools
+
+1. You can add more MCP servers and tools under [backend/tools/mcp_servers](backend/tools/mcp_servers/) following existing examples.
+
+2. If you added an MCP server, remember to add it to the server configs in [backend/client.py](backend/client.py)
diff --git a/nvidia/multi-agent-chatbot/assets/backend/.python-version b/nvidia/multi-agent-chatbot/assets/backend/.python-version
new file mode 100644
index 0000000..c8cfe39
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/.python-version
@@ -0,0 +1 @@
+3.10
diff --git a/nvidia/multi-agent-chatbot/assets/backend/Dockerfile b/nvidia/multi-agent-chatbot/assets/backend/Dockerfile
new file mode 100644
index 0000000..a7b6625
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/Dockerfile
@@ -0,0 +1,33 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+FROM python:3.12
+
+RUN apt-get update && \
+ apt-get install -y curl && \
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+
+WORKDIR /app
+
+RUN . $HOME/.local/bin/env
+ENV PATH="/root/.local/bin:$PATH"
+
+COPY pyproject.toml uv.lock ./
+COPY . .
+
+RUN uv sync
+
+CMD ["uv", "run", "--", "uvicorn", "main:app", "--reload", "--host", "0.0.0.0", "--port", "8000"]
\ No newline at end of file
diff --git a/nvidia/multi-agent-chatbot/assets/backend/README.md b/nvidia/multi-agent-chatbot/assets/backend/README.md
new file mode 100644
index 0000000..ce30b54
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/README.md
@@ -0,0 +1 @@
+# Chatbot Backend API Server
\ No newline at end of file
diff --git a/nvidia/multi-agent-chatbot/assets/backend/__init__.py b/nvidia/multi-agent-chatbot/assets/backend/__init__.py
new file mode 100644
index 0000000..df2d0a6
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/__init__.py
@@ -0,0 +1,16 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/nvidia/multi-agent-chatbot/assets/backend/agent.py b/nvidia/multi-agent-chatbot/assets/backend/agent.py
new file mode 100644
index 0000000..ef0da1e
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/agent.py
@@ -0,0 +1,586 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""ChatAgent implementation for LLM-powered conversational AI with tool calling."""
+
+import asyncio
+import contextlib
+import json
+from typing import AsyncIterator, List, Dict, Any, TypedDict, Optional, Callable, Awaitable
+
+from langchain_core.messages import HumanMessage, AIMessage, AnyMessage, SystemMessage, ToolMessage, ToolCall
+from langchain_core.utils.function_calling import convert_to_openai_tool
+from langgraph.checkpoint.memory import MemorySaver
+from langgraph.graph import END, START, StateGraph
+from openai import AsyncOpenAI
+
+from client import MCPClient
+from logger import logger
+from prompts import Prompts
+from postgres_storage import PostgreSQLConversationStorage
+from utils import convert_langgraph_messages_to_openai
+
+
+memory = MemorySaver()
+SENTINEL = object()
+StreamCallback = Callable[[Dict[str, Any]], Awaitable[None]]
+
+
+class State(TypedDict, total=False):
+ iterations: int
+ messages: List[AnyMessage]
+ chat_id: Optional[str]
+ image_data: Optional[str]
+
+
+class ChatAgent:
+ """Main conversational agent with tool calling and agent delegation capabilities.
+
+ This agent orchestrates conversation flow using a LangGraph state machine that can:
+ - Generate responses using LLMs
+ - Execute tool calls (including MCP tools)
+ - Handle image processing
+ - Manage conversation history via Redis
+ """
+
+ def __init__(self, vector_store, config_manager, postgres_storage: PostgreSQLConversationStorage):
+ """Initialize the chat agent.
+
+ Args:
+ vector_store: VectorStore instance for document retrieval
+ config_manager: ConfigManager for reading configuration
+ postgres_storage: PostgreSQL storage for conversation persistence
+ """
+ self.vector_store = vector_store
+ self.config_manager = config_manager
+ self.conversation_store = postgres_storage
+ self.current_model = None
+
+ self.current_model = None
+ self.max_iterations = 3
+
+ self.mcp_client = None
+ self.openai_tools = None
+ self.tools_by_name = None
+ self.system_prompt = None
+
+ self.graph = self._build_graph()
+ self.stream_callback = None
+ self.last_state = None
+
+ @classmethod
+ async def create(cls, vector_store, config_manager, postgres_storage: PostgreSQLConversationStorage):
+ """
+ Asynchronously creates and initializes a ChatAgent instance.
+
+ This factory method ensures that all async setup, like loading tools,
+ is completed before the agent is ready to be used.
+ """
+ agent = cls(vector_store, config_manager, postgres_storage)
+ await agent.init_tools()
+
+ available_tools = list(agent.tools_by_name.values()) if agent.tools_by_name else []
+ template_vars = {
+ "tools": "\n".join([f"- {tool.name}: {tool.description}" for tool in available_tools]) if available_tools else "No tools available",
+ }
+ agent.system_prompt = Prompts.get_template("supervisor_agent").render(template_vars)
+
+ logger.debug(f"Agent initialized with {len(available_tools)} tools.")
+ agent.set_current_model(config_manager.get_selected_model())
+ return agent
+
+ async def init_tools(self) -> None:
+ """Initialize MCP client and tools with retry logic.
+
+ Sets up the MCP client, retrieves available tools, converts them to OpenAI format,
+ and initializes specialized agents like the coding agent.
+ """
+ self.mcp_client = await MCPClient().init()
+
+ base_delay, max_retries = 0.1, 10
+ mcp_tools = []
+
+ for attempt in range(max_retries):
+ try:
+ mcp_tools = await self.mcp_client.get_tools()
+ break
+ except Exception as e:
+ logger.warning(f"MCP tools initialization attempt {attempt + 1} failed: {e}")
+ if attempt == max_retries - 1:
+ logger.error(f"MCP servers not ready after {max_retries} attempts, continuing without MCP tools")
+ mcp_tools = []
+ break
+ wait_time = base_delay * (2 ** attempt)
+ await asyncio.sleep(wait_time)
+ logger.info(f"MCP servers not ready, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})")
+
+ self.tools_by_name = {tool.name: tool for tool in mcp_tools}
+ logger.debug(f"Loaded {len(mcp_tools)} MCP tools: {list(self.tools_by_name.keys())}")
+
+ if mcp_tools:
+ mcp_tools_openai = [convert_to_openai_tool(tool) for tool in mcp_tools]
+ logger.debug(f"MCP tools converted to OpenAI format: {mcp_tools_openai}")
+
+ self.openai_tools = [
+ {"type": "function", "function": tool['function']}
+ for tool in mcp_tools_openai
+ ]
+ logger.debug(f"Final OpenAI tools format: {self.openai_tools}")
+ else:
+ self.openai_tools = []
+ logger.warning("No MCP tools available - agent will run with limited functionality")
+
+ def set_current_model(self, model_name: str) -> None:
+ """Set the current model for completions.
+
+ Args:
+ model_name: Name of the model to use
+
+ Raises:
+ ValueError: If the model is not available
+ """
+ available_models = self.config_manager.get_available_models()
+
+ try:
+ if model_name in available_models:
+ self.current_model = model_name
+ logger.info(f"Switched to model: {model_name}")
+ self.model_client = AsyncOpenAI(
+ base_url=f"http://{self.current_model}:8000/v1",
+ api_key="api_key"
+ )
+ else:
+ raise ValueError(f"Model {model_name} is not available. Available models: {available_models}")
+ except Exception as e:
+ logger.error(f"Error setting current model: {e}")
+ raise ValueError(f"Model {model_name} is not available. Available models: {available_models}")
+
+ def should_continue(self, state: State) -> str:
+ """Determine whether to continue the tool calling loop.
+
+ Args:
+ state: Current graph state
+
+ Returns:
+ "end" if no more tool calls or max iterations reached, "continue" otherwise
+ """
+ messages = state.get("messages", [])
+ if not messages:
+ return "end"
+
+ last_message = messages[-1]
+ iterations = state.get("iterations", 0)
+ has_tool_calls = bool(last_message.tool_calls) if hasattr(last_message, 'tool_calls') else False
+
+ logger.debug({
+ "message": "GRAPH: should_continue decision",
+ "chat_id": state.get("chat_id"),
+ "iterations": iterations,
+ "max_iterations": self.max_iterations,
+ "has_tool_calls": has_tool_calls,
+ "tool_calls_count": len(last_message.tool_calls) if has_tool_calls else 0
+ })
+
+ if iterations >= self.max_iterations:
+ logger.debug({
+ "message": "GRAPH: should_continue → END (max iterations reached)",
+ "chat_id": state.get("chat_id"),
+ "final_message_preview": str(last_message)[:100] + "..." if len(str(last_message)) > 100 else str(last_message)
+ })
+ return "end"
+
+ if not has_tool_calls:
+ logger.debug({"message": "GRAPH: should_continue → END (no tool calls)", "chat_id": state.get("chat_id")})
+ return "end"
+
+ logger.debug({"message": "GRAPH: should_continue → CONTINUE (has tool calls)", "chat_id": state.get("chat_id")})
+ return "continue"
+
+ async def tool_node(self, state: State) -> Dict[str, Any]:
+ """Execute tools from the last AI message's tool calls.
+
+ Args:
+ state: Current graph state
+
+ Returns:
+ Updated state with tool results and incremented iteration count
+ """
+ logger.debug({
+ "message": "GRAPH: ENTERING NODE - action/tool_node",
+ "chat_id": state.get("chat_id"),
+ "iterations": state.get("iterations", 0)
+ })
+ await self.stream_callback({'type': 'node_start', 'data': 'tool_node'})
+
+ outputs = []
+ messages = state.get("messages", [])
+ last_message = messages[-1]
+ for i, tool_call in enumerate(last_message.tool_calls):
+ logger.debug(f'Executing tool {i+1}/{len(last_message.tool_calls)}: {tool_call["name"]} with args: {tool_call["args"]}')
+ await self.stream_callback({'type': 'tool_start', 'data': tool_call["name"]})
+
+ try:
+ if tool_call["name"] == "explain_image" and state.get("image_data"):
+ tool_args = tool_call["args"].copy()
+ tool_args["image"] = state["image_data"]
+ logger.info(f'Executing tool {tool_call["name"]} with args: {tool_args}')
+ tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_args)
+ state["process_image_used"] = True
+ else:
+ tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_call["args"])
+ if "code" in tool_call["name"]:
+ content = str(tool_result)
+ elif isinstance(tool_result, str):
+ content = tool_result
+ else:
+ content = json.dumps(tool_result)
+ except Exception as e:
+ logger.error(f'Error executing tool {tool_call["name"]}: {str(e)}', exc_info=True)
+ content = f"Error executing tool '{tool_call['name']}': {str(e)}"
+
+ await self.stream_callback({'type': 'tool_end', 'data': tool_call["name"]})
+
+ outputs.append(
+ ToolMessage(
+ content=content,
+ name=tool_call["name"],
+ tool_call_id=tool_call["id"],
+ )
+ )
+
+ state["iterations"] = state.get("iterations", 0) + 1
+
+ logger.debug({
+ "message": "GRAPH: EXITING NODE - action/tool_node",
+ "chat_id": state.get("chat_id"),
+ "iterations": state.get("iterations"),
+ "tools_executed": len(outputs),
+ "next_step": "→ returning to generate"
+ })
+ await self.stream_callback({'type': 'node_end', 'data': 'tool_node'})
+ return {"messages": messages + outputs, "iterations": state.get("iterations", 0) + 1}
+
+ async def generate(self, state: State) -> Dict[str, Any]:
+ """Generate AI response using the current model.
+
+ Args:
+ state: Current graph state
+
+ Returns:
+ Updated state with new AI message
+ """
+ messages = convert_langgraph_messages_to_openai(state.get("messages", []))
+ logger.debug({
+ "message": "GRAPH: ENTERING NODE - generate",
+ "chat_id": state.get("chat_id"),
+ "iterations": state.get("iterations", 0),
+ "current_model": self.current_model,
+ "message_count": len(state.get("messages", []))
+ })
+ await self.stream_callback({'type': 'node_start', 'data': 'generate'})
+
+ supports_tools = self.current_model in {"gpt-oss-20b", "gpt-oss-120b"}
+ has_tools = supports_tools and self.openai_tools and len(self.openai_tools) > 0
+
+ logger.debug({
+ "message": "Tool calling debug info",
+ "chat_id": state.get("chat_id"),
+ "current_model": self.current_model,
+ "supports_tools": supports_tools,
+ "openai_tools_count": len(self.openai_tools) if self.openai_tools else 0,
+ "openai_tools": self.openai_tools,
+ "has_tools": has_tools
+ })
+
+ tool_params = {}
+ if has_tools:
+ tool_params = {
+ "tools": self.openai_tools,
+ "tool_choice": "auto"
+ }
+
+ stream = await self.model_client.chat.completions.create(
+ model=self.current_model,
+ messages=messages,
+ temperature=0,
+ top_p=1,
+ stream=True,
+ **tool_params
+ )
+
+ llm_output_buffer, tool_calls_buffer = await self._stream_response(stream, self.stream_callback)
+ tool_calls = self._format_tool_calls(tool_calls_buffer)
+ raw_output = "".join(llm_output_buffer)
+
+ logger.debug({
+ "message": "Tool call generation results",
+ "chat_id": state.get("chat_id"),
+ "tool_calls_buffer": tool_calls_buffer,
+ "formatted_tool_calls": tool_calls,
+ "tool_calls_count": len(tool_calls),
+ "raw_output_length": len(raw_output),
+ "raw_output": raw_output[:200] + "..." if len(raw_output) > 200 else raw_output
+ })
+
+ response = AIMessage(
+ content=raw_output,
+ **({"tool_calls": tool_calls} if tool_calls else {})
+ )
+
+ logger.debug({
+ "message": "GRAPH: EXITING NODE - generate",
+ "chat_id": state.get("chat_id"),
+ "iterations": state.get("iterations", 0),
+ "response_length": len(response.content) if response.content else 0,
+ "tool_calls_generated": len(tool_calls),
+ "tool_calls_names": [tc["name"] for tc in tool_calls] if tool_calls else [],
+ "next_step": "→ should_continue decision"
+ })
+ await self.stream_callback({'type': 'node_end', 'data': 'generate'})
+ return {"messages": state.get("messages", []) + [response]}
+
+ def _build_graph(self) -> StateGraph:
+ """Build the LangGraph state machine for conversation flow.
+
+ Returns:
+ Compiled StateGraph with nodes and conditional edges
+ """
+ workflow = StateGraph(State)
+
+ workflow.add_node("generate", self.generate)
+ workflow.add_node("action", self.tool_node)
+ workflow.add_edge(START, "generate")
+ workflow.add_conditional_edges(
+ "generate",
+ self.should_continue,
+ {
+ "continue": "action",
+ "end": END,
+ },
+ )
+ workflow.add_edge("action", "generate")
+
+ return workflow.compile(checkpointer=memory)
+
+ def _format_tool_calls(self, tool_calls_buffer: Dict[int, Dict[str, str]]) -> List[ToolCall]:
+ """Parse streamed tool call buffer into ToolCall objects.
+
+ Args:
+ tool_calls_buffer: Buffer of streamed tool call data
+
+ Returns:
+ List of formatted ToolCall objects
+ """
+ if not tool_calls_buffer:
+ return []
+
+ tool_calls = []
+ for i in sorted(tool_calls_buffer):
+ item = tool_calls_buffer[i]
+ try:
+ parsed_args = json.loads(item["arguments"] or "{}")
+ except json.JSONDecodeError:
+ parsed_args = {}
+
+ tool_calls.append(
+ ToolCall(
+ name=item["name"],
+ args=parsed_args,
+ id=item["id"] or f"call_{i}",
+ )
+ )
+ return tool_calls
+
+ async def _stream_response(self, stream, stream_callback: StreamCallback) -> tuple[List[str], Dict[int, Dict[str, str]]]:
+ """Process streaming LLM response and extract content and tool calls.
+
+ Args:
+ stream: Async stream from LLM
+ stream_callback: Callback for streaming events
+
+ Returns:
+ Tuple of (content_buffer, tool_calls_buffer)
+ """
+ llm_output_buffer = []
+ tool_calls_buffer = {}
+ saw_tool_finish = False
+
+ async for chunk in stream:
+ for choice in getattr(chunk, "choices", []) or []:
+ delta = getattr(choice, "delta", None)
+ if not delta:
+ continue
+
+ content = getattr(delta, "content", None)
+ if content:
+ await stream_callback({"type": "token", "data": content})
+ llm_output_buffer.append(content)
+ for tc in getattr(delta, "tool_calls", []) or []:
+ idx = getattr(tc, "index", None)
+ if idx is None:
+ idx = 0 if not tool_calls_buffer else max(tool_calls_buffer) + 1
+ entry = tool_calls_buffer.setdefault(idx, {"id": None, "name": None, "arguments": ""})
+
+ if getattr(tc, "id", None):
+ entry["id"] = tc.id
+
+ fn = getattr(tc, "function", None)
+ if fn:
+ if getattr(fn, "name", None):
+ entry["name"] = fn.name
+ if getattr(fn, "arguments", None):
+ entry["arguments"] += fn.arguments
+
+ finish_reason = getattr(choice, "finish_reason", None)
+ if finish_reason == "tool_calls":
+ saw_tool_finish = True
+ break
+
+ if saw_tool_finish:
+ break
+
+ return llm_output_buffer, tool_calls_buffer
+
+ async def query(self, query_text: str, chat_id: str, image_data: str = None) -> AsyncIterator[Dict[str, Any]]:
+ """Process user query and stream response tokens.
+
+ Args:
+ query_text: User's input text
+ chat_id: Unique chat identifier
+
+ Yields:
+ Streaming events and tokens
+ """
+ logger.debug({
+ "message": "GRAPH: STARTING EXECUTION",
+ "chat_id": chat_id,
+ "query": query_text[:100] + "..." if len(query_text) > 100 else query_text,
+ "graph_flow": "START → generate → should_continue → action → generate → END"
+ })
+
+ config = {"configurable": {"thread_id": chat_id}}
+
+ try:
+ existing_messages = await self.conversation_store.get_messages(chat_id, limit=10)
+
+ base_system_prompt = self.system_prompt
+ if image_data:
+ image_context = "\n\nIMAGE CONTEXT: The user has uploaded an image with their message. You MUST use the explain_image tool to analyze it."
+ system_prompt_with_image = base_system_prompt + image_context
+ messages_to_process = [SystemMessage(content=system_prompt_with_image)]
+ else:
+ messages_to_process = [SystemMessage(content=base_system_prompt)]
+
+ if existing_messages:
+ for msg in existing_messages:
+ if not isinstance(msg, SystemMessage):
+ messages_to_process.append(msg)
+
+ messages_to_process.append(HumanMessage(content=query_text))
+
+ config_obj = self.config_manager.read_config()
+
+ initial_state = {
+ "iterations": 0,
+ "chat_id": chat_id,
+ "messages": messages_to_process,
+ "image_data": image_data if image_data else None,
+ "process_image_used": False
+ }
+
+
+ model_name = self.config_manager.get_selected_model()
+ if self.current_model != model_name:
+ self.set_current_model(model_name)
+
+ logger.debug({
+ "message": "GRAPH: LAUNCHING EXECUTION",
+ "chat_id": chat_id,
+ "initial_state": {
+ "iterations": initial_state["iterations"],
+ "message_count": len(initial_state["messages"]),
+ }
+ })
+
+ self.last_state = None
+ token_q: asyncio.Queue[Any] = asyncio.Queue()
+ self.stream_callback = lambda event: self._queue_writer(event, token_q)
+ runner = asyncio.create_task(self._run_graph(initial_state, config, chat_id, token_q))
+
+ try:
+ while True:
+ item = await token_q.get()
+ if item is SENTINEL:
+ break
+ yield item
+ except Exception as stream_error:
+ logger.error({"message": "Error in streaming", "error": str(stream_error)}, exc_info=True)
+ finally:
+ with contextlib.suppress(asyncio.CancelledError):
+ await runner
+
+ logger.debug({
+ "message": "GRAPH: EXECUTION COMPLETED",
+ "chat_id": chat_id,
+ "final_iterations": self.last_state.get("iterations", 0) if self.last_state else 0
+ })
+
+ except Exception as e:
+ logger.error({"message": "GRAPH: EXECUTION FAILED", "error": str(e), "chat_id": chat_id}, exc_info=True)
+ yield {"type": "error", "data": f"Error performing query: {str(e)}"}
+
+
+ async def _queue_writer(self, event: Dict[str, Any], token_q: asyncio.Queue) -> None:
+ """Write events to the streaming queue.
+
+ Args:
+ event: Event data to queue
+ token_q: Queue for streaming events
+ """
+ await token_q.put(event)
+
+ async def _run_graph(self, initial_state: Dict[str, Any], config: Dict[str, Any], chat_id: str, token_q: asyncio.Queue) -> None:
+ """Run the graph execution in background task.
+
+ Args:
+ initial_state: Starting state for graph
+ config: LangGraph configuration
+ chat_id: Chat identifier
+ token_q: Queue for streaming events
+ """
+ try:
+ async for final_state in self.graph.astream(
+ initial_state,
+ config=config,
+ stream_mode="values",
+ stream_writer=lambda event: self._queue_writer(event, token_q)
+ ):
+ self.last_state = final_state
+ finally:
+ try:
+ if self.last_state and self.last_state.get("messages"):
+ final_msg = self.last_state["messages"][-1]
+ try:
+ logger.debug(f'Saving messages to conversation store for chat: {chat_id}')
+ await self.conversation_store.save_messages(chat_id, self.last_state["messages"])
+ except Exception as save_err:
+ logger.warning({"message": "Failed to persist conversation", "chat_id": chat_id, "error": str(save_err)})
+
+ content = getattr(final_msg, "content", None)
+ if content:
+ await token_q.put(content)
+ finally:
+ await token_q.put(SENTINEL)
diff --git a/nvidia/multi-agent-chatbot/assets/backend/client.py b/nvidia/multi-agent-chatbot/assets/backend/client.py
new file mode 100644
index 0000000..a1d0e9b
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/client.py
@@ -0,0 +1,93 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Multi-Server MCP Client for connecting to multiple MCP servers.
+
+This module provides a unified client interface for connecting to and managing
+multiple Model Context Protocol (MCP) servers. It handles server configuration,
+initialization, and tool retrieval across different server types.
+"""
+
+from typing import List, Optional
+
+from langchain_mcp_adapters.client import MultiServerMCPClient
+from mcp.types import Tool
+
+
+class MCPClient:
+ """Client for managing connections to multiple MCP servers.
+
+ Provides a unified interface for connecting to and interacting with
+ various MCP servers including RAG, image understanding, and weather services.
+ """
+
+ def __init__(self):
+ """Initialize the MCP client with predefined server configurations."""
+ self.server_configs = {
+ "image-understanding-server": {
+ "command": "python",
+ "args": ["tools/mcp_servers/image_understanding.py"],
+ "transport": "stdio",
+ },
+ "code-generation-server": {
+ "command": "python",
+ "args": ["tools/mcp_servers/code_generation.py"],
+ "transport": "stdio",
+ },
+ "rag-server": {
+ "command": "python",
+ "args": ["tools/mcp_servers/rag.py"],
+ "transport": "stdio",
+ },
+ "weather-server": {
+ "command": "python",
+ "args": ["tools/mcp_servers/weather_test.py"],
+ "transport": "stdio",
+ }
+ }
+ self.mcp_client: MultiServerMCPClient | None = None
+
+ async def init(self):
+ """Initialize the multi-server MCP client.
+
+ Returns:
+ MCPClient: Self for method chaining
+
+ Raises:
+ Exception: If client initialization fails
+ """
+ self.mcp_client = MultiServerMCPClient(self.server_configs)
+ return self
+
+ async def get_tools(self):
+ """Retrieve available tools from all connected MCP servers.
+
+ Returns:
+ List[Tool]: List of available tools from all servers
+
+ Raises:
+ RuntimeError: If client is not initialized
+ Exception: If tool retrieval fails
+ """
+ if not self.mcp_client:
+ raise RuntimeError("MCP client not initialized. Call `await init()` first.")
+
+ try:
+ tools = await self.mcp_client.get_tools()
+ return tools
+ except Exception as error:
+ print("Error encountered connecting to MCP server. Is the server running? Is your config server path correct?\n")
+ raise error
\ No newline at end of file
diff --git a/nvidia/multi-agent-chatbot/assets/backend/config.py b/nvidia/multi-agent-chatbot/assets/backend/config.py
new file mode 100644
index 0000000..6d5f6bc
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/config.py
@@ -0,0 +1,165 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""ConfigManager for managing the configuration of the chat application."""
+
+import json
+import os
+import logging
+import threading
+from typing import List
+
+from logger import logger
+from models import ChatConfig
+
+
+class ConfigManager:
+ def __init__(self, config_path: str):
+ """Initialize the ConfigManager"""
+ self.config_path = config_path
+ self.config = None
+ self._last_modified = 0
+ self._lock = threading.Lock()
+ self._ensure_config_exists()
+ self.read_config()
+
+ def _ensure_config_exists(self) -> None:
+ """Ensure config.json exists, creating it with default values if not."""
+ models = []
+ models = os.getenv("MODELS", "")
+
+ if models:
+ models = [model.strip() for model in models.split(",") if model.strip()]
+ else:
+ logger.warning("MODELS environment variable not set, using empty models list")
+
+ if not os.path.exists(self.config_path):
+ logger.debug(f"Config file {self.config_path} not found, creating default config")
+ default_config = ChatConfig(
+ sources=[],
+ models=models,
+ selected_model=models[0] if models else None,
+ selected_sources=[],
+ current_chat_id=None
+ )
+
+ with open(self.config_path, "w") as f:
+ json.dump(default_config.model_dump(), f, indent=2)
+ else:
+ try:
+ with open(self.config_path, "r") as f:
+ data = json.load(f)
+ existing_config = ChatConfig(**data)
+
+ if models:
+ existing_config.models = models
+ if not existing_config.selected_model or existing_config.selected_model not in models:
+ existing_config.selected_model = models[0]
+
+ with open(self.config_path, "w") as f:
+ json.dump(existing_config.model_dump(), f, indent=2)
+
+ logger.debug(f"Updated existing config with models: {models}")
+ except Exception as e:
+ logger.error(f"Error updating existing config: {e}")
+ default_config = ChatConfig(
+ sources=[],
+ models=models,
+ selected_model=models[0] if models else None,
+ selected_sources=[],
+ current_chat_id=None
+ )
+ with open(self.config_path, "w") as f:
+ json.dump(default_config.model_dump(), f, indent=2)
+
+ def read_config(self) -> ChatConfig:
+ """Read config from file, but only if it has changed since last read."""
+ with self._lock:
+ try:
+ current_mtime = os.path.getmtime(self.config_path)
+ if self.config is None or current_mtime > self._last_modified:
+ with open(self.config_path, "r") as f:
+ data = json.load(f)
+ self.config = ChatConfig(**data)
+ self._last_modified = current_mtime
+ return self.config
+ except Exception as e:
+ logger.error(f"Error reading config: {e}")
+ if self.config is None:
+ models = []
+ models = os.getenv("MODELS", "")
+ if models:
+ models = [model.strip() for model in models.split(",") if model.strip()]
+
+ self.config = ChatConfig(
+ sources=[],
+ models=models,
+ selected_model=models[0] if models else "gpt-oss-120b",
+ selected_sources=[],
+ current_chat_id="1"
+ )
+ return self.config
+
+ def write_config(self, new_config: ChatConfig) -> None:
+ """Thread-safe write config to file."""
+ with self._lock:
+ with open(self.config_path, "w") as f:
+ json.dump(new_config.model_dump(), f, indent=2)
+ self.config = new_config
+ self._last_modified = os.path.getmtime(self.config_path)
+
+ def get_sources(self) -> List[str]:
+ """Return list of available sources."""
+ self.config = self.read_config()
+ return self.config.sources
+
+ def get_selected_sources(self) -> List[str]:
+ """Return list of selected sources."""
+ self.config = self.read_config()
+ return self.config.selected_sources
+
+ def get_available_models(self) -> List[str]:
+ """Return list of available models."""
+ self.config = self.read_config()
+ return self.config.models
+
+ def get_selected_model(self) -> str:
+ """Return the selected model."""
+ self.config = self.read_config()
+ logger.debug(f"Selected model: {self.config.selected_model}")
+ return self.config.selected_model
+
+ def get_current_chat_id(self) -> str:
+ """Return the current chat id."""
+ self.config = self.read_config()
+ return self.config.current_chat_id
+
+
+ def updated_selected_sources(self, new_sources: List[str]) -> None:
+ """Update the selected sources in the config."""
+ self.config = self.read_config().model_copy(update={"selected_sources": new_sources})
+ self.write_config(self.config)
+
+ def updated_selected_model(self, new_model: str) -> None:
+ """Update the selected model in the config."""
+ self.config = self.read_config().model_copy(update={"selected_model": new_model})
+ logger.debug(f"Updated selected model to: {new_model}")
+ self.write_config(self.config)
+
+ def updated_current_chat_id(self, new_chat_id: str) -> None:
+ """Update the current chat id in the config."""
+ self.config = self.read_config().model_copy(update={"current_chat_id": new_chat_id})
+ self.write_config(self.config)
diff --git a/nvidia/multi-agent-chatbot/assets/backend/logger.py b/nvidia/multi-agent-chatbot/assets/backend/logger.py
new file mode 100644
index 0000000..2c61cab
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/logger.py
@@ -0,0 +1,145 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import json
+import logging
+import sys
+import traceback
+from datetime import datetime
+from typing import Any, Dict, Optional
+
+
+class JsonFormatter(logging.Formatter):
+ """
+ Formatter that outputs JSON strings after parsing the log record.
+ """
+ def __init__(self, **kwargs):
+ self.default_keys = {
+ 'timestamp': 'asctime',
+ 'level': 'levelname',
+ 'message': 'message'
+ }
+ self.default_keys.update(kwargs)
+
+ def format(self, record: logging.LogRecord) -> str:
+ """
+ Format the log record as JSON.
+ """
+ log_record = {}
+
+ log_record['timestamp'] = datetime.utcnow().isoformat() + 'Z'
+ log_record['level'] = record.levelname
+ log_record['logger'] = record.name
+
+ if isinstance(record.msg, dict):
+ log_record['message'] = record.msg.get('message', '')
+ for key, value in record.msg.items():
+ if key != 'message':
+ log_record[key] = value
+ else:
+ log_record['message'] = record.getMessage()
+
+ if record.exc_info:
+ log_record['exception'] = {
+ 'type': record.exc_info[0].__name__,
+ 'message': str(record.exc_info[1]),
+ 'traceback': traceback.format_exception(*record.exc_info)
+ }
+
+ for key, value in record.__dict__.items():
+ if key not in ['msg', 'args', 'exc_info', 'exc_text', 'stack_info', 'lineno',
+ 'funcName', 'created', 'msecs', 'relativeCreated', 'levelname',
+ 'levelno', 'pathname', 'filename', 'module', 'name', 'thread',
+ 'threadName', 'processName', 'process']:
+ log_record[key] = value
+
+ return json.dumps(log_record)
+
+def setup_logger(name: str = 'backend',
+ level: int = logging.INFO,
+ log_file: Optional[str] = 'app.log') -> logging.Logger:
+ """
+ Set up a JSON logger with console and file handlers.
+
+ Args:
+ name: Logger name
+ level: Logging level
+ log_file: Path to log file (None for no file logging)
+
+ Returns:
+ Configured logger instance
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+
+ logger.propagate = False
+
+ for handler in logger.handlers[:]:
+ logger.removeHandler(handler)
+
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setFormatter(JsonFormatter())
+ logger.addHandler(console_handler)
+
+ if log_file:
+ file_handler = logging.FileHandler(log_file)
+ file_handler.setFormatter(JsonFormatter())
+ logger.addHandler(file_handler)
+
+ return logger
+
+logger = setup_logger()
+
+
+def log_request(request_data: Dict[str, Any], endpoint: str) -> None:
+ """
+ Log an API request with structured data.
+ """
+ logger.info({
+ 'message': f'API request to {endpoint}',
+ 'endpoint': endpoint,
+ 'request_data': request_data
+ })
+
+
+def log_response(response_data: Dict[str, Any], endpoint: str, status_code: int = 200) -> None:
+ """
+ Log an API response with structured data.
+ """
+ logger.info({
+ 'message': f'API response from {endpoint}',
+ 'endpoint': endpoint,
+ 'status_code': status_code,
+ 'response_data': response_data
+ })
+
+
+def log_error(error: Exception, endpoint: str = None, request_data: Dict[str, Any] = None) -> None:
+ """
+ Log an error with structured data.
+ """
+ error_data = {
+ 'message': f'Error: {str(error)}',
+ 'error_type': error.__class__.__name__,
+ }
+
+ if endpoint:
+ error_data['endpoint'] = endpoint
+
+ if request_data:
+ error_data['request_data'] = request_data
+
+ logger.error(error_data, exc_info=True)
\ No newline at end of file
diff --git a/nvidia/multi-agent-chatbot/assets/backend/main.py b/nvidia/multi-agent-chatbot/assets/backend/main.py
new file mode 100644
index 0000000..cbac396
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/main.py
@@ -0,0 +1,516 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""FastAPI backend server for the chatbot application.
+
+This module provides the main HTTP API endpoints and WebSocket connections for:
+- Real-time chat via WebSocket
+- File upload and document ingestion
+- Configuration management (models, sources, chat settings)
+- Chat history management
+- Vector store operations
+"""
+
+import base64
+import json
+import os
+import uuid
+from contextlib import asynccontextmanager
+from typing import List, Optional, Dict
+
+from fastapi import FastAPI, File, Form, UploadFile, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect
+from fastapi.middleware.cors import CORSMiddleware
+
+from agent import ChatAgent
+from config import ConfigManager
+from logger import logger, log_request, log_response, log_error
+from models import ChatIdRequest, ChatRenameRequest, SelectedModelRequest
+from postgres_storage import PostgreSQLConversationStorage
+from utils import process_and_ingest_files_background
+from vector_store import create_vector_store_with_config
+
+POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
+POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", 5432))
+POSTGRES_DB = os.getenv("POSTGRES_DB", "chatbot")
+POSTGRES_USER = os.getenv("POSTGRES_USER", "chatbot_user")
+POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "chatbot_password")
+
+config_manager = ConfigManager("./config.json")
+
+postgres_storage = PostgreSQLConversationStorage(
+ host=POSTGRES_HOST,
+ port=POSTGRES_PORT,
+ database=POSTGRES_DB,
+ user=POSTGRES_USER,
+ password=POSTGRES_PASSWORD
+)
+
+vector_store = create_vector_store_with_config(config_manager)
+
+vector_store._initialize_store()
+
+agent: ChatAgent | None = None
+indexing_tasks: Dict[str, str] = {}
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """Application lifespan manager for startup and shutdown tasks."""
+ global agent
+ logger.debug("Initializing PostgreSQL storage and agent...")
+
+ try:
+ await postgres_storage.init_pool()
+ logger.info("PostgreSQL storage initialized successfully")
+ logger.debug("Initializing ChatAgent...")
+ agent = await ChatAgent.create(
+ vector_store=vector_store,
+ config_manager=config_manager,
+ postgres_storage=postgres_storage
+ )
+ logger.info("ChatAgent initialized successfully.")
+ except Exception as e:
+ logger.error(f"Failed to initialize PostgreSQL storage: {e}")
+ raise
+
+ yield
+
+ try:
+ await postgres_storage.close()
+ logger.debug("PostgreSQL storage closed successfully")
+ except Exception as e:
+ logger.error(f"Error closing PostgreSQL storage: {e}")
+
+
+app = FastAPI(
+ title="Chatbot API",
+ description="Backend API for LLM-powered chatbot with RAG capabilities",
+ version="1.0.0",
+ lifespan=lifespan
+)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["http://localhost:3000"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+
+@app.websocket("/ws/chat/{chat_id}")
+async def websocket_endpoint(websocket: WebSocket, chat_id: str):
+ """WebSocket endpoint for real-time chat communication.
+
+ Args:
+ websocket: WebSocket connection
+ chat_id: Unique chat identifier
+ """
+ logger.debug(f"WebSocket connection attempt for chat_id: {chat_id}")
+ try:
+ await websocket.accept()
+ logger.debug(f"WebSocket connection accepted for chat_id: {chat_id}")
+
+ history_messages = await postgres_storage.get_messages(chat_id)
+ history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(history_messages) if i != 0]
+ await websocket.send_json({"type": "history", "messages": history})
+
+ while True:
+ data = await websocket.receive_text()
+ client_message = json.loads(data)
+ new_message = client_message.get("message")
+ image_id = client_message.get("image_id")
+
+ image_data = None
+ if image_id:
+ image_data = await postgres_storage.get_image(image_id)
+ logger.debug(f"Retrieved image data for image_id: {image_id}, data length: {len(image_data) if image_data else 0}")
+
+ try:
+ async for event in agent.query(query_text=new_message, chat_id=chat_id, image_data=image_data):
+ await websocket.send_json(event)
+ except Exception as query_error:
+ logger.error(f"Error in agent.query: {str(query_error)}", exc_info=True)
+ await websocket.send_json({"type": "error", "content": f"Error processing request: {str(query_error)}"})
+
+ final_messages = await postgres_storage.get_messages(chat_id)
+ final_history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(final_messages) if i != 0]
+ await websocket.send_json({"type": "history", "messages": final_history})
+
+ except WebSocketDisconnect:
+ logger.debug(f"Client disconnected from chat {chat_id}")
+ except Exception as e:
+ logger.error(f"WebSocket error for chat {chat_id}: {str(e)}", exc_info=True)
+
+
+@app.post("/upload-image")
+async def upload_image(image: UploadFile = File(...), chat_id: str = Form(...)):
+ """Upload and store an image for chat processing.
+
+ Args:
+ image: Uploaded image file
+ chat_id: Chat identifier for context
+
+ Returns:
+ Dictionary with generated image_id
+ """
+ image_data = await image.read()
+ image_base64 = base64.b64encode(image_data).decode('utf-8')
+ data_uri = f"data:{image.content_type};base64,{image_base64}"
+ image_id = str(uuid.uuid4())
+ await postgres_storage.store_image(image_id, data_uri)
+ return {"image_id": image_id}
+
+
+@app.post("/ingest")
+async def ingest_files(files: Optional[List[UploadFile]] = File(None), background_tasks: BackgroundTasks = None):
+ """Ingest documents for vector search and RAG.
+
+ Args:
+ files: List of uploaded files to process
+ background_tasks: FastAPI background tasks manager
+
+ Returns:
+ Task information for tracking ingestion progress
+ """
+ try:
+ log_request({"file_count": len(files) if files else 0}, "/ingest")
+
+ task_id = str(uuid.uuid4())
+
+ file_info = []
+ for file in files:
+ content = await file.read()
+ file_info.append({
+ "filename": file.filename,
+ "content": content
+ })
+
+ indexing_tasks[task_id] = "queued"
+
+ background_tasks.add_task(
+ process_and_ingest_files_background,
+ file_info,
+ vector_store,
+ config_manager,
+ task_id,
+ indexing_tasks
+ )
+
+ response = {
+ "message": f"Files queued for processing. Indexing {len(files)} files in the background.",
+ "files": [file.filename for file in files],
+ "status": "queued",
+ "task_id": task_id
+ }
+
+ log_response(response, "/ingest")
+ return response
+
+ except Exception as e:
+ log_error(e, "/ingest")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error queuing files for ingestion: {str(e)}"
+ )
+
+
+@app.get("/ingest/status/{task_id}")
+async def get_indexing_status(task_id: str):
+ """Get the status of a file ingestion task.
+
+ Args:
+ task_id: Unique task identifier
+
+ Returns:
+ Current task status
+ """
+ if task_id in indexing_tasks:
+ return {"status": indexing_tasks[task_id]}
+ else:
+ raise HTTPException(status_code=404, detail="Task not found")
+
+
+@app.get("/sources")
+async def get_sources():
+ """Get all available document sources."""
+ try:
+ config = config_manager.read_config()
+ return {"sources": config.sources}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error getting sources: {str(e)}")
+
+
+@app.get("/selected_sources")
+async def get_selected_sources():
+ """Get currently selected document sources for RAG."""
+ try:
+ config = config_manager.read_config()
+ return {"sources": config.selected_sources}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error getting selected sources: {str(e)}")
+
+
+@app.post("/selected_sources")
+async def update_selected_sources(selected_sources: List[str]):
+ """Update the selected document sources for RAG.
+
+ Args:
+ selected_sources: List of source names to use for retrieval
+ """
+ try:
+ config_manager.updated_selected_sources(selected_sources)
+ return {"status": "success", "message": "Selected sources updated successfully"}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error updating selected sources: {str(e)}")
+
+
+@app.get("/selected_model")
+async def get_selected_model():
+ """Get the currently selected LLM model."""
+ try:
+ model = config_manager.get_selected_model()
+ return {"model": model}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error getting selected model: {str(e)}")
+
+
+@app.post("/selected_model")
+async def update_selected_model(request: SelectedModelRequest):
+ """Update the selected LLM model.
+
+ Args:
+ request: Model selection request with model name
+ """
+ try:
+ logger.debug(f"Updating selected model to: {request.model}")
+ config_manager.updated_selected_model(request.model)
+ return {"status": "success", "message": "Selected model updated successfully"}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error updating selected model: {str(e)}")
+
+
+@app.get("/available_models")
+async def get_available_models():
+ """Get list of all available LLM models."""
+ try:
+ models = config_manager.get_available_models()
+ return {"models": models}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error getting available models: {str(e)}")
+
+
+@app.get("/chats")
+async def list_chats():
+ """Get list of all chat conversations."""
+ try:
+ chat_ids = await postgres_storage.list_conversations()
+ return {"chats": chat_ids}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error listing chats: {str(e)}")
+
+
+@app.get("/chat_id")
+async def get_chat_id():
+ """Get the current active chat ID, creating a conversation if it doesn't exist."""
+ try:
+ config = config_manager.read_config()
+ current_chat_id = config.current_chat_id
+
+ if current_chat_id and await postgres_storage.exists(current_chat_id):
+ return {
+ "status": "success",
+ "chat_id": current_chat_id
+ }
+
+ new_chat_id = str(uuid.uuid4())
+
+ await postgres_storage.save_messages_immediate(new_chat_id, [])
+ await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
+
+ config_manager.updated_current_chat_id(new_chat_id)
+
+ return {
+ "status": "success",
+ "chat_id": new_chat_id
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error getting chat ID: {str(e)}"
+ )
+
+
+@app.post("/chat_id")
+async def update_chat_id(request: ChatIdRequest):
+ """Update the current active chat ID.
+
+ Args:
+ request: Chat ID update request
+ """
+ try:
+ config_manager.updated_current_chat_id(request.chat_id)
+ return {
+ "status": "success",
+ "message": f"Current chat ID updated to {request.chat_id}",
+ "chat_id": request.chat_id
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error updating chat ID: {str(e)}"
+ )
+
+
+@app.get("/chat/{chat_id}/metadata")
+async def get_chat_metadata(chat_id: str):
+ """Get metadata for a specific chat.
+
+ Args:
+ chat_id: Unique chat identifier
+
+ Returns:
+ Chat metadata including name
+ """
+ try:
+ metadata = await postgres_storage.get_chat_metadata(chat_id)
+ return metadata
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error getting chat metadata: {str(e)}"
+ )
+
+
+@app.post("/chat/rename")
+async def rename_chat(request: ChatRenameRequest):
+ """Rename a chat conversation.
+
+ Args:
+ request: Chat rename request with chat_id and new_name
+ """
+ try:
+ await postgres_storage.set_chat_metadata(request.chat_id, request.new_name)
+ return {
+ "status": "success",
+ "message": f"Chat {request.chat_id} renamed to {request.new_name}"
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error renaming chat: {str(e)}"
+ )
+
+
+@app.post("/chat/new")
+async def create_new_chat():
+ """Create a new chat conversation and set it as current."""
+ try:
+ new_chat_id = str(uuid.uuid4())
+ await postgres_storage.save_messages_immediate(new_chat_id, [])
+ await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
+
+ config_manager.updated_current_chat_id(new_chat_id)
+
+ return {
+ "status": "success",
+ "message": "New chat created",
+ "chat_id": new_chat_id
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error creating new chat: {str(e)}"
+ )
+
+
+@app.delete("/chat/{chat_id}")
+async def delete_chat(chat_id: str):
+ """Delete a specific chat and its messages.
+
+ Args:
+ chat_id: Unique chat identifier to delete
+ """
+ try:
+ success = await postgres_storage.delete_conversation(chat_id)
+
+ if success:
+ return {
+ "status": "success",
+ "message": f"Chat {chat_id} deleted successfully"
+ }
+ else:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Chat {chat_id} not found"
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error deleting chat: {str(e)}"
+ )
+
+
+@app.delete("/chats/clear")
+async def clear_all_chats():
+ """Clear all chat conversations and create a new default chat."""
+ try:
+ chat_ids = await postgres_storage.list_conversations()
+ cleared_count = 0
+
+ for chat_id in chat_ids:
+ if await postgres_storage.delete_conversation(chat_id):
+ cleared_count += 1
+
+ new_chat_id = str(uuid.uuid4())
+ await postgres_storage.save_messages_immediate(new_chat_id, [])
+ await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
+
+ config_manager.updated_current_chat_id(new_chat_id)
+
+ return {
+ "status": "success",
+ "message": f"Cleared {cleared_count} chats and created new chat",
+ "new_chat_id": new_chat_id,
+ "cleared_count": cleared_count
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error clearing all chats: {str(e)}"
+ )
+
+
+@app.delete("/collections/{collection_name}")
+async def delete_collection(collection_name: str):
+ """Delete a document collection from the vector store.
+
+ Args:
+ collection_name: Name of the collection to delete
+ """
+ try:
+ success = vector_store.delete_collection(collection_name)
+ if success:
+ return {"status": "success", "message": f"Collection '{collection_name}' deleted successfully"}
+ else:
+ raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found or could not be deleted")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
+
+
+if __name__ == "__main__":
+ import uvicorn
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)
\ No newline at end of file
diff --git a/nvidia/multi-agent-chatbot/assets/backend/models.py b/nvidia/multi-agent-chatbot/assets/backend/models.py
new file mode 100644
index 0000000..ecfc6e9
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/models.py
@@ -0,0 +1,35 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pydantic import BaseModel
+from typing import Optional, List
+
+class ChatConfig(BaseModel):
+ sources: List[str]
+ models : List[str]
+ selected_model: Optional[str] = None
+ selected_sources: Optional[List[str]] = None
+ current_chat_id: Optional[str] = None
+
+class ChatIdRequest(BaseModel):
+ chat_id: str
+
+class ChatRenameRequest(BaseModel):
+ chat_id: str
+ new_name: str
+
+class SelectedModelRequest(BaseModel):
+ model: str
diff --git a/nvidia/multi-agent-chatbot/assets/backend/postgres_storage.py b/nvidia/multi-agent-chatbot/assets/backend/postgres_storage.py
new file mode 100644
index 0000000..51dea0a
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/postgres_storage.py
@@ -0,0 +1,571 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""PostgreSQL-based conversation storage with caching and I/O optimization."""
+
+import json
+import time
+from typing import Dict, List, Optional, Any
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+import asyncio
+import asyncpg
+from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage, ToolMessage
+
+from logger import logger
+
+
+@dataclass
+class CacheEntry:
+ """Cache entry with TTL support."""
+ data: Any
+ timestamp: float
+ ttl: float = 300
+
+ def is_expired(self) -> bool:
+ return time.time() - self.timestamp > self.ttl
+
+
+class PostgreSQLConversationStorage:
+ """PostgreSQL-based conversation storage with intelligent caching and I/O optimization."""
+
+ def __init__(
+ self,
+ host: str = 'postgres',
+ port: int = 5432,
+ database: str = 'chatbot',
+ user: str = 'chatbot_user',
+ password: str = 'chatbot_password',
+ pool_size: int = 10,
+ cache_ttl: int = 300
+ ):
+ """Initialize PostgreSQL connection pool and caching.
+
+ Args:
+ host: PostgreSQL host
+ port: PostgreSQL port
+ database: Database name
+ user: Database user
+ password: Database password
+ pool_size: Connection pool size
+ cache_ttl: Cache TTL in seconds
+ """
+ self.host = host
+ self.port = port
+ self.database = database
+ self.user = user
+ self.password = password
+ self.pool_size = pool_size
+ self.cache_ttl = cache_ttl
+
+ self.pool: Optional[asyncpg.Pool] = None
+
+ self._message_cache: Dict[str, CacheEntry] = {}
+ self._metadata_cache: Dict[str, CacheEntry] = {}
+ self._image_cache: Dict[str, CacheEntry] = {}
+ self._chat_list_cache: Optional[CacheEntry] = None
+
+ self._pending_saves: Dict[str, List[BaseMessage]] = {}
+ self._save_lock = asyncio.Lock()
+ self._batch_save_task: Optional[asyncio.Task] = None
+
+ self._cache_hits = 0
+ self._cache_misses = 0
+ self._db_operations = 0
+
+ async def init_pool(self) -> None:
+ """Initialize the connection pool and create tables."""
+ try:
+ await self._ensure_database_exists()
+
+ self.pool = await asyncpg.create_pool(
+ host=self.host,
+ port=self.port,
+ database=self.database,
+ user=self.user,
+ password=self.password,
+ min_size=2,
+ max_size=self.pool_size,
+ command_timeout=30
+ )
+
+ await self._create_tables()
+ logger.debug("PostgreSQL connection pool initialized successfully")
+
+ self._batch_save_task = asyncio.create_task(self._batch_save_worker())
+
+ except Exception as e:
+ logger.error(f"Failed to initialize PostgreSQL pool: {e}")
+ raise
+
+ async def _ensure_database_exists(self) -> None:
+ """Ensure the target database exists, create if it doesn't."""
+ try:
+ conn = await asyncpg.connect(
+ host=self.host,
+ port=self.port,
+ database='postgres',
+ user=self.user,
+ password=self.password
+ )
+
+ try:
+ result = await conn.fetchval(
+ "SELECT 1 FROM pg_database WHERE datname = $1",
+ self.database
+ )
+
+ if not result:
+ await conn.execute(f'CREATE DATABASE "{self.database}"')
+ logger.debug(f"Created database: {self.database}")
+ else:
+ logger.debug(f"Database {self.database} already exists")
+
+ finally:
+ await conn.close()
+
+ except Exception as e:
+ logger.error(f"Error ensuring database exists: {e}")
+ pass
+
+ async def close(self) -> None:
+ """Close the connection pool and cleanup."""
+ if self._batch_save_task:
+ self._batch_save_task.cancel()
+ try:
+ await self._batch_save_task
+ except asyncio.CancelledError:
+ pass
+
+ if self.pool:
+ await self.pool.close()
+ logger.debug("PostgreSQL connection pool closed")
+
+ async def _create_tables(self) -> None:
+ """Create necessary tables if they don't exist."""
+ async with self.pool.acquire() as conn:
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS conversations (
+ chat_id VARCHAR(255) PRIMARY KEY,
+ messages JSONB NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ message_count INTEGER DEFAULT 0
+ )
+ """)
+
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS chat_metadata (
+ chat_id VARCHAR(255) PRIMARY KEY,
+ name VARCHAR(500),
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (chat_id) REFERENCES conversations(chat_id) ON DELETE CASCADE
+ )
+ """)
+
+ await conn.execute("""
+ CREATE TABLE IF NOT EXISTS images (
+ image_id VARCHAR(255) PRIMARY KEY,
+ image_data TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ expires_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP + INTERVAL '1 hour')
+ )
+ """)
+
+ await conn.execute("CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at)")
+ await conn.execute("CREATE INDEX IF NOT EXISTS idx_images_expires_at ON images(expires_at)")
+
+ await conn.execute("""
+ CREATE OR REPLACE FUNCTION update_updated_at_column()
+ RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.updated_at = CURRENT_TIMESTAMP;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql'
+ """)
+
+ await conn.execute("""
+ DROP TRIGGER IF EXISTS update_conversations_updated_at ON conversations
+ """)
+ await conn.execute("""
+ CREATE TRIGGER update_conversations_updated_at
+ BEFORE UPDATE ON conversations
+ FOR EACH ROW
+ EXECUTE FUNCTION update_updated_at_column()
+ """)
+
+ def _message_to_dict(self, message: BaseMessage) -> Dict:
+ """Convert a message object to a dictionary for storage."""
+ result = {
+ "type": message.__class__.__name__,
+ "content": message.content,
+ }
+
+ if hasattr(message, "tool_calls") and message.tool_calls:
+ result["tool_calls"] = message.tool_calls
+
+ if isinstance(message, ToolMessage):
+ result["tool_call_id"] = getattr(message, "tool_call_id", None)
+ result["name"] = getattr(message, "name", None)
+
+ return result
+
+ def _dict_to_message(self, data: Dict) -> BaseMessage:
+ """Convert a dictionary back to a message object."""
+ msg_type = data["type"]
+ content = data["content"]
+
+ if msg_type == "AIMessage":
+ msg = AIMessage(content=content)
+ if "tool_calls" in data:
+ msg.tool_calls = data["tool_calls"]
+ return msg
+ elif msg_type == "HumanMessage":
+ return HumanMessage(content=content)
+ elif msg_type == "SystemMessage":
+ return SystemMessage(content=content)
+ elif msg_type == "ToolMessage":
+ return ToolMessage(
+ content=content,
+ tool_call_id=data.get("tool_call_id", ""),
+ name=data.get("name", "")
+ )
+ else:
+ return HumanMessage(content=content)
+
+ def _get_cached_messages(self, chat_id: str) -> Optional[List[BaseMessage]]:
+ """Get messages from cache if available and not expired."""
+ cache_entry = self._message_cache.get(chat_id)
+ if cache_entry and not cache_entry.is_expired():
+ self._cache_hits += 1
+ return cache_entry.data
+
+ self._cache_misses += 1
+ return None
+
+ def _cache_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
+ """Cache messages with TTL."""
+ self._message_cache[chat_id] = CacheEntry(
+ data=messages.copy(),
+ timestamp=time.time(),
+ ttl=self.cache_ttl
+ )
+
+ def _invalidate_cache(self, chat_id: str) -> None:
+ """Invalidate cache entries for a chat."""
+ self._message_cache.pop(chat_id, None)
+ self._metadata_cache.pop(chat_id, None)
+ self._chat_list_cache = None
+
+ async def exists(self, chat_id: str) -> bool:
+ """Check if a conversation exists (with caching)."""
+ cached_messages = self._get_cached_messages(chat_id)
+ if cached_messages is not None:
+ return len(cached_messages) > 0
+
+ async with self.pool.acquire() as conn:
+ result = await conn.fetchval(
+ "SELECT EXISTS(SELECT 1 FROM conversations WHERE chat_id = $1)",
+ chat_id
+ )
+ self._db_operations += 1
+ return result
+
+ async def get_messages(self, chat_id: str, limit: Optional[int] = None) -> List[BaseMessage]:
+ """Retrieve messages for a chat session with caching."""
+ cached_messages = self._get_cached_messages(chat_id)
+ if cached_messages is not None:
+ return cached_messages[-limit:] if limit else cached_messages
+
+ async with self.pool.acquire() as conn:
+ row = await conn.fetchrow(
+ "SELECT messages FROM conversations WHERE chat_id = $1",
+ chat_id
+ )
+ self._db_operations += 1
+
+ if not row:
+ return []
+
+ messages_data = row['messages']
+ if isinstance(messages_data, str):
+ messages_data = json.loads(messages_data)
+ messages = [self._dict_to_message(msg_data) for msg_data in messages_data]
+
+ self._cache_messages(chat_id, messages)
+
+ return messages[-limit:] if limit else messages
+
+ async def save_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
+ """Save messages with batching for performance."""
+ async with self._save_lock:
+ self._pending_saves[chat_id] = messages.copy()
+
+ self._cache_messages(chat_id, messages)
+
+ async def save_messages_immediate(self, chat_id: str, messages: List[BaseMessage]) -> None:
+ """Save messages immediately without batching - for critical operations."""
+ serialized_messages = [self._message_to_dict(msg) for msg in messages]
+
+ async with self.pool.acquire() as conn:
+ await conn.execute("""
+ INSERT INTO conversations (chat_id, messages, message_count)
+ VALUES ($1, $2, $3)
+ ON CONFLICT (chat_id)
+ DO UPDATE SET
+ messages = EXCLUDED.messages,
+ message_count = EXCLUDED.message_count,
+ updated_at = CURRENT_TIMESTAMP
+ """, chat_id, json.dumps(serialized_messages), len(messages))
+ self._db_operations += 1
+
+ self._cache_messages(chat_id, messages)
+ self._chat_list_cache = None
+
+ async def _batch_save_worker(self) -> None:
+ """Background worker to batch save operations."""
+ while True:
+ try:
+ await asyncio.sleep(1.0)
+
+ async with self._save_lock:
+ if not self._pending_saves:
+ continue
+
+ saves_to_process = self._pending_saves.copy()
+ self._pending_saves.clear()
+
+ async with self.pool.acquire() as conn:
+ async with conn.transaction():
+ for chat_id, messages in saves_to_process.items():
+ serialized_messages = [self._message_to_dict(msg) for msg in messages]
+
+ await conn.execute("""
+ INSERT INTO conversations (chat_id, messages, message_count)
+ VALUES ($1, $2, $3)
+ ON CONFLICT (chat_id)
+ DO UPDATE SET
+ messages = EXCLUDED.messages,
+ message_count = EXCLUDED.message_count,
+ updated_at = CURRENT_TIMESTAMP
+ """, chat_id, json.dumps(serialized_messages), len(messages))
+
+ self._db_operations += len(saves_to_process)
+ if saves_to_process:
+ logger.debug(f"Batch saved {len(saves_to_process)} conversations")
+ self._chat_list_cache = None
+
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.error(f"Error in batch save worker: {e}")
+
+ async def add_message(self, chat_id: str, message: BaseMessage) -> None:
+ """Add a single message to conversation (optimized)."""
+ current_messages = await self.get_messages(chat_id)
+ current_messages.append(message)
+
+ await self.save_messages(chat_id, current_messages)
+
+ async def delete_conversation(self, chat_id: str) -> bool:
+ """Delete a conversation by chat_id."""
+ try:
+ async with self.pool.acquire() as conn:
+ result = await conn.execute(
+ "DELETE FROM conversations WHERE chat_id = $1",
+ chat_id
+ )
+ self._db_operations += 1
+
+ self._invalidate_cache(chat_id)
+
+ return "DELETE 1" in result
+ except Exception as e:
+ logger.error(f"Error deleting conversation {chat_id}: {e}")
+ return False
+
+ async def list_conversations(self) -> List[str]:
+ """List all conversation IDs with caching."""
+ if self._chat_list_cache and not self._chat_list_cache.is_expired():
+ self._cache_hits += 1
+ return self._chat_list_cache.data
+
+ async with self.pool.acquire() as conn:
+ rows = await conn.fetch(
+ "SELECT chat_id FROM conversations ORDER BY updated_at DESC"
+ )
+ self._db_operations += 1
+
+ chat_ids = [row['chat_id'] for row in rows]
+
+ self._chat_list_cache = CacheEntry(
+ data=chat_ids,
+ timestamp=time.time(),
+ ttl=60
+ )
+ self._cache_misses += 1
+
+ return chat_ids
+
+ async def store_image(self, image_id: str, image_base64: str) -> None:
+ """Store base64 image data with TTL."""
+ async with self.pool.acquire() as conn:
+ await conn.execute("""
+ INSERT INTO images (image_id, image_data)
+ VALUES ($1, $2)
+ ON CONFLICT (image_id)
+ DO UPDATE SET
+ image_data = EXCLUDED.image_data,
+ created_at = CURRENT_TIMESTAMP,
+ expires_at = CURRENT_TIMESTAMP + INTERVAL '1 hour'
+ """, image_id, image_base64)
+ self._db_operations += 1
+
+ self._image_cache[image_id] = CacheEntry(
+ data=image_base64,
+ timestamp=time.time(),
+ ttl=3600
+ )
+
+ async def get_image(self, image_id: str) -> Optional[str]:
+ """Retrieve base64 image data with caching."""
+ cache_entry = self._image_cache.get(image_id)
+ if cache_entry and not cache_entry.is_expired():
+ self._cache_hits += 1
+ return cache_entry.data
+
+ async with self.pool.acquire() as conn:
+ row = await conn.fetchrow("""
+ SELECT image_data FROM images
+ WHERE image_id = $1 AND expires_at > CURRENT_TIMESTAMP
+ """, image_id)
+ self._db_operations += 1
+
+ if row:
+ image_data = row['image_data']
+ self._image_cache[image_id] = CacheEntry(
+ data=image_data,
+ timestamp=time.time(),
+ ttl=3600
+ )
+ self._cache_misses += 1
+ return image_data
+
+ return None
+
+ async def get_chat_metadata(self, chat_id: str) -> Optional[Dict]:
+ """Get chat metadata with caching."""
+ cache_entry = self._metadata_cache.get(chat_id)
+ if cache_entry and not cache_entry.is_expired():
+ self._cache_hits += 1
+ return cache_entry.data
+
+ async with self.pool.acquire() as conn:
+ row = await conn.fetchrow(
+ "SELECT name, created_at FROM chat_metadata WHERE chat_id = $1",
+ chat_id
+ )
+ self._db_operations += 1
+
+ if row:
+ metadata = {
+ "name": row['name'],
+ "created_at": row['created_at'].isoformat()
+ }
+ else:
+ metadata = {"name": f"Chat {chat_id[:8]}"}
+
+ self._metadata_cache[chat_id] = CacheEntry(
+ data=metadata,
+ timestamp=time.time(),
+ ttl=self.cache_ttl
+ )
+ self._cache_misses += 1
+
+ return metadata
+
+ async def set_chat_metadata(self, chat_id: str, name: str) -> None:
+ """Set chat metadata."""
+ async with self.pool.acquire() as conn:
+ await conn.execute("""
+ INSERT INTO chat_metadata (chat_id, name)
+ VALUES ($1, $2)
+ ON CONFLICT (chat_id)
+ DO UPDATE SET
+ name = EXCLUDED.name,
+ updated_at = CURRENT_TIMESTAMP
+ """, chat_id, name)
+ self._db_operations += 1
+
+ self._metadata_cache[chat_id] = CacheEntry(
+ data={"name": name},
+ timestamp=time.time(),
+ ttl=self.cache_ttl
+ )
+
+ async def cleanup_expired_images(self) -> int:
+ """Clean up expired images and return count of deleted images."""
+ async with self.pool.acquire() as conn:
+ result = await conn.execute(
+ "DELETE FROM images WHERE expires_at < CURRENT_TIMESTAMP"
+ )
+ self._db_operations += 1
+
+ expired_keys = [
+ key for key, entry in self._image_cache.items()
+ if entry.is_expired()
+ ]
+ for key in expired_keys:
+ del self._image_cache[key]
+
+ deleted_count = int(result.split()[-1]) if result else 0
+ if deleted_count > 0:
+ logger.debug(f"Cleaned up {deleted_count} expired images")
+
+ return deleted_count
+
+ def get_cache_stats(self) -> Dict[str, Any]:
+ """Get cache performance statistics."""
+ total_requests = self._cache_hits + self._cache_misses
+ hit_rate = (self._cache_hits / total_requests * 100) if total_requests > 0 else 0
+
+ return {
+ "cache_hits": self._cache_hits,
+ "cache_misses": self._cache_misses,
+ "hit_rate_percent": round(hit_rate, 2),
+ "db_operations": self._db_operations,
+ "cached_conversations": len(self._message_cache),
+ "cached_metadata": len(self._metadata_cache),
+ "cached_images": len(self._image_cache)
+ }
+
+ def load_conversation_history(self, chat_id: str) -> List[Dict]:
+ """Legacy method - converts to async call."""
+ import asyncio
+ return asyncio.create_task(self._load_conversation_history_dict(chat_id))
+
+ async def _load_conversation_history_dict(self, chat_id: str) -> List[Dict]:
+ """Load conversation history in dict format for compatibility."""
+ messages = await self.get_messages(chat_id)
+ return [self._message_to_dict(msg) for msg in messages]
+
+ def save_conversation_history(self, chat_id: str, messages: List[Dict]) -> None:
+ """Legacy method - converts to async call."""
+ import asyncio
+ message_objects = [self._dict_to_message(msg) for msg in messages]
+ return asyncio.create_task(self.save_messages(chat_id, message_objects))
diff --git a/nvidia/multi-agent-chatbot/assets/backend/prompts.py b/nvidia/multi-agent-chatbot/assets/backend/prompts.py
new file mode 100644
index 0000000..d4af5b9
--- /dev/null
+++ b/nvidia/multi-agent-chatbot/assets/backend/prompts.py
@@ -0,0 +1,152 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import jinja2
+from typing import Dict
+
+
+SUPERVISOR_AGENT_STR = """
+You are a supervisor agent whose role is to be a helpful planner that can use tools to answer questions. DO NOT WRITE CODE YOURSELF, ALWAYS USE THE TOOLS.
+
+{% if tools %}
+IMPORTANT: You have access to these tools and you MUST use them when applicable and use tool response in your final answer:
+{{ tools }}
+
+CRITICAL RULES:
+- **ALWAYS** use a tool when the user's request matches a tool's capability. For example:
+ - If the user asks to "generate code", "develop", "build", "create", "write a script", "make a website", "develop an app", etc. → **MUST** use the write_code tool with appropriate programming_language parameter
+ - If the user asks to "search", "find", "summarize", "analyze documents/reports", "key points", etc. → **MUST** use the search_documents tool with the query, don't add any other text to the query. You can assume that the user has already uploaded the document and just call the tool.
+ - If the user asks to analyze/describe/understand an image (e.g., "what's in this image", "describe the picture") → **MUST** use the explain_image tool
+
+- **NEVER EVER generate code yourself** - you are FORBIDDEN from writing code directly. ALWAYS use the write_code tool for ANY coding requests
+- **DO NOT** try to answer questions from documents yourself - always use the search_documents tool
+
+CODING KEYWORDS that REQUIRE write_code tool:
+- "code", "develop", "build", "create", "write", "make", "implement", "program", "script", "website", "app", "function", "class", "HTML", "CSS", "JavaScript", "Python", "React", "component"
+
+
+Batching policy:
+- **Batch** when: (a) calls are independent (e.g., weather in two cities), (b) calls target different tools without dependency, or (c) multiple calls to the same tool with different arguments.
+- **Do not batch** when: a call’s arguments depend on a previous tool’s output (e.g., writing code which depends on the output of a search_documents tool).
+
+Output protocol:
+- In the first assistant message of a turn, if tools are needed, **emit all tool calls together** (as multiple tool_calls). Do not include narrative text before the tool_calls unless required by the API.
+- After the ToolMessages arrive, produce a single assistant message with the final answer incorporating all results. Do not call the tools again for the same purpose.
+- **CRITICAL**: When you receive tool results, you MUST use them in your final response. Do NOT ignore successful tool results or claim you don't have information when tools have already provided it.
+- If any tool call succeeds, base your answer on the successful results. Ignore failed tool calls if you have successful ones.
+- Always present the information from successful tool calls as your definitive answer.
+
+
+Few-shot examples:
+# Direct coding request
+User: Create a responsive personal website for my AI development business
+Assistant (tool calls immediately):
+- write_code({"query": "Create a responsive personal website for my AI development business", "programming_language": "HTML"})
+
+# Batching independent calls
+User: now, can you get the weather in egypt and the rain forecast in malibu?
+Assistant (tool calls in one message):
+- get_weather({"location": "Egypt"})
+- get_rain_forecast({"location": "Malibu"})
+
+# Staged dependent calls
+User: Search my documents for design requirements then build a website based on those requirements
+Assistant (first message; dependent plan):
+- search_documents({"query": "design requirements website"})
+# (Wait for ToolMessage)
+Assistant (after ToolMessage):
+- write_code({"query": "build a website based on these design requirements: {ingestMessage}
+RAG Agent
+Image Understanding Agent
+Coding Agent
+Local LLM
+| Step | \n", + "Training Loss | \n", + "Validation Loss | \n", + "
|---|
"
+ ],
+ "text/plain": [
+ "