dgx-spark-playbooks/nvidia/multi-agent-chatbot/assets/backend/postgres_storage.py

572 lines
22 KiB
Python
Raw Normal View History

2025-10-04 21:21:42 +00:00
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PostgreSQL-based conversation storage with caching and I/O optimization."""
import json
import time
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime, timedelta
import asyncio
import asyncpg
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage, ToolMessage
from logger import logger
@dataclass
class CacheEntry:
"""Cache entry with TTL support."""
data: Any
timestamp: float
ttl: float = 300
def is_expired(self) -> bool:
return time.time() - self.timestamp > self.ttl
class PostgreSQLConversationStorage:
"""PostgreSQL-based conversation storage with intelligent caching and I/O optimization."""
def __init__(
self,
host: str = 'postgres',
port: int = 5432,
database: str = 'chatbot',
user: str = 'chatbot_user',
password: str = 'chatbot_password',
pool_size: int = 10,
cache_ttl: int = 300
):
"""Initialize PostgreSQL connection pool and caching.
Args:
host: PostgreSQL host
port: PostgreSQL port
database: Database name
user: Database user
password: Database password
pool_size: Connection pool size
cache_ttl: Cache TTL in seconds
"""
self.host = host
self.port = port
self.database = database
self.user = user
self.password = password
self.pool_size = pool_size
self.cache_ttl = cache_ttl
self.pool: Optional[asyncpg.Pool] = None
self._message_cache: Dict[str, CacheEntry] = {}
self._metadata_cache: Dict[str, CacheEntry] = {}
self._image_cache: Dict[str, CacheEntry] = {}
self._chat_list_cache: Optional[CacheEntry] = None
self._pending_saves: Dict[str, List[BaseMessage]] = {}
self._save_lock = asyncio.Lock()
self._batch_save_task: Optional[asyncio.Task] = None
self._cache_hits = 0
self._cache_misses = 0
self._db_operations = 0
async def init_pool(self) -> None:
"""Initialize the connection pool and create tables."""
try:
await self._ensure_database_exists()
self.pool = await asyncpg.create_pool(
host=self.host,
port=self.port,
database=self.database,
user=self.user,
password=self.password,
min_size=2,
max_size=self.pool_size,
command_timeout=30
)
await self._create_tables()
logger.debug("PostgreSQL connection pool initialized successfully")
self._batch_save_task = asyncio.create_task(self._batch_save_worker())
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL pool: {e}")
raise
async def _ensure_database_exists(self) -> None:
"""Ensure the target database exists, create if it doesn't."""
try:
conn = await asyncpg.connect(
host=self.host,
port=self.port,
database='postgres',
user=self.user,
password=self.password
)
try:
result = await conn.fetchval(
"SELECT 1 FROM pg_database WHERE datname = $1",
self.database
)
if not result:
await conn.execute(f'CREATE DATABASE "{self.database}"')
logger.debug(f"Created database: {self.database}")
else:
logger.debug(f"Database {self.database} already exists")
finally:
await conn.close()
except Exception as e:
logger.error(f"Error ensuring database exists: {e}")
pass
async def close(self) -> None:
"""Close the connection pool and cleanup."""
if self._batch_save_task:
self._batch_save_task.cancel()
try:
await self._batch_save_task
except asyncio.CancelledError:
pass
if self.pool:
await self.pool.close()
logger.debug("PostgreSQL connection pool closed")
async def _create_tables(self) -> None:
"""Create necessary tables if they don't exist."""
async with self.pool.acquire() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS conversations (
chat_id VARCHAR(255) PRIMARY KEY,
messages JSONB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
message_count INTEGER DEFAULT 0
)
""")
await conn.execute("""
CREATE TABLE IF NOT EXISTS chat_metadata (
chat_id VARCHAR(255) PRIMARY KEY,
name VARCHAR(500),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (chat_id) REFERENCES conversations(chat_id) ON DELETE CASCADE
)
""")
await conn.execute("""
CREATE TABLE IF NOT EXISTS images (
image_id VARCHAR(255) PRIMARY KEY,
image_data TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP + INTERVAL '1 hour')
)
""")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_images_expires_at ON images(expires_at)")
await conn.execute("""
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql'
""")
await conn.execute("""
DROP TRIGGER IF EXISTS update_conversations_updated_at ON conversations
""")
await conn.execute("""
CREATE TRIGGER update_conversations_updated_at
BEFORE UPDATE ON conversations
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column()
""")
def _message_to_dict(self, message: BaseMessage) -> Dict:
"""Convert a message object to a dictionary for storage."""
result = {
"type": message.__class__.__name__,
"content": message.content,
}
if hasattr(message, "tool_calls") and message.tool_calls:
result["tool_calls"] = message.tool_calls
if isinstance(message, ToolMessage):
result["tool_call_id"] = getattr(message, "tool_call_id", None)
result["name"] = getattr(message, "name", None)
return result
def _dict_to_message(self, data: Dict) -> BaseMessage:
"""Convert a dictionary back to a message object."""
msg_type = data["type"]
content = data["content"]
if msg_type == "AIMessage":
msg = AIMessage(content=content)
if "tool_calls" in data:
msg.tool_calls = data["tool_calls"]
return msg
elif msg_type == "HumanMessage":
return HumanMessage(content=content)
elif msg_type == "SystemMessage":
return SystemMessage(content=content)
elif msg_type == "ToolMessage":
return ToolMessage(
content=content,
tool_call_id=data.get("tool_call_id", ""),
name=data.get("name", "")
)
else:
return HumanMessage(content=content)
def _get_cached_messages(self, chat_id: str) -> Optional[List[BaseMessage]]:
"""Get messages from cache if available and not expired."""
cache_entry = self._message_cache.get(chat_id)
if cache_entry and not cache_entry.is_expired():
self._cache_hits += 1
return cache_entry.data
self._cache_misses += 1
return None
def _cache_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
"""Cache messages with TTL."""
self._message_cache[chat_id] = CacheEntry(
data=messages.copy(),
timestamp=time.time(),
ttl=self.cache_ttl
)
def _invalidate_cache(self, chat_id: str) -> None:
"""Invalidate cache entries for a chat."""
self._message_cache.pop(chat_id, None)
self._metadata_cache.pop(chat_id, None)
self._chat_list_cache = None
async def exists(self, chat_id: str) -> bool:
"""Check if a conversation exists (with caching)."""
cached_messages = self._get_cached_messages(chat_id)
if cached_messages is not None:
return len(cached_messages) > 0
async with self.pool.acquire() as conn:
result = await conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM conversations WHERE chat_id = $1)",
chat_id
)
self._db_operations += 1
return result
async def get_messages(self, chat_id: str, limit: Optional[int] = None) -> List[BaseMessage]:
"""Retrieve messages for a chat session with caching."""
cached_messages = self._get_cached_messages(chat_id)
if cached_messages is not None:
return cached_messages[-limit:] if limit else cached_messages
async with self.pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT messages FROM conversations WHERE chat_id = $1",
chat_id
)
self._db_operations += 1
if not row:
return []
messages_data = row['messages']
if isinstance(messages_data, str):
messages_data = json.loads(messages_data)
messages = [self._dict_to_message(msg_data) for msg_data in messages_data]
self._cache_messages(chat_id, messages)
return messages[-limit:] if limit else messages
async def save_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
"""Save messages with batching for performance."""
async with self._save_lock:
self._pending_saves[chat_id] = messages.copy()
self._cache_messages(chat_id, messages)
async def save_messages_immediate(self, chat_id: str, messages: List[BaseMessage]) -> None:
"""Save messages immediately without batching - for critical operations."""
serialized_messages = [self._message_to_dict(msg) for msg in messages]
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO conversations (chat_id, messages, message_count)
VALUES ($1, $2, $3)
ON CONFLICT (chat_id)
DO UPDATE SET
messages = EXCLUDED.messages,
message_count = EXCLUDED.message_count,
updated_at = CURRENT_TIMESTAMP
""", chat_id, json.dumps(serialized_messages), len(messages))
self._db_operations += 1
self._cache_messages(chat_id, messages)
self._chat_list_cache = None
async def _batch_save_worker(self) -> None:
"""Background worker to batch save operations."""
while True:
try:
await asyncio.sleep(1.0)
async with self._save_lock:
if not self._pending_saves:
continue
saves_to_process = self._pending_saves.copy()
self._pending_saves.clear()
async with self.pool.acquire() as conn:
async with conn.transaction():
for chat_id, messages in saves_to_process.items():
serialized_messages = [self._message_to_dict(msg) for msg in messages]
await conn.execute("""
INSERT INTO conversations (chat_id, messages, message_count)
VALUES ($1, $2, $3)
ON CONFLICT (chat_id)
DO UPDATE SET
messages = EXCLUDED.messages,
message_count = EXCLUDED.message_count,
updated_at = CURRENT_TIMESTAMP
""", chat_id, json.dumps(serialized_messages), len(messages))
self._db_operations += len(saves_to_process)
if saves_to_process:
logger.debug(f"Batch saved {len(saves_to_process)} conversations")
self._chat_list_cache = None
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in batch save worker: {e}")
async def add_message(self, chat_id: str, message: BaseMessage) -> None:
"""Add a single message to conversation (optimized)."""
current_messages = await self.get_messages(chat_id)
current_messages.append(message)
await self.save_messages(chat_id, current_messages)
async def delete_conversation(self, chat_id: str) -> bool:
"""Delete a conversation by chat_id."""
try:
async with self.pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM conversations WHERE chat_id = $1",
chat_id
)
self._db_operations += 1
self._invalidate_cache(chat_id)
return "DELETE 1" in result
except Exception as e:
logger.error(f"Error deleting conversation {chat_id}: {e}")
return False
async def list_conversations(self) -> List[str]:
"""List all conversation IDs with caching."""
if self._chat_list_cache and not self._chat_list_cache.is_expired():
self._cache_hits += 1
return self._chat_list_cache.data
async with self.pool.acquire() as conn:
rows = await conn.fetch(
"SELECT chat_id FROM conversations ORDER BY updated_at DESC"
)
self._db_operations += 1
chat_ids = [row['chat_id'] for row in rows]
self._chat_list_cache = CacheEntry(
data=chat_ids,
timestamp=time.time(),
ttl=60
)
self._cache_misses += 1
return chat_ids
async def store_image(self, image_id: str, image_base64: str) -> None:
"""Store base64 image data with TTL."""
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO images (image_id, image_data)
VALUES ($1, $2)
ON CONFLICT (image_id)
DO UPDATE SET
image_data = EXCLUDED.image_data,
created_at = CURRENT_TIMESTAMP,
expires_at = CURRENT_TIMESTAMP + INTERVAL '1 hour'
""", image_id, image_base64)
self._db_operations += 1
self._image_cache[image_id] = CacheEntry(
data=image_base64,
timestamp=time.time(),
ttl=3600
)
async def get_image(self, image_id: str) -> Optional[str]:
"""Retrieve base64 image data with caching."""
cache_entry = self._image_cache.get(image_id)
if cache_entry and not cache_entry.is_expired():
self._cache_hits += 1
return cache_entry.data
async with self.pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT image_data FROM images
WHERE image_id = $1 AND expires_at > CURRENT_TIMESTAMP
""", image_id)
self._db_operations += 1
if row:
image_data = row['image_data']
self._image_cache[image_id] = CacheEntry(
data=image_data,
timestamp=time.time(),
ttl=3600
)
self._cache_misses += 1
return image_data
return None
async def get_chat_metadata(self, chat_id: str) -> Optional[Dict]:
"""Get chat metadata with caching."""
cache_entry = self._metadata_cache.get(chat_id)
if cache_entry and not cache_entry.is_expired():
self._cache_hits += 1
return cache_entry.data
async with self.pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT name, created_at FROM chat_metadata WHERE chat_id = $1",
chat_id
)
self._db_operations += 1
if row:
metadata = {
"name": row['name'],
"created_at": row['created_at'].isoformat()
}
else:
metadata = {"name": f"Chat {chat_id[:8]}"}
self._metadata_cache[chat_id] = CacheEntry(
data=metadata,
timestamp=time.time(),
ttl=self.cache_ttl
)
self._cache_misses += 1
return metadata
async def set_chat_metadata(self, chat_id: str, name: str) -> None:
"""Set chat metadata."""
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO chat_metadata (chat_id, name)
VALUES ($1, $2)
ON CONFLICT (chat_id)
DO UPDATE SET
name = EXCLUDED.name,
updated_at = CURRENT_TIMESTAMP
""", chat_id, name)
self._db_operations += 1
self._metadata_cache[chat_id] = CacheEntry(
data={"name": name},
timestamp=time.time(),
ttl=self.cache_ttl
)
async def cleanup_expired_images(self) -> int:
"""Clean up expired images and return count of deleted images."""
async with self.pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM images WHERE expires_at < CURRENT_TIMESTAMP"
)
self._db_operations += 1
expired_keys = [
key for key, entry in self._image_cache.items()
if entry.is_expired()
]
for key in expired_keys:
del self._image_cache[key]
deleted_count = int(result.split()[-1]) if result else 0
if deleted_count > 0:
logger.debug(f"Cleaned up {deleted_count} expired images")
return deleted_count
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache performance statistics."""
total_requests = self._cache_hits + self._cache_misses
hit_rate = (self._cache_hits / total_requests * 100) if total_requests > 0 else 0
return {
"cache_hits": self._cache_hits,
"cache_misses": self._cache_misses,
"hit_rate_percent": round(hit_rate, 2),
"db_operations": self._db_operations,
"cached_conversations": len(self._message_cache),
"cached_metadata": len(self._metadata_cache),
"cached_images": len(self._image_cache)
}
def load_conversation_history(self, chat_id: str) -> List[Dict]:
"""Legacy method - converts to async call."""
import asyncio
return asyncio.create_task(self._load_conversation_history_dict(chat_id))
async def _load_conversation_history_dict(self, chat_id: str) -> List[Dict]:
"""Load conversation history in dict format for compatibility."""
messages = await self.get_messages(chat_id)
return [self._message_to_dict(msg) for msg in messages]
def save_conversation_history(self, chat_id: str, messages: List[Dict]) -> None:
"""Legacy method - converts to async call."""
import asyncio
message_objects = [self._dict_to_message(msg) for msg in messages]
return asyncio.create_task(self.save_messages(chat_id, message_objects))