mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-22 18:13:52 +00:00
572 lines
22 KiB
Python
572 lines
22 KiB
Python
#
|
|
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""PostgreSQL-based conversation storage with caching and I/O optimization."""
|
|
|
|
import json
|
|
import time
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
import asyncio
|
|
import asyncpg
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage, ToolMessage
|
|
|
|
from logger import logger
|
|
|
|
|
|
@dataclass
|
|
class CacheEntry:
|
|
"""Cache entry with TTL support."""
|
|
data: Any
|
|
timestamp: float
|
|
ttl: float = 300
|
|
|
|
def is_expired(self) -> bool:
|
|
return time.time() - self.timestamp > self.ttl
|
|
|
|
|
|
class PostgreSQLConversationStorage:
|
|
"""PostgreSQL-based conversation storage with intelligent caching and I/O optimization."""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = 'postgres',
|
|
port: int = 5432,
|
|
database: str = 'chatbot',
|
|
user: str = 'chatbot_user',
|
|
password: str = 'chatbot_password',
|
|
pool_size: int = 10,
|
|
cache_ttl: int = 300
|
|
):
|
|
"""Initialize PostgreSQL connection pool and caching.
|
|
|
|
Args:
|
|
host: PostgreSQL host
|
|
port: PostgreSQL port
|
|
database: Database name
|
|
user: Database user
|
|
password: Database password
|
|
pool_size: Connection pool size
|
|
cache_ttl: Cache TTL in seconds
|
|
"""
|
|
self.host = host
|
|
self.port = port
|
|
self.database = database
|
|
self.user = user
|
|
self.password = password
|
|
self.pool_size = pool_size
|
|
self.cache_ttl = cache_ttl
|
|
|
|
self.pool: Optional[asyncpg.Pool] = None
|
|
|
|
self._message_cache: Dict[str, CacheEntry] = {}
|
|
self._metadata_cache: Dict[str, CacheEntry] = {}
|
|
self._image_cache: Dict[str, CacheEntry] = {}
|
|
self._chat_list_cache: Optional[CacheEntry] = None
|
|
|
|
self._pending_saves: Dict[str, List[BaseMessage]] = {}
|
|
self._save_lock = asyncio.Lock()
|
|
self._batch_save_task: Optional[asyncio.Task] = None
|
|
|
|
self._cache_hits = 0
|
|
self._cache_misses = 0
|
|
self._db_operations = 0
|
|
|
|
async def init_pool(self) -> None:
|
|
"""Initialize the connection pool and create tables."""
|
|
try:
|
|
await self._ensure_database_exists()
|
|
|
|
self.pool = await asyncpg.create_pool(
|
|
host=self.host,
|
|
port=self.port,
|
|
database=self.database,
|
|
user=self.user,
|
|
password=self.password,
|
|
min_size=2,
|
|
max_size=self.pool_size,
|
|
command_timeout=30
|
|
)
|
|
|
|
await self._create_tables()
|
|
logger.debug("PostgreSQL connection pool initialized successfully")
|
|
|
|
self._batch_save_task = asyncio.create_task(self._batch_save_worker())
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize PostgreSQL pool: {e}")
|
|
raise
|
|
|
|
async def _ensure_database_exists(self) -> None:
|
|
"""Ensure the target database exists, create if it doesn't."""
|
|
try:
|
|
conn = await asyncpg.connect(
|
|
host=self.host,
|
|
port=self.port,
|
|
database='postgres',
|
|
user=self.user,
|
|
password=self.password
|
|
)
|
|
|
|
try:
|
|
result = await conn.fetchval(
|
|
"SELECT 1 FROM pg_database WHERE datname = $1",
|
|
self.database
|
|
)
|
|
|
|
if not result:
|
|
await conn.execute(f'CREATE DATABASE "{self.database}"')
|
|
logger.debug(f"Created database: {self.database}")
|
|
else:
|
|
logger.debug(f"Database {self.database} already exists")
|
|
|
|
finally:
|
|
await conn.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring database exists: {e}")
|
|
pass
|
|
|
|
async def close(self) -> None:
|
|
"""Close the connection pool and cleanup."""
|
|
if self._batch_save_task:
|
|
self._batch_save_task.cancel()
|
|
try:
|
|
await self._batch_save_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
if self.pool:
|
|
await self.pool.close()
|
|
logger.debug("PostgreSQL connection pool closed")
|
|
|
|
async def _create_tables(self) -> None:
|
|
"""Create necessary tables if they don't exist."""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS conversations (
|
|
chat_id VARCHAR(255) PRIMARY KEY,
|
|
messages JSONB NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
message_count INTEGER DEFAULT 0
|
|
)
|
|
""")
|
|
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS chat_metadata (
|
|
chat_id VARCHAR(255) PRIMARY KEY,
|
|
name VARCHAR(500),
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY (chat_id) REFERENCES conversations(chat_id) ON DELETE CASCADE
|
|
)
|
|
""")
|
|
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS images (
|
|
image_id VARCHAR(255) PRIMARY KEY,
|
|
image_data TEXT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
expires_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP + INTERVAL '1 hour')
|
|
)
|
|
""")
|
|
|
|
await conn.execute("CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at)")
|
|
await conn.execute("CREATE INDEX IF NOT EXISTS idx_images_expires_at ON images(expires_at)")
|
|
|
|
await conn.execute("""
|
|
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
|
RETURNS TRIGGER AS $$
|
|
BEGIN
|
|
NEW.updated_at = CURRENT_TIMESTAMP;
|
|
RETURN NEW;
|
|
END;
|
|
$$ language 'plpgsql'
|
|
""")
|
|
|
|
await conn.execute("""
|
|
DROP TRIGGER IF EXISTS update_conversations_updated_at ON conversations
|
|
""")
|
|
await conn.execute("""
|
|
CREATE TRIGGER update_conversations_updated_at
|
|
BEFORE UPDATE ON conversations
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION update_updated_at_column()
|
|
""")
|
|
|
|
def _message_to_dict(self, message: BaseMessage) -> Dict:
|
|
"""Convert a message object to a dictionary for storage."""
|
|
result = {
|
|
"type": message.__class__.__name__,
|
|
"content": message.content,
|
|
}
|
|
|
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
result["tool_calls"] = message.tool_calls
|
|
|
|
if isinstance(message, ToolMessage):
|
|
result["tool_call_id"] = getattr(message, "tool_call_id", None)
|
|
result["name"] = getattr(message, "name", None)
|
|
|
|
return result
|
|
|
|
def _dict_to_message(self, data: Dict) -> BaseMessage:
|
|
"""Convert a dictionary back to a message object."""
|
|
msg_type = data["type"]
|
|
content = data["content"]
|
|
|
|
if msg_type == "AIMessage":
|
|
msg = AIMessage(content=content)
|
|
if "tool_calls" in data:
|
|
msg.tool_calls = data["tool_calls"]
|
|
return msg
|
|
elif msg_type == "HumanMessage":
|
|
return HumanMessage(content=content)
|
|
elif msg_type == "SystemMessage":
|
|
return SystemMessage(content=content)
|
|
elif msg_type == "ToolMessage":
|
|
return ToolMessage(
|
|
content=content,
|
|
tool_call_id=data.get("tool_call_id", ""),
|
|
name=data.get("name", "")
|
|
)
|
|
else:
|
|
return HumanMessage(content=content)
|
|
|
|
def _get_cached_messages(self, chat_id: str) -> Optional[List[BaseMessage]]:
|
|
"""Get messages from cache if available and not expired."""
|
|
cache_entry = self._message_cache.get(chat_id)
|
|
if cache_entry and not cache_entry.is_expired():
|
|
self._cache_hits += 1
|
|
return cache_entry.data
|
|
|
|
self._cache_misses += 1
|
|
return None
|
|
|
|
def _cache_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
|
|
"""Cache messages with TTL."""
|
|
self._message_cache[chat_id] = CacheEntry(
|
|
data=messages.copy(),
|
|
timestamp=time.time(),
|
|
ttl=self.cache_ttl
|
|
)
|
|
|
|
def _invalidate_cache(self, chat_id: str) -> None:
|
|
"""Invalidate cache entries for a chat."""
|
|
self._message_cache.pop(chat_id, None)
|
|
self._metadata_cache.pop(chat_id, None)
|
|
self._chat_list_cache = None
|
|
|
|
async def exists(self, chat_id: str) -> bool:
|
|
"""Check if a conversation exists (with caching)."""
|
|
cached_messages = self._get_cached_messages(chat_id)
|
|
if cached_messages is not None:
|
|
return len(cached_messages) > 0
|
|
|
|
async with self.pool.acquire() as conn:
|
|
result = await conn.fetchval(
|
|
"SELECT EXISTS(SELECT 1 FROM conversations WHERE chat_id = $1)",
|
|
chat_id
|
|
)
|
|
self._db_operations += 1
|
|
return result
|
|
|
|
async def get_messages(self, chat_id: str, limit: Optional[int] = None) -> List[BaseMessage]:
|
|
"""Retrieve messages for a chat session with caching."""
|
|
cached_messages = self._get_cached_messages(chat_id)
|
|
if cached_messages is not None:
|
|
return cached_messages[-limit:] if limit else cached_messages
|
|
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"SELECT messages FROM conversations WHERE chat_id = $1",
|
|
chat_id
|
|
)
|
|
self._db_operations += 1
|
|
|
|
if not row:
|
|
return []
|
|
|
|
messages_data = row['messages']
|
|
if isinstance(messages_data, str):
|
|
messages_data = json.loads(messages_data)
|
|
messages = [self._dict_to_message(msg_data) for msg_data in messages_data]
|
|
|
|
self._cache_messages(chat_id, messages)
|
|
|
|
return messages[-limit:] if limit else messages
|
|
|
|
async def save_messages(self, chat_id: str, messages: List[BaseMessage]) -> None:
|
|
"""Save messages with batching for performance."""
|
|
async with self._save_lock:
|
|
self._pending_saves[chat_id] = messages.copy()
|
|
|
|
self._cache_messages(chat_id, messages)
|
|
|
|
async def save_messages_immediate(self, chat_id: str, messages: List[BaseMessage]) -> None:
|
|
"""Save messages immediately without batching - for critical operations."""
|
|
serialized_messages = [self._message_to_dict(msg) for msg in messages]
|
|
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
INSERT INTO conversations (chat_id, messages, message_count)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (chat_id)
|
|
DO UPDATE SET
|
|
messages = EXCLUDED.messages,
|
|
message_count = EXCLUDED.message_count,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
""", chat_id, json.dumps(serialized_messages), len(messages))
|
|
self._db_operations += 1
|
|
|
|
self._cache_messages(chat_id, messages)
|
|
self._chat_list_cache = None
|
|
|
|
async def _batch_save_worker(self) -> None:
|
|
"""Background worker to batch save operations."""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(1.0)
|
|
|
|
async with self._save_lock:
|
|
if not self._pending_saves:
|
|
continue
|
|
|
|
saves_to_process = self._pending_saves.copy()
|
|
self._pending_saves.clear()
|
|
|
|
async with self.pool.acquire() as conn:
|
|
async with conn.transaction():
|
|
for chat_id, messages in saves_to_process.items():
|
|
serialized_messages = [self._message_to_dict(msg) for msg in messages]
|
|
|
|
await conn.execute("""
|
|
INSERT INTO conversations (chat_id, messages, message_count)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (chat_id)
|
|
DO UPDATE SET
|
|
messages = EXCLUDED.messages,
|
|
message_count = EXCLUDED.message_count,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
""", chat_id, json.dumps(serialized_messages), len(messages))
|
|
|
|
self._db_operations += len(saves_to_process)
|
|
if saves_to_process:
|
|
logger.debug(f"Batch saved {len(saves_to_process)} conversations")
|
|
self._chat_list_cache = None
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in batch save worker: {e}")
|
|
|
|
async def add_message(self, chat_id: str, message: BaseMessage) -> None:
|
|
"""Add a single message to conversation (optimized)."""
|
|
current_messages = await self.get_messages(chat_id)
|
|
current_messages.append(message)
|
|
|
|
await self.save_messages(chat_id, current_messages)
|
|
|
|
async def delete_conversation(self, chat_id: str) -> bool:
|
|
"""Delete a conversation by chat_id."""
|
|
try:
|
|
async with self.pool.acquire() as conn:
|
|
result = await conn.execute(
|
|
"DELETE FROM conversations WHERE chat_id = $1",
|
|
chat_id
|
|
)
|
|
self._db_operations += 1
|
|
|
|
self._invalidate_cache(chat_id)
|
|
|
|
return "DELETE 1" in result
|
|
except Exception as e:
|
|
logger.error(f"Error deleting conversation {chat_id}: {e}")
|
|
return False
|
|
|
|
async def list_conversations(self) -> List[str]:
|
|
"""List all conversation IDs with caching."""
|
|
if self._chat_list_cache and not self._chat_list_cache.is_expired():
|
|
self._cache_hits += 1
|
|
return self._chat_list_cache.data
|
|
|
|
async with self.pool.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"SELECT chat_id FROM conversations ORDER BY updated_at DESC"
|
|
)
|
|
self._db_operations += 1
|
|
|
|
chat_ids = [row['chat_id'] for row in rows]
|
|
|
|
self._chat_list_cache = CacheEntry(
|
|
data=chat_ids,
|
|
timestamp=time.time(),
|
|
ttl=60
|
|
)
|
|
self._cache_misses += 1
|
|
|
|
return chat_ids
|
|
|
|
async def store_image(self, image_id: str, image_base64: str) -> None:
|
|
"""Store base64 image data with TTL."""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
INSERT INTO images (image_id, image_data)
|
|
VALUES ($1, $2)
|
|
ON CONFLICT (image_id)
|
|
DO UPDATE SET
|
|
image_data = EXCLUDED.image_data,
|
|
created_at = CURRENT_TIMESTAMP,
|
|
expires_at = CURRENT_TIMESTAMP + INTERVAL '1 hour'
|
|
""", image_id, image_base64)
|
|
self._db_operations += 1
|
|
|
|
self._image_cache[image_id] = CacheEntry(
|
|
data=image_base64,
|
|
timestamp=time.time(),
|
|
ttl=3600
|
|
)
|
|
|
|
async def get_image(self, image_id: str) -> Optional[str]:
|
|
"""Retrieve base64 image data with caching."""
|
|
cache_entry = self._image_cache.get(image_id)
|
|
if cache_entry and not cache_entry.is_expired():
|
|
self._cache_hits += 1
|
|
return cache_entry.data
|
|
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
SELECT image_data FROM images
|
|
WHERE image_id = $1 AND expires_at > CURRENT_TIMESTAMP
|
|
""", image_id)
|
|
self._db_operations += 1
|
|
|
|
if row:
|
|
image_data = row['image_data']
|
|
self._image_cache[image_id] = CacheEntry(
|
|
data=image_data,
|
|
timestamp=time.time(),
|
|
ttl=3600
|
|
)
|
|
self._cache_misses += 1
|
|
return image_data
|
|
|
|
return None
|
|
|
|
async def get_chat_metadata(self, chat_id: str) -> Optional[Dict]:
|
|
"""Get chat metadata with caching."""
|
|
cache_entry = self._metadata_cache.get(chat_id)
|
|
if cache_entry and not cache_entry.is_expired():
|
|
self._cache_hits += 1
|
|
return cache_entry.data
|
|
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"SELECT name, created_at FROM chat_metadata WHERE chat_id = $1",
|
|
chat_id
|
|
)
|
|
self._db_operations += 1
|
|
|
|
if row:
|
|
metadata = {
|
|
"name": row['name'],
|
|
"created_at": row['created_at'].isoformat()
|
|
}
|
|
else:
|
|
metadata = {"name": f"Chat {chat_id[:8]}"}
|
|
|
|
self._metadata_cache[chat_id] = CacheEntry(
|
|
data=metadata,
|
|
timestamp=time.time(),
|
|
ttl=self.cache_ttl
|
|
)
|
|
self._cache_misses += 1
|
|
|
|
return metadata
|
|
|
|
async def set_chat_metadata(self, chat_id: str, name: str) -> None:
|
|
"""Set chat metadata."""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
INSERT INTO chat_metadata (chat_id, name)
|
|
VALUES ($1, $2)
|
|
ON CONFLICT (chat_id)
|
|
DO UPDATE SET
|
|
name = EXCLUDED.name,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
""", chat_id, name)
|
|
self._db_operations += 1
|
|
|
|
self._metadata_cache[chat_id] = CacheEntry(
|
|
data={"name": name},
|
|
timestamp=time.time(),
|
|
ttl=self.cache_ttl
|
|
)
|
|
|
|
async def cleanup_expired_images(self) -> int:
|
|
"""Clean up expired images and return count of deleted images."""
|
|
async with self.pool.acquire() as conn:
|
|
result = await conn.execute(
|
|
"DELETE FROM images WHERE expires_at < CURRENT_TIMESTAMP"
|
|
)
|
|
self._db_operations += 1
|
|
|
|
expired_keys = [
|
|
key for key, entry in self._image_cache.items()
|
|
if entry.is_expired()
|
|
]
|
|
for key in expired_keys:
|
|
del self._image_cache[key]
|
|
|
|
deleted_count = int(result.split()[-1]) if result else 0
|
|
if deleted_count > 0:
|
|
logger.debug(f"Cleaned up {deleted_count} expired images")
|
|
|
|
return deleted_count
|
|
|
|
def get_cache_stats(self) -> Dict[str, Any]:
|
|
"""Get cache performance statistics."""
|
|
total_requests = self._cache_hits + self._cache_misses
|
|
hit_rate = (self._cache_hits / total_requests * 100) if total_requests > 0 else 0
|
|
|
|
return {
|
|
"cache_hits": self._cache_hits,
|
|
"cache_misses": self._cache_misses,
|
|
"hit_rate_percent": round(hit_rate, 2),
|
|
"db_operations": self._db_operations,
|
|
"cached_conversations": len(self._message_cache),
|
|
"cached_metadata": len(self._metadata_cache),
|
|
"cached_images": len(self._image_cache)
|
|
}
|
|
|
|
def load_conversation_history(self, chat_id: str) -> List[Dict]:
|
|
"""Legacy method - converts to async call."""
|
|
import asyncio
|
|
return asyncio.create_task(self._load_conversation_history_dict(chat_id))
|
|
|
|
async def _load_conversation_history_dict(self, chat_id: str) -> List[Dict]:
|
|
"""Load conversation history in dict format for compatibility."""
|
|
messages = await self.get_messages(chat_id)
|
|
return [self._message_to_dict(msg) for msg in messages]
|
|
|
|
def save_conversation_history(self, chat_id: str, messages: List[Dict]) -> None:
|
|
"""Legacy method - converts to async call."""
|
|
import asyncio
|
|
message_objects = [self._dict_to_message(msg) for msg in messages]
|
|
return asyncio.create_task(self.save_messages(chat_id, message_objects))
|