dgx-spark-playbooks/nvidia/multi-agent-chatbot/assets/backend/main.py
2025-10-04 21:21:42 +00:00

516 lines
17 KiB
Python

#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""FastAPI backend server for the chatbot application.
This module provides the main HTTP API endpoints and WebSocket connections for:
- Real-time chat via WebSocket
- File upload and document ingestion
- Configuration management (models, sources, chat settings)
- Chat history management
- Vector store operations
"""
import base64
import json
import os
import uuid
from contextlib import asynccontextmanager
from typing import List, Optional, Dict
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from agent import ChatAgent
from config import ConfigManager
from logger import logger, log_request, log_response, log_error
from models import ChatIdRequest, ChatRenameRequest, SelectedModelRequest
from postgres_storage import PostgreSQLConversationStorage
from utils import process_and_ingest_files_background
from vector_store import create_vector_store_with_config
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", 5432))
POSTGRES_DB = os.getenv("POSTGRES_DB", "chatbot")
POSTGRES_USER = os.getenv("POSTGRES_USER", "chatbot_user")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "chatbot_password")
config_manager = ConfigManager("./config.json")
postgres_storage = PostgreSQLConversationStorage(
host=POSTGRES_HOST,
port=POSTGRES_PORT,
database=POSTGRES_DB,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD
)
vector_store = create_vector_store_with_config(config_manager)
vector_store._initialize_store()
agent: ChatAgent | None = None
indexing_tasks: Dict[str, str] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for startup and shutdown tasks."""
global agent
logger.debug("Initializing PostgreSQL storage and agent...")
try:
await postgres_storage.init_pool()
logger.info("PostgreSQL storage initialized successfully")
logger.debug("Initializing ChatAgent...")
agent = await ChatAgent.create(
vector_store=vector_store,
config_manager=config_manager,
postgres_storage=postgres_storage
)
logger.info("ChatAgent initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL storage: {e}")
raise
yield
try:
await postgres_storage.close()
logger.debug("PostgreSQL storage closed successfully")
except Exception as e:
logger.error(f"Error closing PostgreSQL storage: {e}")
app = FastAPI(
title="Chatbot API",
description="Backend API for LLM-powered chatbot with RAG capabilities",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.websocket("/ws/chat/{chat_id}")
async def websocket_endpoint(websocket: WebSocket, chat_id: str):
"""WebSocket endpoint for real-time chat communication.
Args:
websocket: WebSocket connection
chat_id: Unique chat identifier
"""
logger.debug(f"WebSocket connection attempt for chat_id: {chat_id}")
try:
await websocket.accept()
logger.debug(f"WebSocket connection accepted for chat_id: {chat_id}")
history_messages = await postgres_storage.get_messages(chat_id)
history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(history_messages) if i != 0]
await websocket.send_json({"type": "history", "messages": history})
while True:
data = await websocket.receive_text()
client_message = json.loads(data)
new_message = client_message.get("message")
image_id = client_message.get("image_id")
image_data = None
if image_id:
image_data = await postgres_storage.get_image(image_id)
logger.debug(f"Retrieved image data for image_id: {image_id}, data length: {len(image_data) if image_data else 0}")
try:
async for event in agent.query(query_text=new_message, chat_id=chat_id, image_data=image_data):
await websocket.send_json(event)
except Exception as query_error:
logger.error(f"Error in agent.query: {str(query_error)}", exc_info=True)
await websocket.send_json({"type": "error", "content": f"Error processing request: {str(query_error)}"})
final_messages = await postgres_storage.get_messages(chat_id)
final_history = [postgres_storage._message_to_dict(msg) for i, msg in enumerate(final_messages) if i != 0]
await websocket.send_json({"type": "history", "messages": final_history})
except WebSocketDisconnect:
logger.debug(f"Client disconnected from chat {chat_id}")
except Exception as e:
logger.error(f"WebSocket error for chat {chat_id}: {str(e)}", exc_info=True)
@app.post("/upload-image")
async def upload_image(image: UploadFile = File(...), chat_id: str = Form(...)):
"""Upload and store an image for chat processing.
Args:
image: Uploaded image file
chat_id: Chat identifier for context
Returns:
Dictionary with generated image_id
"""
image_data = await image.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
data_uri = f"data:{image.content_type};base64,{image_base64}"
image_id = str(uuid.uuid4())
await postgres_storage.store_image(image_id, data_uri)
return {"image_id": image_id}
@app.post("/ingest")
async def ingest_files(files: Optional[List[UploadFile]] = File(None), background_tasks: BackgroundTasks = None):
"""Ingest documents for vector search and RAG.
Args:
files: List of uploaded files to process
background_tasks: FastAPI background tasks manager
Returns:
Task information for tracking ingestion progress
"""
try:
log_request({"file_count": len(files) if files else 0}, "/ingest")
task_id = str(uuid.uuid4())
file_info = []
for file in files:
content = await file.read()
file_info.append({
"filename": file.filename,
"content": content
})
indexing_tasks[task_id] = "queued"
background_tasks.add_task(
process_and_ingest_files_background,
file_info,
vector_store,
config_manager,
task_id,
indexing_tasks
)
response = {
"message": f"Files queued for processing. Indexing {len(files)} files in the background.",
"files": [file.filename for file in files],
"status": "queued",
"task_id": task_id
}
log_response(response, "/ingest")
return response
except Exception as e:
log_error(e, "/ingest")
raise HTTPException(
status_code=500,
detail=f"Error queuing files for ingestion: {str(e)}"
)
@app.get("/ingest/status/{task_id}")
async def get_indexing_status(task_id: str):
"""Get the status of a file ingestion task.
Args:
task_id: Unique task identifier
Returns:
Current task status
"""
if task_id in indexing_tasks:
return {"status": indexing_tasks[task_id]}
else:
raise HTTPException(status_code=404, detail="Task not found")
@app.get("/sources")
async def get_sources():
"""Get all available document sources."""
try:
config = config_manager.read_config()
return {"sources": config.sources}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting sources: {str(e)}")
@app.get("/selected_sources")
async def get_selected_sources():
"""Get currently selected document sources for RAG."""
try:
config = config_manager.read_config()
return {"sources": config.selected_sources}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting selected sources: {str(e)}")
@app.post("/selected_sources")
async def update_selected_sources(selected_sources: List[str]):
"""Update the selected document sources for RAG.
Args:
selected_sources: List of source names to use for retrieval
"""
try:
config_manager.updated_selected_sources(selected_sources)
return {"status": "success", "message": "Selected sources updated successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error updating selected sources: {str(e)}")
@app.get("/selected_model")
async def get_selected_model():
"""Get the currently selected LLM model."""
try:
model = config_manager.get_selected_model()
return {"model": model}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting selected model: {str(e)}")
@app.post("/selected_model")
async def update_selected_model(request: SelectedModelRequest):
"""Update the selected LLM model.
Args:
request: Model selection request with model name
"""
try:
logger.debug(f"Updating selected model to: {request.model}")
config_manager.updated_selected_model(request.model)
return {"status": "success", "message": "Selected model updated successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error updating selected model: {str(e)}")
@app.get("/available_models")
async def get_available_models():
"""Get list of all available LLM models."""
try:
models = config_manager.get_available_models()
return {"models": models}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting available models: {str(e)}")
@app.get("/chats")
async def list_chats():
"""Get list of all chat conversations."""
try:
chat_ids = await postgres_storage.list_conversations()
return {"chats": chat_ids}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error listing chats: {str(e)}")
@app.get("/chat_id")
async def get_chat_id():
"""Get the current active chat ID, creating a conversation if it doesn't exist."""
try:
config = config_manager.read_config()
current_chat_id = config.current_chat_id
if current_chat_id and await postgres_storage.exists(current_chat_id):
return {
"status": "success",
"chat_id": current_chat_id
}
new_chat_id = str(uuid.uuid4())
await postgres_storage.save_messages_immediate(new_chat_id, [])
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
config_manager.updated_current_chat_id(new_chat_id)
return {
"status": "success",
"chat_id": new_chat_id
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error getting chat ID: {str(e)}"
)
@app.post("/chat_id")
async def update_chat_id(request: ChatIdRequest):
"""Update the current active chat ID.
Args:
request: Chat ID update request
"""
try:
config_manager.updated_current_chat_id(request.chat_id)
return {
"status": "success",
"message": f"Current chat ID updated to {request.chat_id}",
"chat_id": request.chat_id
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error updating chat ID: {str(e)}"
)
@app.get("/chat/{chat_id}/metadata")
async def get_chat_metadata(chat_id: str):
"""Get metadata for a specific chat.
Args:
chat_id: Unique chat identifier
Returns:
Chat metadata including name
"""
try:
metadata = await postgres_storage.get_chat_metadata(chat_id)
return metadata
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error getting chat metadata: {str(e)}"
)
@app.post("/chat/rename")
async def rename_chat(request: ChatRenameRequest):
"""Rename a chat conversation.
Args:
request: Chat rename request with chat_id and new_name
"""
try:
await postgres_storage.set_chat_metadata(request.chat_id, request.new_name)
return {
"status": "success",
"message": f"Chat {request.chat_id} renamed to {request.new_name}"
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error renaming chat: {str(e)}"
)
@app.post("/chat/new")
async def create_new_chat():
"""Create a new chat conversation and set it as current."""
try:
new_chat_id = str(uuid.uuid4())
await postgres_storage.save_messages_immediate(new_chat_id, [])
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
config_manager.updated_current_chat_id(new_chat_id)
return {
"status": "success",
"message": "New chat created",
"chat_id": new_chat_id
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error creating new chat: {str(e)}"
)
@app.delete("/chat/{chat_id}")
async def delete_chat(chat_id: str):
"""Delete a specific chat and its messages.
Args:
chat_id: Unique chat identifier to delete
"""
try:
success = await postgres_storage.delete_conversation(chat_id)
if success:
return {
"status": "success",
"message": f"Chat {chat_id} deleted successfully"
}
else:
raise HTTPException(
status_code=404,
detail=f"Chat {chat_id} not found"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error deleting chat: {str(e)}"
)
@app.delete("/chats/clear")
async def clear_all_chats():
"""Clear all chat conversations and create a new default chat."""
try:
chat_ids = await postgres_storage.list_conversations()
cleared_count = 0
for chat_id in chat_ids:
if await postgres_storage.delete_conversation(chat_id):
cleared_count += 1
new_chat_id = str(uuid.uuid4())
await postgres_storage.save_messages_immediate(new_chat_id, [])
await postgres_storage.set_chat_metadata(new_chat_id, f"Chat {new_chat_id[:8]}")
config_manager.updated_current_chat_id(new_chat_id)
return {
"status": "success",
"message": f"Cleared {cleared_count} chats and created new chat",
"new_chat_id": new_chat_id,
"cleared_count": cleared_count
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error clearing all chats: {str(e)}"
)
@app.delete("/collections/{collection_name}")
async def delete_collection(collection_name: str):
"""Delete a document collection from the vector store.
Args:
collection_name: Name of the collection to delete
"""
try:
success = vector_store.delete_collection(collection_name)
if success:
return {"status": "success", "message": f"Collection '{collection_name}' deleted successfully"}
else:
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found or could not be deleted")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)