mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-30 05:26:43 +00:00
chore: Regenerate all playbooks
This commit is contained in:
parent
cdd90b989f
commit
27fe116e71
@ -12,9 +12,7 @@
|
||||
- [Step 4. Validate TensorRT-LLM installation](#step-4-validate-tensorrt-llm-installation)
|
||||
- [Step 5. Create cache directory](#step-5-create-cache-directory)
|
||||
- [Step 6. Validate setup with quickstart_advanced](#step-6-validate-setup-with-quickstartadvanced)
|
||||
- [LLM quickstart example](#llm-quickstart-example)
|
||||
- [Step 7. Validate setup with quickstart_multimodal](#step-7-validate-setup-with-quickstartmultimodal)
|
||||
- [VLM quickstart example](#vlm-quickstart-example)
|
||||
- [Step 8. Serve LLM with OpenAI-compatible API](#step-8-serve-llm-with-openai-compatible-api)
|
||||
- [Step 9. Troubleshooting](#step-9-troubleshooting)
|
||||
- [Step 10. Cleanup and rollback](#step-10-cleanup-and-rollback)
|
||||
@ -39,6 +37,15 @@
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
**NVIDIA TensorRT-LLM (TRT-LLM)** is an open-source library for optimizing and accelerating large language model (LLM) inference on NVIDIA GPUs.
|
||||
|
||||
It provides highly efficient kernels, memory management, and parallelism strategies—like tensor, pipeline, and sequence parallelism—so developers can serve LLMs with lower latency and higher throughput.
|
||||
|
||||
TRT-LLM integrates with frameworks like Hugging Face and PyTorch, making it easier to deploy state-of-the-art models at scale.
|
||||
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You'll set up TensorRT-LLM to optimize and deploy large language models on NVIDIA Spark with
|
||||
@ -89,13 +96,17 @@ The following models are supported with TensorRT-LLM on Spark. All listed models
|
||||
| **Llama-4-Scout-17B-16E-Instruct** | NVFP4 | ✅ | `nvidia/Llama-4-Scout-17B-16E-Instruct-FP4` |
|
||||
| **Qwen3-235B-A22B (two Sparks only)** | NVFP4 | ✅ | `nvidia/Qwen3-235B-A22B-FP4` |
|
||||
|
||||
**Note:** You can use the NVFP4 Quantization documentation to generate your own NVFP4-quantized checkpoints for your favorite models. This enables you to take advantage of the performance and memory benefits of NVFP4 quantization even for models not already published by NVIDIA. Note: Not all model architectures are supported for NVFP4 quantization.
|
||||
**Note:** You can use the NVFP4 Quantization documentation to generate your own NVFP4-quantized checkpoints for your favorite models. This enables you to take advantage of the performance and memory benefits of NVFP4 quantization even for models not already published by NVIDIA.
|
||||
|
||||
Reminder: not all model architectures are supported for NVFP4 quantization.
|
||||
|
||||
## Time & risk
|
||||
|
||||
**Duration**: 45-60 minutes for setup and API server deployment
|
||||
|
||||
**Risk level**: Medium - container pulls and model downloads may fail due to network issues
|
||||
**Rollback**: Stop inference servers and remove downloaded models to free resources
|
||||
|
||||
**Rollback**: Stop inference servers and remove downloaded models to free resources.
|
||||
|
||||
## Single Spark
|
||||
|
||||
@ -170,7 +181,7 @@ mkdir -p $HOME/.cache/huggingface/
|
||||
|
||||
This quickstart validates your TensorRT-LLM setup end-to-end by testing model loading, inference engine initialization, and GPU execution with real text generation. It's the fastest way to confirm everything works before starting the inference API server.
|
||||
|
||||
### LLM quickstart example
|
||||
**LLM quickstart example**
|
||||
|
||||
#### Llama 3.1 8B Instruct
|
||||
```bash
|
||||
@ -241,7 +252,7 @@ docker run \
|
||||
```
|
||||
### Step 7. Validate setup with quickstart_multimodal
|
||||
|
||||
### VLM quickstart example
|
||||
**VLM quickstart example**
|
||||
|
||||
This demonstrates vision-language model capabilities by running inference with image understanding. The example uses multimodal inputs to validate both text and vision processing pipelines.
|
||||
|
||||
@ -405,9 +416,7 @@ docker rmi nvcr.io/nvidia/tensorrt-llm/release:spark-single-gpu-dev
|
||||
|
||||
### Step 1. Review Spark clustering documentation
|
||||
|
||||
Go to the official DGX Spark clustering documentation to understand the networking requirements and setup procedures:
|
||||
|
||||
[DGX Spark Clustering Documentation](https://docs.nvidia.com/dgx/dgx-spark/spark-clustering.html)
|
||||
Go to the official DGX Spark clustering documentation to understand the networking requirements and setup procedures:[DGX Spark Clustering Documentation](https://docs.nvidia.com/dgx/dgx-spark/spark-clustering.html)
|
||||
|
||||
Review the networking configuration options and choose the appropriate setup method for your environment.
|
||||
|
||||
|
||||
5
nvidia/txt2kg/assets/.cursor/rules/nextjs.mdc
Normal file
5
nvidia/txt2kg/assets/.cursor/rules/nextjs.mdc
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
Use pnpm as the main package manager
|
||||
description: nextjs projects
|
||||
alwaysApply: false
|
||||
---
|
||||
3
nvidia/txt2kg/assets/.dockerignore
Normal file
3
nvidia/txt2kg/assets/.dockerignore
Normal file
@ -0,0 +1,3 @@
|
||||
node_modules
|
||||
.next
|
||||
.git
|
||||
59
nvidia/txt2kg/assets/.gitignore
vendored
Normal file
59
nvidia/txt2kg/assets/.gitignore
vendored
Normal file
@ -0,0 +1,59 @@
|
||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
|
||||
# next.js
|
||||
/.next/
|
||||
/out/
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# debug
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
|
||||
# env files
|
||||
.env*
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
volumes/
|
||||
frontend/data/
|
||||
scripts/*.pt
|
||||
frontend/node_modules
|
||||
frontend/.next
|
||||
frontend/.env
|
||||
frontend/.env.local
|
||||
frontend/pnpm-lock.yaml
|
||||
frontend/pnpm-workspace.yaml/volumes/etcd/
|
||||
/frontend/node_modules/
|
||||
/frontend/.next/
|
||||
/node_modules/
|
||||
/volumes/etcd/
|
||||
.DS_Store
|
||||
*.log
|
||||
/volumes/
|
||||
video-demo.md
|
||||
|
||||
.trae/
|
||||
.vscode/
|
||||
|
||||
biorxiv_creative_commons/
|
||||
|
||||
# benchmark results
|
||||
benchmark_results/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
21
nvidia/txt2kg/assets/LICENSE
Normal file
21
nvidia/txt2kg/assets/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 NVIDIA Corporation
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
307
nvidia/txt2kg/assets/README.md
Normal file
307
nvidia/txt2kg/assets/README.md
Normal file
@ -0,0 +1,307 @@
|
||||
# NVIDIA txt2kg
|
||||
|
||||
Use the following documentation to learn about NVIDIA txt2kg.
|
||||
- [Overview](#overview)
|
||||
- [Key Features](#key-features)
|
||||
- [Target Audience](#target-audience)
|
||||
- [Software Components](#software-components)
|
||||
- [Technical Diagram](#technical-diagram)
|
||||
- [GPU-Accelerated Visualization](#gpu-accelerated-visualization)
|
||||
- [Minimum System Requirements](#minimum-system-requirements)
|
||||
- [OS Requirements](#os-requirements)
|
||||
- [Deployment Options](#deployment-options)
|
||||
- [Driver Versions](#driver-versions)
|
||||
- [Hardware Requirements](#hardware-requirements)
|
||||
- [Next Steps](#next-steps)
|
||||
- [Deployment Guide](#deployment-guide)
|
||||
- [Standard Deployment](#standard-deployment)
|
||||
- [PyGraphistry GPU-Accelerated Deployment](#pygraphistry-gpu-accelerated-deployment)
|
||||
- [Available Customizations](#available-customizations)
|
||||
- [License](#license)
|
||||
|
||||
## Overview
|
||||
|
||||
This blueprint serves as a reference solution for knowledge graph extraction and querying with Retrieval Augmented Generation (RAG). This txt2kg blueprint extracts knowledge triples from text and constructs a knowledge graph for visualization and querying, creating a more structured form of information retrieval compared to traditional RAG approaches. By leveraging graph databases and entity relationships, this blueprint delivers more contextually rich answers that better represent complex relationships in your data.
|
||||
|
||||
By default, this blueprint leverages **Ollama** for local LLM inference, providing a fully self-contained solution that runs entirely on your own hardware. You can optionally use NVIDIA-hosted models available in the [NVIDIA API Catalog](https://build.nvidia.com) or vLLM for advanced GPU-accelerated inference.
|
||||
|
||||
## Key Features
|
||||
|
||||

|
||||
|
||||
[Watch the demo video](https://drive.google.com/file/d/1a0VG67zx_pGT4WyPTPH2ynefhfy2I0Py/view?usp=sharing)
|
||||
|
||||
- Knowledge triple extraction from text documents
|
||||
- Knowledge graph construction and visualization
|
||||
- **Local-first architecture** with Ollama for LLM inference
|
||||
- Graph-based RAG for more contextual answers
|
||||
- Graph database integration with ArangoDB
|
||||
- Local vector embeddings with Pinecone-compatible storage
|
||||
- GPU-accelerated LLM inference with Ollama and optional vLLM
|
||||
- Sentence Transformers for efficient embedding generation
|
||||
- Interactive knowledge graph visualization with Three.js WebGPU
|
||||
- Optional NVIDIA API integration for cloud-based models
|
||||
- Fully containerized deployment with Docker Compose
|
||||
- Decomposable and customizable
|
||||
|
||||
## Target Audience
|
||||
|
||||
This blueprint is for:
|
||||
|
||||
- **Developers**: Developers who want to quickly set up a local-first Graph-based RAG solution
|
||||
- **Data Scientists**: Data scientists who want to extract structured knowledge from unstructured text
|
||||
- **Enterprise Architects**: Architects seeking to combine knowledge graph and RAG solutions for their organization
|
||||
- **Privacy-Conscious Users**: Organizations requiring fully local, air-gapped deployments
|
||||
- **GPU Researchers**: Researchers wanting to leverage GPU acceleration for LLM inference and graph visualization
|
||||
|
||||
## Software Components
|
||||
|
||||
The following are the default components included in this blueprint:
|
||||
|
||||
* **LLM Inference**
|
||||
* **Ollama** (default): Local LLM inference with GPU acceleration
|
||||
* Default model: `llama3.1:8b`
|
||||
* Supports any Ollama-compatible model
|
||||
* **vLLM** (optional): Advanced GPU-accelerated inference with quantization
|
||||
* Default model: `meta-llama/Llama-3.2-3B-Instruct`
|
||||
* **NVIDIA API** (optional): Cloud-based models via NVIDIA API Catalog
|
||||
* **Vector Database & Embedding**
|
||||
* **SentenceTransformer**: Local embedding generation
|
||||
* Model: `all-MiniLM-L6-v2`
|
||||
* **Pinecone (Local)**: Self-hosted vector storage and similarity search
|
||||
* No cloud API key required
|
||||
* Compatible with Pinecone client libraries
|
||||
* **Knowledge Graph Database**
|
||||
* **ArangoDB**: Graph database for storing knowledge triples (entities and relationships)
|
||||
* Web interface on port 8529
|
||||
* No authentication required (configurable)
|
||||
* **Graph Visualization**
|
||||
* **Three.js WebGPU**: Client-side GPU-accelerated graph rendering
|
||||
* Optional remote WebGPU clustering for large graphs
|
||||
* **Frontend & API**
|
||||
* **Next.js**: Modern React framework with API routes
|
||||
|
||||
## Technical Diagram
|
||||
|
||||
The architecture follows this workflow:
|
||||
1. User uploads documents through the txt2kg web UI
|
||||
2. Documents are processed and chunked for analysis
|
||||
3. **Ollama** extracts knowledge triples (subject-predicate-object) from the text using local LLM inference
|
||||
4. Triples are stored in **ArangoDB** graph database
|
||||
5. **SentenceTransformer** generates entity embeddings
|
||||
6. Embeddings are stored in local **Pinecone** vector database
|
||||
7. User queries are processed through graph-based RAG:
|
||||
- KNN search identifies relevant entities in the vector database
|
||||
- Graph traversal enhances context with entity relationships from ArangoDB
|
||||
- Ollama generates responses using the enriched context
|
||||
8. Results are visualized with **Three.js WebGPU** rendering in the browser
|
||||
|
||||
## GPU-Accelerated LLM Inference
|
||||
|
||||
This blueprint includes **GPU-accelerated LLM inference** with Ollama:
|
||||
|
||||
### Ollama Features
|
||||
- **Fully local inference**: No cloud dependencies or API keys required
|
||||
- **GPU acceleration**: Automatic CUDA support with NVIDIA GPUs
|
||||
- **Multiple model support**: Use any Ollama-compatible model
|
||||
- **Optimized performance**: Flash attention, KV cache optimization, and quantization
|
||||
- **Easy model management**: Pull and switch models with simple commands
|
||||
- **Privacy-first**: All data processing happens on your hardware
|
||||
|
||||
### Default Configuration
|
||||
- Model: `llama3.1:8b`
|
||||
- GPU memory fraction: 0.9 (90% of available VRAM)
|
||||
- Flash attention enabled
|
||||
- Q8_0 KV cache for memory efficiency
|
||||
|
||||
### Using Different Models
|
||||
```bash
|
||||
# Pull a different model
|
||||
docker exec ollama-compose ollama pull llama3.1:70b
|
||||
|
||||
# Update environment variable in docker-compose.yml
|
||||
OLLAMA_MODEL=llama3.1:70b
|
||||
```
|
||||
|
||||
## Minimum System Requirements
|
||||
|
||||
### OS Requirements
|
||||
|
||||
Ubuntu 22.04 or later
|
||||
|
||||
### Deployment Options
|
||||
|
||||
- [Standard Docker Compose](./deploy/compose/docker-compose.yml) (Default - Ollama + ArangoDB + Pinecone)
|
||||
- [vLLM Docker Compose](./deploy/compose/docker-compose.vllm.yml) (Advanced - vLLM for FP8 and NVFP4 quantization)
|
||||
- [Complete Docker Compose](./deploy/compose/docker-compose.complete.yml) (Full stack with MinIO S3)
|
||||
|
||||
### Driver Versions
|
||||
|
||||
- GPU Driver - 530.30.02+
|
||||
- CUDA version - 12.0+
|
||||
### Hardware Requirements
|
||||
|
||||
- **For Ollama LLM inference**:
|
||||
- NVIDIA GPU with CUDA support (GTX 1060 or newer, RTX series recommended)
|
||||
- VRAM requirements depend on model size:
|
||||
- 8B models: 6-8GB VRAM
|
||||
- 70B models: 48GB+ VRAM (or use quantized versions)
|
||||
- System RAM: 16GB+ recommended
|
||||
- **For vLLM (optional)**:
|
||||
- NVIDIA GPU with Ampere architecture or newer (RTX 30xx+, A100, H100)
|
||||
- Support for FP8 quantization for optimal performance
|
||||
- Similar VRAM requirements as Ollama
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Clone the repository
|
||||
- Install Docker and NVIDIA Container Toolkit
|
||||
- Deploy with Docker Compose (no API keys required!)
|
||||
- Pull your preferred Ollama model
|
||||
- Upload documents and explore the knowledge graph
|
||||
- Customize for your specific use case
|
||||
|
||||
## Deployment Guide
|
||||
|
||||
### Environment Variables
|
||||
|
||||
**No API keys required for default deployment!** All services run locally.
|
||||
|
||||
The default configuration uses:
|
||||
- Local Ollama (no API key needed)
|
||||
- Local Pinecone (no API key needed)
|
||||
- Local ArangoDB (no authentication by default)
|
||||
- Local SentenceTransformer embeddings
|
||||
|
||||
#### Optional Environment Variables
|
||||
|
||||
```bash
|
||||
# Ollama configuration (optional - defaults are set)
|
||||
OLLAMA_BASE_URL=http://ollama:11434/v1
|
||||
OLLAMA_MODEL=llama3.1:8b
|
||||
|
||||
# NVIDIA API (optional - for cloud models)
|
||||
NVIDIA_API_KEY=your-nvidia-api-key
|
||||
|
||||
# vLLM configuration (optional)
|
||||
VLLM_BASE_URL=http://vllm:8001/v1
|
||||
VLLM_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
```
|
||||
|
||||
### Standard Deployment
|
||||
|
||||
1. **Clone the repository:**
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd txt2kg
|
||||
```
|
||||
|
||||
2. **Start the application:**
|
||||
```bash
|
||||
./start.sh
|
||||
```
|
||||
|
||||
That's it! No configuration needed. The script will:
|
||||
- Start all required services with Docker Compose
|
||||
- Set up ArangoDB database
|
||||
- Initialize local Pinecone vector storage
|
||||
- Launch Ollama with GPU acceleration
|
||||
- Start the Next.js frontend
|
||||
|
||||
3. **Pull an Ollama model (first time only):**
|
||||
```bash
|
||||
docker exec ollama-compose ollama pull llama3.1:8b
|
||||
```
|
||||
|
||||
4. **Access the application:**
|
||||
- **Web UI**: http://localhost:3001
|
||||
- **ArangoDB**: http://localhost:8529 (no authentication required)
|
||||
- **Ollama API**: http://localhost:11434
|
||||
|
||||
### Advanced Deployment Options
|
||||
|
||||
#### Using vLLM for FP8 Quantization
|
||||
|
||||
vLLM provides advanced GPU acceleration with FP8 quantization for smaller memory footprint:
|
||||
|
||||
```bash
|
||||
# Use vLLM compose file
|
||||
docker compose -f deploy/compose/docker-compose.vllm.yml up -d
|
||||
```
|
||||
|
||||
vLLM is recommended for:
|
||||
- Newer NVIDIA GPUs (Ampere architecture or later)
|
||||
- Production deployments requiring maximum throughput
|
||||
- Memory-constrained environments (FP8 uses less VRAM)
|
||||
|
||||
#### GPU Setup Prerequisites
|
||||
|
||||
1. **Install NVIDIA Container Toolkit**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
|
||||
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
|
||||
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
|
||||
|
||||
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
|
||||
sudo systemctl restart docker
|
||||
```
|
||||
|
||||
2. **Verify GPU Access**:
|
||||
```bash
|
||||
docker run --rm --gpus all nvidia/cuda:12.0-base-ubuntu22.04 nvidia-smi
|
||||
```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
**Check Service Logs**:
|
||||
```bash
|
||||
# View all service logs
|
||||
docker compose logs -f
|
||||
|
||||
# View Ollama logs
|
||||
docker compose logs -f ollama
|
||||
|
||||
# View vLLM logs (if using vLLM)
|
||||
docker compose -f deploy/compose/docker-compose.vllm.yml logs -f vllm
|
||||
```
|
||||
|
||||
**GPU Issues**:
|
||||
```bash
|
||||
# Check GPU availability
|
||||
nvidia-smi
|
||||
|
||||
# Verify Docker GPU access
|
||||
docker run --rm --gpus all nvidia/cuda:12.0-base nvidia-smi
|
||||
```
|
||||
|
||||
**Ollama Model Management**:
|
||||
```bash
|
||||
# List available models
|
||||
docker exec ollama-compose ollama list
|
||||
|
||||
# Pull a different model
|
||||
docker exec ollama-compose ollama pull mistral
|
||||
|
||||
# Remove a model to free space
|
||||
docker exec ollama-compose ollama rm llama3.1:8b
|
||||
```
|
||||
|
||||
## Available Customizations
|
||||
|
||||
The following are some of the customizations you can make:
|
||||
|
||||
- **Switch Ollama models**: Use any model from Ollama's library (Llama, Qwen, etc.)
|
||||
- **Modify extraction prompts**: Customize how triples are extracted from text
|
||||
- **Adjust embedding parameters**: Change the SentenceTransformer model
|
||||
- **Implement custom entity relationships**: Define domain-specific relationship types
|
||||
- **Add domain-specific knowledge sources**: Integrate external ontologies or taxonomies
|
||||
- **Configure GPU settings**: Optimize VRAM usage and performance for your hardware
|
||||
- **Switch to vLLM**: Use vLLM for advanced quantization and higher throughput
|
||||
- **Use NVIDIA API**: Connect to cloud models for specific use cases
|
||||
|
||||
## License
|
||||
|
||||
[MIT](LICENSE)
|
||||
|
||||
This is licensed under the MIT License. This project will download and install additional third-party open source software projects and containers.
|
||||
38
nvidia/txt2kg/assets/deploy/README.md
Normal file
38
nvidia/txt2kg/assets/deploy/README.md
Normal file
@ -0,0 +1,38 @@
|
||||
# Deployment Configuration
|
||||
|
||||
This directory contains all deployment-related configuration for the txt2kg project.
|
||||
|
||||
## Structure
|
||||
|
||||
- **compose/**: Docker Compose files for local development and testing
|
||||
- `docker-compose.yml`: Main Docker Compose configuration
|
||||
- `docker-compose.gnn.yml`: Docker Compose configuration for GNN components
|
||||
- `docker-compose.neo4j.yml`: Docker Compose configuration for Neo4j
|
||||
|
||||
- **docker/**: Docker-related files
|
||||
- Dockerfile
|
||||
- Initialization scripts for services
|
||||
|
||||
- **services/**: Containerized services
|
||||
- **gnn_model/**: Graph Neural Network model service
|
||||
- **sentence-transformers/**: Sentence transformer service for embeddings
|
||||
|
||||
## Usage
|
||||
|
||||
To start the default services:
|
||||
|
||||
```bash
|
||||
docker-compose -f deploy/compose/docker-compose.yml up
|
||||
```
|
||||
|
||||
To include GNN components:
|
||||
|
||||
```bash
|
||||
docker-compose -f deploy/compose/docker-compose.yml -f deploy/compose/docker-compose.gnn.yml up
|
||||
```
|
||||
|
||||
To include Neo4j:
|
||||
|
||||
```bash
|
||||
docker-compose -f deploy/compose/docker-compose.yml -f deploy/compose/docker-compose.neo4j.yml up
|
||||
```
|
||||
48
nvidia/txt2kg/assets/deploy/app/Dockerfile
Normal file
48
nvidia/txt2kg/assets/deploy/app/Dockerfile
Normal file
@ -0,0 +1,48 @@
|
||||
# Use the official Node.js image from the Docker Hub
|
||||
FROM node:18-slim
|
||||
|
||||
# Set environment variables to avoid interactive prompts
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV NPM_CONFIG_YES=true
|
||||
ENV PNPM_HOME=/pnpm
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install pnpm globally with --force and yes flags
|
||||
RUN npm install -g pnpm --force --yes
|
||||
|
||||
# Copy package files ONLY (for better Docker layer caching)
|
||||
# Copy package.json (required) and pnpm-lock.yaml (optional)
|
||||
COPY ./frontend/package.json ./
|
||||
COPY ./frontend/pnpm-lock.yaml* ./
|
||||
|
||||
# Copy the scripts directory (needed for setup-pinecone)
|
||||
COPY ./scripts/ /scripts/
|
||||
|
||||
# Update the setup-pinecone.js path in package.json
|
||||
RUN sed -i 's|"setup-pinecone": "node ../scripts/setup-pinecone.js"|"setup-pinecone": "node /scripts/setup-pinecone.js"|g' package.json
|
||||
|
||||
# Install project dependencies (this layer will be cached if package files don't change)
|
||||
# Use --no-frozen-lockfile as fallback if lockfile is missing or out of sync
|
||||
RUN pnpm config set auto-install-peers true && \
|
||||
if [ -f pnpm-lock.yaml ]; then \
|
||||
echo "Lock file found, installing with frozen lockfile..." && \
|
||||
(pnpm install --no-optional --frozen-lockfile || pnpm install --no-optional --no-frozen-lockfile); \
|
||||
else \
|
||||
echo "No lock file found, installing without frozen lockfile..." && \
|
||||
pnpm install --no-optional --no-frozen-lockfile; \
|
||||
fi
|
||||
|
||||
# Copy the rest of the frontend files
|
||||
COPY ./frontend/ ./
|
||||
|
||||
# Build the application
|
||||
RUN pnpm build
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 3000
|
||||
|
||||
# Start the application
|
||||
CMD ["pnpm", "start"]
|
||||
50
nvidia/txt2kg/assets/deploy/app/Dockerfile.remote-webgpu
Normal file
50
nvidia/txt2kg/assets/deploy/app/Dockerfile.remote-webgpu
Normal file
@ -0,0 +1,50 @@
|
||||
# Remote WebGPU Clustering Service Dockerfile
|
||||
# Based on NVIDIA PyTorch Geometric container which includes cuGraph/RAPIDS for GPU acceleration
|
||||
|
||||
FROM nvcr.io/nvidia/pyg:25.05-py3
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
wget \
|
||||
git \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies for remote WebGPU service
|
||||
COPY requirements-remote-webgpu.txt .
|
||||
RUN pip install --no-cache-dir -r requirements-remote-webgpu.txt
|
||||
|
||||
# Install additional dependencies for WebRTC streaming
|
||||
RUN pip install --no-cache-dir \
|
||||
opencv-python-headless \
|
||||
plotly \
|
||||
kaleido \
|
||||
Pillow \
|
||||
redis
|
||||
|
||||
# Copy service files
|
||||
COPY remote_webgpu_clustering_service.py .
|
||||
COPY unified_gpu_service.py .
|
||||
COPY local_gpu_viz_service.py .
|
||||
COPY simple_webgpu_test.py .
|
||||
|
||||
# Create directories for temporary files
|
||||
RUN mkdir -p /tmp/webrtc_frames
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONPATH=/app
|
||||
ENV CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8083
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8083/health || exit 1
|
||||
|
||||
# Start the main remote WebGPU clustering service
|
||||
CMD ["python", "remote_webgpu_clustering_service.py"]
|
||||
@ -0,0 +1,22 @@
|
||||
// ArangoDB initialization script to create the txt2kg database
|
||||
// This script is executed automatically when the ArangoDB container starts
|
||||
|
||||
db._createDatabase("txt2kg");
|
||||
console.log("Database 'txt2kg' created successfully!");
|
||||
|
||||
// Optional: Create collections needed by your application
|
||||
// Replace with actual collections you need
|
||||
/*
|
||||
const db = require("@arangodb").db;
|
||||
db._useDatabase("txt2kg");
|
||||
|
||||
if (!db._collection("entities")) {
|
||||
db._createDocumentCollection("entities");
|
||||
console.log("Collection 'entities' created");
|
||||
}
|
||||
|
||||
if (!db._collection("relationships")) {
|
||||
db._createEdgeCollection("relationships");
|
||||
console.log("Collection 'relationships' created");
|
||||
}
|
||||
*/
|
||||
19
nvidia/txt2kg/assets/deploy/app/arangodb-init/init.sh
Executable file
19
nvidia/txt2kg/assets/deploy/app/arangodb-init/init.sh
Executable file
@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Wait for ArangoDB to be ready
|
||||
echo "Waiting for ArangoDB to start..."
|
||||
until curl --silent --fail http://localhost:8529/_api/version > /dev/null; do
|
||||
echo "ArangoDB is unavailable - sleeping"
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "ArangoDB is up - executing initialization script"
|
||||
|
||||
# Run the database creation script
|
||||
arangosh \
|
||||
--server.endpoint tcp://127.0.0.1:8529 \
|
||||
--server.authentication false \
|
||||
--javascript.execute /docker-entrypoint-initdb.d/create-database.js
|
||||
|
||||
echo "Initialization completed"
|
||||
47
nvidia/txt2kg/assets/deploy/app/pinecone-init.sh
Executable file
47
nvidia/txt2kg/assets/deploy/app/pinecone-init.sh
Executable file
@ -0,0 +1,47 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Script to initialize Pinecone index at container startup
|
||||
echo "Initializing Pinecone index..."
|
||||
|
||||
# Wait for the Pinecone service to become available
|
||||
echo "Waiting for Pinecone service to start..."
|
||||
max_attempts=30
|
||||
attempt=1
|
||||
|
||||
while [ $attempt -le $max_attempts ]; do
|
||||
if curl -s --head http://pinecone:5080 > /dev/null; then
|
||||
echo "Pinecone service is up!"
|
||||
break
|
||||
fi
|
||||
echo "Waiting for Pinecone service (attempt $attempt/$max_attempts)..."
|
||||
attempt=$((attempt + 1))
|
||||
sleep 2
|
||||
done
|
||||
|
||||
if [ $attempt -gt $max_attempts ]; then
|
||||
echo "Timed out waiting for Pinecone service"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create the index directly
|
||||
echo "Creating index 'entity-embeddings'..."
|
||||
curl -v -X POST "http://pinecone:5080/create_index" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "entity-embeddings",
|
||||
"dimension": 384,
|
||||
"metric": "cosine"
|
||||
}'
|
||||
|
||||
# Also try alternate endpoint as fallback
|
||||
echo "Trying alternate endpoint..."
|
||||
curl -v -X POST "http://pinecone:5080/indexes" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Api-Key: pclocal" \
|
||||
-d '{
|
||||
"name": "entity-embeddings",
|
||||
"dimension": 384,
|
||||
"metric": "cosine"
|
||||
}'
|
||||
|
||||
echo "Pinecone initialization complete"
|
||||
140
nvidia/txt2kg/assets/deploy/compose/docker-compose.complete.yml
Normal file
140
nvidia/txt2kg/assets/deploy/compose/docker-compose.complete.yml
Normal file
@ -0,0 +1,140 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
app:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: deploy/app/Dockerfile
|
||||
ports:
|
||||
- '3001:3000'
|
||||
environment:
|
||||
- ARANGODB_URL=http://arangodb:8529
|
||||
- ARANGODB_DB=txt2kg
|
||||
- PINECONE_HOST=entity-embeddings
|
||||
- PINECONE_PORT=5081
|
||||
- PINECONE_API_KEY=pclocal
|
||||
- PINECONE_ENVIRONMENT=local
|
||||
- LANGCHAIN_TRACING_V2=true
|
||||
- SENTENCE_TRANSFORMER_URL=http://sentence-transformers:80
|
||||
- MODEL_NAME=all-MiniLM-L6-v2
|
||||
- GRPC_SSL_CIPHER_SUITES=HIGH+ECDSA:HIGH+aRSA
|
||||
- NODE_TLS_REJECT_UNAUTHORIZED=0
|
||||
# - XAI_API_KEY=${XAI_API_KEY} # xAI integration removed
|
||||
- OLLAMA_BASE_URL=http://ollama:11434/v1
|
||||
- OLLAMA_MODEL=qwen3:1.7b
|
||||
- S3_ENDPOINT=http://minio:9000
|
||||
- S3_REGION=us-east-1
|
||||
- S3_BUCKET=txt2kg
|
||||
- S3_ACCESS_KEY=minioadmin
|
||||
- S3_SECRET_KEY=minioadmin
|
||||
networks:
|
||||
- pinecone-net
|
||||
- s3-net
|
||||
- default
|
||||
depends_on:
|
||||
- arangodb
|
||||
- entity-embeddings
|
||||
- sentence-transformers
|
||||
- minio
|
||||
|
||||
arangodb:
|
||||
image: arangodb:latest
|
||||
ports:
|
||||
- '8529:8529'
|
||||
environment:
|
||||
- ARANGO_NO_AUTH=1
|
||||
volumes:
|
||||
- arangodb_data:/var/lib/arangodb3
|
||||
- arangodb_apps_data:/var/lib/arangodb3-apps
|
||||
|
||||
arangodb-init:
|
||||
image: arangodb:latest
|
||||
depends_on:
|
||||
arangodb:
|
||||
condition: service_started
|
||||
restart: on-failure
|
||||
entrypoint: >
|
||||
sh -c "
|
||||
echo 'Waiting for ArangoDB to start...' &&
|
||||
sleep 10 &&
|
||||
echo 'Creating txt2kg database...' &&
|
||||
arangosh --server.endpoint tcp://arangodb:8529 --server.authentication false --javascript.execute-string 'try { db._createDatabase(\"txt2kg\"); console.log(\"Database txt2kg created successfully!\"); } catch(e) { if(e.message.includes(\"duplicate\")) { console.log(\"Database txt2kg already exists\"); } else { throw e; } }'
|
||||
"
|
||||
|
||||
entity-embeddings:
|
||||
image: ghcr.io/pinecone-io/pinecone-index:latest
|
||||
container_name: entity-embeddings
|
||||
environment:
|
||||
PORT: 5081
|
||||
INDEX_TYPE: serverless
|
||||
VECTOR_TYPE: dense
|
||||
DIMENSION: 384
|
||||
METRIC: cosine
|
||||
INDEX_NAME: entity-embeddings
|
||||
ports:
|
||||
- "5081:5081"
|
||||
platform: linux/amd64
|
||||
networks:
|
||||
- pinecone-net
|
||||
restart: unless-stopped
|
||||
|
||||
sentence-transformers:
|
||||
build:
|
||||
context: ../../deploy/services/sentence-transformers
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- '8000:80'
|
||||
environment:
|
||||
- MODEL_NAME=all-MiniLM-L6-v2
|
||||
networks:
|
||||
- default
|
||||
|
||||
# MinIO S3-compatible storage
|
||||
minio:
|
||||
image: minio/minio:latest
|
||||
container_name: txt2kg-minio
|
||||
ports:
|
||||
- "9000:9000" # API endpoint
|
||||
- "9001:9001" # Web console
|
||||
environment:
|
||||
- MINIO_ROOT_USER=minioadmin
|
||||
- MINIO_ROOT_PASSWORD=minioadmin
|
||||
volumes:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
networks:
|
||||
- s3-net
|
||||
- default
|
||||
|
||||
createbucket:
|
||||
image: minio/mc
|
||||
depends_on:
|
||||
- minio
|
||||
entrypoint: >
|
||||
/bin/sh -c "
|
||||
sleep 5;
|
||||
/usr/bin/mc config host add myminio http://minio:9000 minioadmin minioadmin;
|
||||
/usr/bin/mc mb myminio/txt2kg;
|
||||
/usr/bin/mc policy set public myminio/txt2kg;
|
||||
exit 0;
|
||||
"
|
||||
networks:
|
||||
- s3-net
|
||||
|
||||
volumes:
|
||||
arangodb_data:
|
||||
arangodb_apps_data:
|
||||
minio_data:
|
||||
|
||||
networks:
|
||||
pinecone-net:
|
||||
name: pinecone
|
||||
s3-net:
|
||||
name: s3-network
|
||||
default:
|
||||
driver: bridge
|
||||
137
nvidia/txt2kg/assets/deploy/compose/docker-compose.vllm.yml
Normal file
137
nvidia/txt2kg/assets/deploy/compose/docker-compose.vllm.yml
Normal file
@ -0,0 +1,137 @@
|
||||
services:
|
||||
app:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: deploy/app/Dockerfile
|
||||
ports:
|
||||
- '3001:3000'
|
||||
environment:
|
||||
- ARANGODB_URL=http://arangodb:8529
|
||||
- ARANGODB_DB=txt2kg
|
||||
- PINECONE_HOST=entity-embeddings
|
||||
- PINECONE_PORT=5081
|
||||
- PINECONE_API_KEY=pclocal
|
||||
- PINECONE_ENVIRONMENT=local
|
||||
- LANGCHAIN_TRACING_V2=true
|
||||
- SENTENCE_TRANSFORMER_URL=http://sentence-transformers:80
|
||||
- MODEL_NAME=all-MiniLM-L6-v2
|
||||
- GRPC_SSL_CIPHER_SUITES=HIGH+ECDSA:HIGH+aRSA
|
||||
- NODE_TLS_REJECT_UNAUTHORIZED=0
|
||||
- OLLAMA_BASE_URL=http://ollama:11434/v1
|
||||
- OLLAMA_MODEL=qwen3:1.7b
|
||||
- VLLM_BASE_URL=http://vllm:8001/v1
|
||||
- VLLM_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
- REMOTE_WEBGPU_SERVICE_URL=http://txt2kg-remote-webgpu:8083
|
||||
networks:
|
||||
- pinecone-net
|
||||
- default
|
||||
- txt2kg-network
|
||||
depends_on:
|
||||
- arangodb
|
||||
- entity-embeddings
|
||||
- sentence-transformers
|
||||
- vllm
|
||||
arangodb:
|
||||
image: arangodb:latest
|
||||
ports:
|
||||
- '8529:8529'
|
||||
environment:
|
||||
- ARANGO_NO_AUTH=1
|
||||
volumes:
|
||||
- arangodb_data:/var/lib/arangodb3
|
||||
- arangodb_apps_data:/var/lib/arangodb3-apps
|
||||
arangodb-init:
|
||||
image: arangodb:latest
|
||||
depends_on:
|
||||
arangodb:
|
||||
condition: service_started
|
||||
restart: on-failure
|
||||
entrypoint: >
|
||||
sh -c "
|
||||
echo 'Waiting for ArangoDB to start...' &&
|
||||
sleep 10 &&
|
||||
echo 'Creating txt2kg database...' &&
|
||||
arangosh --server.endpoint tcp://arangodb:8529 --server.authentication false --javascript.execute-string 'try { db._createDatabase(\"txt2kg\"); console.log(\"Database txt2kg created successfully!\"); } catch(e) { if(e.message.includes(\"duplicate\")) { console.log(\"Database txt2kg already exists\"); } else { throw e; } }'
|
||||
"
|
||||
entity-embeddings:
|
||||
image: ghcr.io/pinecone-io/pinecone-index:latest
|
||||
container_name: entity-embeddings
|
||||
environment:
|
||||
PORT: 5081
|
||||
INDEX_TYPE: serverless
|
||||
VECTOR_TYPE: dense
|
||||
DIMENSION: 384
|
||||
METRIC: cosine
|
||||
INDEX_NAME: entity-embeddings
|
||||
ports:
|
||||
- "5081:5081"
|
||||
platform: linux/amd64
|
||||
networks:
|
||||
- pinecone-net
|
||||
restart: unless-stopped
|
||||
sentence-transformers:
|
||||
build:
|
||||
context: ../../deploy/services/sentence-transformers
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- '8000:80'
|
||||
environment:
|
||||
- MODEL_NAME=all-MiniLM-L6-v2
|
||||
networks:
|
||||
- default
|
||||
vllm:
|
||||
build:
|
||||
context: ../../deploy/services/vllm
|
||||
dockerfile: Dockerfile
|
||||
container_name: vllm-service
|
||||
ports:
|
||||
- '8001:8001'
|
||||
environment:
|
||||
# Model configuration
|
||||
- VLLM_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
- VLLM_TENSOR_PARALLEL_SIZE=1
|
||||
- VLLM_MAX_MODEL_LEN=4096
|
||||
- VLLM_GPU_MEMORY_UTILIZATION=0.9
|
||||
# NVfp4 quantization settings
|
||||
- VLLM_QUANTIZATION=fp8
|
||||
- VLLM_KV_CACHE_DTYPE=fp8
|
||||
# Service configuration
|
||||
- VLLM_PORT=8001
|
||||
- VLLM_HOST=0.0.0.0
|
||||
# Performance tuning
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
- NCCL_DEBUG=INFO
|
||||
volumes:
|
||||
- vllm_models:/app/models
|
||||
- /tmp:/tmp
|
||||
# Mount model cache for faster startup
|
||||
- ~/.cache/huggingface:/root/.cache/huggingface
|
||||
networks:
|
||||
- default
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8001/v1/models"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 120s # Longer start period for model loading
|
||||
|
||||
volumes:
|
||||
arangodb_data:
|
||||
arangodb_apps_data:
|
||||
vllm_models:
|
||||
|
||||
networks:
|
||||
pinecone-net:
|
||||
name: pinecone
|
||||
default:
|
||||
driver: bridge
|
||||
txt2kg-network:
|
||||
driver: bridge
|
||||
168
nvidia/txt2kg/assets/deploy/compose/docker-compose.yml
Normal file
168
nvidia/txt2kg/assets/deploy/compose/docker-compose.yml
Normal file
@ -0,0 +1,168 @@
|
||||
services:
|
||||
app:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: deploy/app/Dockerfile
|
||||
ports:
|
||||
- '3001:3000'
|
||||
environment:
|
||||
- ARANGODB_URL=http://arangodb:8529
|
||||
- ARANGODB_DB=txt2kg
|
||||
- PINECONE_HOST=entity-embeddings
|
||||
- PINECONE_PORT=5081
|
||||
- PINECONE_API_KEY=pclocal
|
||||
- PINECONE_ENVIRONMENT=local
|
||||
- LANGCHAIN_TRACING_V2=true
|
||||
- SENTENCE_TRANSFORMER_URL=http://sentence-transformers:80
|
||||
- MODEL_NAME=all-MiniLM-L6-v2
|
||||
- GRPC_SSL_CIPHER_SUITES=HIGH+ECDSA:HIGH+aRSA
|
||||
- NODE_TLS_REJECT_UNAUTHORIZED=0
|
||||
# - XAI_API_KEY=${XAI_API_KEY} # xAI integration removed
|
||||
- OLLAMA_BASE_URL=http://ollama:11434/v1
|
||||
- OLLAMA_MODEL=llama3.1:8b
|
||||
- VLLM_BASE_URL=http://vllm:8001/v1
|
||||
- VLLM_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
- REMOTE_WEBGPU_SERVICE_URL=http://txt2kg-remote-webgpu:8083
|
||||
# Node.js timeout configurations for large model processing
|
||||
- NODE_OPTIONS=--max-http-header-size=80000
|
||||
- UV_THREADPOOL_SIZE=128
|
||||
- HTTP_TIMEOUT=1800000
|
||||
- REQUEST_TIMEOUT=1800000
|
||||
networks:
|
||||
- pinecone-net
|
||||
- default
|
||||
- txt2kg-network
|
||||
arangodb:
|
||||
image: arangodb:latest
|
||||
ports:
|
||||
- '8529:8529'
|
||||
environment:
|
||||
- ARANGO_NO_AUTH=1
|
||||
volumes:
|
||||
- arangodb_data:/var/lib/arangodb3
|
||||
- arangodb_apps_data:/var/lib/arangodb3-apps
|
||||
arangodb-init:
|
||||
image: arangodb:latest
|
||||
depends_on:
|
||||
arangodb:
|
||||
condition: service_started
|
||||
restart: on-failure
|
||||
entrypoint: >
|
||||
sh -c "
|
||||
echo 'Waiting for ArangoDB to start...' &&
|
||||
sleep 10 &&
|
||||
echo 'Creating txt2kg database...' &&
|
||||
arangosh --server.endpoint tcp://arangodb:8529 --server.authentication false --javascript.execute-string 'try { db._createDatabase(\"txt2kg\"); console.log(\"Database txt2kg created successfully!\"); } catch(e) { if(e.message.includes(\"duplicate\")) { console.log(\"Database txt2kg already exists\"); } else { throw e; } }'
|
||||
"
|
||||
entity-embeddings:
|
||||
image: ghcr.io/pinecone-io/pinecone-index:latest
|
||||
container_name: entity-embeddings
|
||||
environment:
|
||||
PORT: 5081
|
||||
INDEX_TYPE: serverless
|
||||
VECTOR_TYPE: dense
|
||||
DIMENSION: 384
|
||||
METRIC: cosine
|
||||
INDEX_NAME: entity-embeddings
|
||||
ports:
|
||||
- "5081:5081"
|
||||
platform: linux/amd64
|
||||
networks:
|
||||
- pinecone-net
|
||||
restart: unless-stopped
|
||||
sentence-transformers:
|
||||
build:
|
||||
context: ../../deploy/services/sentence-transformers
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- '8000:80'
|
||||
environment:
|
||||
- MODEL_NAME=all-MiniLM-L6-v2
|
||||
networks:
|
||||
- default
|
||||
ollama:
|
||||
build:
|
||||
context: ../services/ollama
|
||||
dockerfile: Dockerfile
|
||||
image: ollama-custom:latest
|
||||
container_name: ollama-compose
|
||||
ports:
|
||||
- '11434:11434'
|
||||
volumes:
|
||||
- ollama_data:/root/.ollama
|
||||
environment:
|
||||
- OLLAMA_FLASH_ATTENTION=1 # Enable flash attention for better performance
|
||||
- OLLAMA_KEEP_ALIVE=30m # Keep models loaded for 30 minutes
|
||||
- OLLAMA_CUDA=1 # Enable CUDA acceleration
|
||||
- OLLAMA_LLM_LIBRARY=cuda # Use CUDA library for LLM operations
|
||||
- OLLAMA_NUM_PARALLEL=1 # Process one request at a time for 70B models
|
||||
- OLLAMA_MAX_LOADED_MODELS=1 # Load only one model at a time to avoid VRAM contention
|
||||
- OLLAMA_KV_CACHE_TYPE=q8_0 # Reduce KV cache VRAM usage with minimal performance impact
|
||||
- OLLAMA_GPU_LAYERS=999 # Use maximum GPU layers
|
||||
- OLLAMA_GPU_MEMORY_FRACTION=0.9 # Use 90% of GPU memory
|
||||
- CUDA_VISIBLE_DEVICES=0 # Use GPU 0 (change to 'all' for multi-GPU)
|
||||
networks:
|
||||
- default
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:11434/api/tags"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
vllm:
|
||||
build:
|
||||
context: ../../deploy/services/vllm
|
||||
dockerfile: Dockerfile
|
||||
container_name: vllm-service
|
||||
ports:
|
||||
- '8001:8001'
|
||||
environment:
|
||||
- VLLM_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
- VLLM_TENSOR_PARALLEL_SIZE=1
|
||||
- VLLM_MAX_MODEL_LEN=4096
|
||||
- VLLM_GPU_MEMORY_UTILIZATION=0.9
|
||||
- VLLM_QUANTIZATION=fp8
|
||||
- VLLM_KV_CACHE_DTYPE=fp8
|
||||
- VLLM_PORT=8001
|
||||
- VLLM_HOST=0.0.0.0
|
||||
volumes:
|
||||
- vllm_models:/app/models
|
||||
- /tmp:/tmp
|
||||
networks:
|
||||
- default
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8001/v1/models"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
volumes:
|
||||
arangodb_data:
|
||||
arangodb_apps_data:
|
||||
ollama_data:
|
||||
vllm_models:
|
||||
|
||||
networks:
|
||||
pinecone-net:
|
||||
name: pinecone
|
||||
default:
|
||||
driver: bridge
|
||||
txt2kg-network:
|
||||
driver: bridge
|
||||
26
nvidia/txt2kg/assets/deploy/services/gnn_model/Dockerfile
Normal file
26
nvidia/txt2kg/assets/deploy/services/gnn_model/Dockerfile
Normal file
@ -0,0 +1,26 @@
|
||||
FROM nvcr.io/nvidia/pyg:25.03-py3
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Flask and other required packages
|
||||
RUN pip install --no-cache-dir \
|
||||
flask==2.0.1 \
|
||||
gunicorn==20.1.0 \
|
||||
tqdm
|
||||
|
||||
# Create model directory
|
||||
RUN mkdir -p /app/models
|
||||
|
||||
# Copy application code
|
||||
COPY services/gnn_model/app.py /app/
|
||||
|
||||
# Set environment variables
|
||||
ENV MODEL_PATH=/app/models/tech-qa-model.pt
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV FLASK_APP=app.py
|
||||
|
||||
# Expose the port
|
||||
EXPOSE 5000
|
||||
|
||||
# Run the service with gunicorn
|
||||
CMD ["gunicorn", "--bind", "0.0.0.0:5000", "app:app"]
|
||||
95
nvidia/txt2kg/assets/deploy/services/gnn_model/README.md
Normal file
95
nvidia/txt2kg/assets/deploy/services/gnn_model/README.md
Normal file
@ -0,0 +1,95 @@
|
||||
# GNN Model Service
|
||||
|
||||
This service provides a REST API for serving predictions from a Graph Neural Network (GNN) model trained to enhance RAG (Retrieval Augmented Generation) performance. It allows comparing GNN-based knowledge graph retrieval with traditional RAG approaches.
|
||||
|
||||
## Overview
|
||||
|
||||
The service exposes a simple API to:
|
||||
- Load a pre-trained GNN model that combines graph structures with language models
|
||||
- Process queries by incorporating graph-structured knowledge
|
||||
- Return predictions that leverage both text and graph relationships
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker and Docker Compose
|
||||
- The trained model file (created using `train_export.py`)
|
||||
|
||||
### Running the Service
|
||||
|
||||
The service is included in the main docker-compose configuration. Simply run:
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
This will start the GNN model service along with other services in the system.
|
||||
|
||||
## Training the Model
|
||||
|
||||
Before using the service, you need to train the GNN model:
|
||||
|
||||
```bash
|
||||
# Create the models directory if it doesn't exist
|
||||
mkdir -p models
|
||||
|
||||
# Run the training script
|
||||
python deploy/services/gnn_model/train_export.py --output_dir models
|
||||
```
|
||||
|
||||
This will create the `tech-qa-model.pt` file in the models directory, which the service will load.
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Health Check
|
||||
|
||||
```
|
||||
GET /health
|
||||
```
|
||||
|
||||
Returns the health status of the service.
|
||||
|
||||
### Prediction
|
||||
|
||||
```
|
||||
POST /predict
|
||||
```
|
||||
|
||||
Request body:
|
||||
```json
|
||||
{
|
||||
"question": "Your question here",
|
||||
"context": "Retrieved context information"
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"question": "Your question here",
|
||||
"answer": "The generated answer"
|
||||
}
|
||||
```
|
||||
|
||||
## Using the Client Example
|
||||
|
||||
A simple client script is provided to test the service:
|
||||
|
||||
```bash
|
||||
python deploy/services/gnn_model/client_example.py --question "What is the capital of France?" --context "France is a country in Western Europe. Its capital is Paris, which is known for the Eiffel Tower."
|
||||
```
|
||||
|
||||
This script also includes a placeholder for comparing the GNN-based approach with a traditional RAG approach.
|
||||
|
||||
## Architecture
|
||||
|
||||
The GNN model service uses:
|
||||
- A Graph Attention Network (GAT) to process graph structured data
|
||||
- A Language Model (LLM) to generate answers
|
||||
- A combined architecture (GRetriever) that leverages both components
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current implementation requires graph construction to be handled separately
|
||||
- The `create_graph_from_text` function in the service is a placeholder that needs implementation based on your specific graph construction approach
|
||||
114
nvidia/txt2kg/assets/deploy/services/gnn_model/app.py
Normal file
114
nvidia/txt2kg/assets/deploy/services/gnn_model/app.py
Normal file
@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import torch
|
||||
from flask import Flask, request, jsonify
|
||||
import torch_geometric
|
||||
from torch_geometric.nn import GAT, LLM, GRetriever
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Constants
|
||||
MODEL_PATH = os.environ.get('MODEL_PATH', '/app/models/tech-qa-model.pt')
|
||||
LLM_GENERATOR_NAME = os.environ.get('LLM_GENERATOR_NAME', 'meta-llama/Meta-Llama-3.1-8B-Instruct')
|
||||
GNN_HID_CHANNELS = int(os.environ.get('GNN_HID_CHANNELS', '1024'))
|
||||
GNN_LAYERS = int(os.environ.get('GNN_LAYERS', '4'))
|
||||
|
||||
# Prompt template for questions
|
||||
prompt_template = """Answer this question based on retrieved contexts. Just give the answer without explanation.
|
||||
[QUESTION] {question} [END_QUESTION]
|
||||
[RETRIEVED_CONTEXTS] {context} [END_RETRIEVED_CONTEXTS]
|
||||
Answer: """
|
||||
|
||||
# Load the model
|
||||
def load_model():
|
||||
print(f"Loading model from {MODEL_PATH}")
|
||||
|
||||
# Create the GNN component
|
||||
gnn = GAT(in_channels=768, hidden_channels=GNN_HID_CHANNELS,
|
||||
out_channels=1024, num_layers=GNN_LAYERS, heads=4)
|
||||
|
||||
# Create the LLM component
|
||||
llm = LLM(model_name=LLM_GENERATOR_NAME)
|
||||
|
||||
# Create the GRetriever model
|
||||
model = GRetriever(llm=llm, gnn=gnn)
|
||||
|
||||
# Load trained weights
|
||||
if os.path.exists(MODEL_PATH):
|
||||
state_dict = torch.load(MODEL_PATH, weights_only=True)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
print("Model loaded successfully")
|
||||
else:
|
||||
print(f"WARNING: Model file not found at {MODEL_PATH}. Using untrained model.")
|
||||
|
||||
return model
|
||||
|
||||
# Initialize model
|
||||
model = None
|
||||
|
||||
@app.before_first_request
|
||||
def initialize():
|
||||
global model
|
||||
model = load_model()
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health_check():
|
||||
return jsonify({"status": "healthy"})
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
if not request.is_json:
|
||||
return jsonify({"error": "Request must be JSON"}), 400
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'question' not in data:
|
||||
return jsonify({"error": "Question is required"}), 400
|
||||
|
||||
if 'context' not in data:
|
||||
return jsonify({"error": "Context is required"}), 400
|
||||
|
||||
question = data['question']
|
||||
context = data['context']
|
||||
|
||||
# Format the question with context using the prompt template
|
||||
formatted_question = prompt_template.format(question=question, context=context)
|
||||
|
||||
# Prepare input for the model
|
||||
# Note: In a real implementation, you'd need to convert text to graph structure
|
||||
# Here we're assuming a simplified interface for demonstration
|
||||
try:
|
||||
# Create a PyTorch Geometric Data object
|
||||
# This is simplified and would need to be adapted to your actual graph structure
|
||||
graph_data = create_graph_from_text(context)
|
||||
|
||||
# Generate prediction
|
||||
with torch.no_grad():
|
||||
prediction = model.generate(
|
||||
input_question=[formatted_question],
|
||||
input_graph=graph_data
|
||||
)[0] # Get first prediction since we're processing one sample
|
||||
|
||||
return jsonify({
|
||||
"question": question,
|
||||
"answer": prediction
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
def create_graph_from_text(text):
|
||||
"""
|
||||
Convert text to a graph structure for the GNN.
|
||||
This is a placeholder - you'll need to implement the actual conversion
|
||||
based on your specific graph construction approach.
|
||||
"""
|
||||
# This would need to be implemented based on how your graphs are constructed
|
||||
# For now, return a dummy graph
|
||||
raise NotImplementedError("Graph creation from text needs to be implemented")
|
||||
|
||||
if __name__ == '__main__':
|
||||
port = int(os.environ.get('PORT', 5000))
|
||||
app.run(host='0.0.0.0', port=port)
|
||||
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import requests
|
||||
import json
|
||||
import argparse
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Client for GNN Model Service')
|
||||
parser.add_argument('--url', type=str, default='http://localhost:5000',
|
||||
help='URL of the GNN model service')
|
||||
parser.add_argument('--question', type=str, required=True,
|
||||
help='Question to ask')
|
||||
parser.add_argument('--context', type=str, required=True,
|
||||
help='Context information to provide')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def query_gnn_model(url, question, context):
|
||||
"""
|
||||
Query the GNN model service with a question and context
|
||||
"""
|
||||
endpoint = f"{url}/predict"
|
||||
|
||||
payload = {
|
||||
"question": question,
|
||||
"context": context
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(endpoint, json=payload, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error connecting to GNN service: {e}")
|
||||
return None
|
||||
|
||||
def query_rag_model(question, context):
|
||||
"""
|
||||
Simple Pure RAG approach for comparison
|
||||
This is a placeholder - in a real implementation, you would have a separate RAG service
|
||||
or use a local LLM with context insertion
|
||||
"""
|
||||
# This would typically call an API or use a local LLM
|
||||
print("Note: This is a placeholder for a Pure RAG implementation")
|
||||
return {
|
||||
"question": question,
|
||||
"answer": "Placeholder RAG answer. Implement real RAG for comparison."
|
||||
}
|
||||
|
||||
def compare_approaches(gnn_result, rag_result):
|
||||
"""
|
||||
Compare the results from GNN and Pure RAG approaches
|
||||
"""
|
||||
print("\n----- COMPARISON -----")
|
||||
print(f"Question: {gnn_result['question']}")
|
||||
print(f"GNN Answer: {gnn_result['answer']}")
|
||||
print(f"RAG Answer: {rag_result['answer']}")
|
||||
print("----------------------\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
print(f"Querying GNN model at {args.url}...")
|
||||
gnn_result = query_gnn_model(args.url, args.question, args.context)
|
||||
|
||||
if gnn_result:
|
||||
print("GNN Query successful!")
|
||||
|
||||
# Get RAG result for comparison
|
||||
rag_result = query_rag_model(args.question, args.context)
|
||||
|
||||
# Compare the approaches
|
||||
compare_approaches(gnn_result, rag_result)
|
||||
else:
|
||||
print("Failed to get response from GNN model service.")
|
||||
164
nvidia/txt2kg/assets/deploy/services/gnn_model/train_export.py
Normal file
164
nvidia/txt2kg/assets/deploy/services/gnn_model/train_export.py
Normal file
@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.loader import DataLoader
|
||||
from torch_geometric.nn import GAT, LLM, GRetriever
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train and export GNN model for service')
|
||||
parser.add_argument('--dataset_file', type=str, default='tech_qa.pt', help='Path to load dataset')
|
||||
parser.add_argument('--output_dir', type=str, default='models', help='Directory to save model')
|
||||
parser.add_argument('--model_save_path', type=str, default='tech-qa-model.pt', help='Model file name')
|
||||
parser.add_argument('--gnn_hidden_channels', type=int, default=1024, help='Hidden channels for GNN')
|
||||
parser.add_argument('--num_gnn_layers', type=int, default=4, help='Number of GNN layers')
|
||||
parser.add_argument('--llm_generator_name', type=str, default='meta-llama/Meta-Llama-3.1-8B-Instruct',
|
||||
help='LLM to use for generation')
|
||||
parser.add_argument('--epochs', type=int, default=2, help='Number of training epochs')
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='Training batch size')
|
||||
parser.add_argument('--eval_batch_size', type=int, default=2, help='Evaluation batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def load_dataset(dataset_path):
|
||||
"""
|
||||
Load preprocessed dataset from file
|
||||
"""
|
||||
if not os.path.exists(dataset_path):
|
||||
raise FileNotFoundError(f"Dataset file not found at {dataset_path}. Please run preprocess_data.py first.")
|
||||
|
||||
print(f"Loading dataset from {dataset_path}...")
|
||||
data_lists = torch.load(dataset_path, weights_only=False)
|
||||
print("Dataset loaded successfully!")
|
||||
print(f"Train set size: {len(data_lists['train'])}")
|
||||
print(f"Validation set size: {len(data_lists['validation'])}")
|
||||
print(f"Test set size: {len(data_lists['test'])}")
|
||||
|
||||
return data_lists
|
||||
|
||||
def train_model(args, data_lists):
|
||||
"""
|
||||
Train the GNN model
|
||||
"""
|
||||
batch_size = args.batch_size
|
||||
eval_batch_size = args.eval_batch_size
|
||||
hidden_channels = args.gnn_hidden_channels
|
||||
num_gnn_layers = args.num_gnn_layers
|
||||
|
||||
train_loader = DataLoader(data_lists["train"], batch_size=batch_size,
|
||||
drop_last=True, pin_memory=True, shuffle=True)
|
||||
val_loader = DataLoader(data_lists["validation"], batch_size=eval_batch_size,
|
||||
drop_last=False, pin_memory=True, shuffle=False)
|
||||
|
||||
# Create GNN model
|
||||
gnn = GAT(in_channels=768, hidden_channels=hidden_channels,
|
||||
out_channels=1024, num_layers=num_gnn_layers, heads=4)
|
||||
|
||||
# Create LLM model
|
||||
llm = LLM(model_name=args.llm_generator_name)
|
||||
|
||||
# Create the combined GRetriever model
|
||||
model = GRetriever(llm=llm, gnn=gnn)
|
||||
|
||||
# Training setup
|
||||
params = [p for _, p in model.named_parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([{
|
||||
'params': params, 'lr': args.lr, 'weight_decay': 0.05
|
||||
}], betas=(0.9, 0.95))
|
||||
|
||||
# Prompt template for questions
|
||||
prompt_template = """Answer this question based on retrieved contexts. Just give the answer without explanation.
|
||||
[QUESTION] {question} [END_QUESTION]
|
||||
[RETRIEVED_CONTEXTS] {context} [END_RETRIEVED_CONTEXTS]
|
||||
Answer: """
|
||||
|
||||
# Training loop
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
epoch_loss = 0
|
||||
print(f'Epoch: {epoch + 1}/{args.epochs}')
|
||||
|
||||
for batch in train_loader:
|
||||
new_qs = []
|
||||
for i, q in enumerate(batch["question"]):
|
||||
# insert context
|
||||
new_qs.append(
|
||||
prompt_template.format(question=q, context=batch.text_context[i]))
|
||||
batch.question = new_qs
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = model(
|
||||
input_question=batch.question,
|
||||
input_graph=batch,
|
||||
output_labels=batch.label
|
||||
)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += float(loss)
|
||||
|
||||
avg_train_loss = epoch_loss / len(train_loader)
|
||||
print(f'Train Loss: {avg_train_loss:.4f}')
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
new_qs = []
|
||||
for i, q in enumerate(batch["question"]):
|
||||
# insert context
|
||||
new_qs.append(
|
||||
prompt_template.format(question=q, context=batch.text_context[i]))
|
||||
batch.question = new_qs
|
||||
|
||||
loss = model(
|
||||
input_question=batch.question,
|
||||
input_graph=batch,
|
||||
output_labels=batch.label
|
||||
)
|
||||
val_loss += float(loss)
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
print(f'Validation Loss: {avg_val_loss:.4f}')
|
||||
|
||||
return model
|
||||
|
||||
def save_model(model, save_path):
|
||||
"""
|
||||
Save the trained model
|
||||
"""
|
||||
directory = os.path.dirname(save_path)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
print(f"Saving model to {save_path}")
|
||||
torch.save(model.state_dict(), save_path)
|
||||
print("Model saved successfully!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
import math
|
||||
|
||||
# Set seed for reproducibility
|
||||
seed_everything(50)
|
||||
|
||||
# Parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# Load dataset
|
||||
dataset_path = os.path.join(args.output_dir, args.dataset_file)
|
||||
data_lists = load_dataset(dataset_path)
|
||||
|
||||
# Train model
|
||||
model = train_model(args, data_lists)
|
||||
|
||||
# Save model
|
||||
model_path = os.path.join(args.output_dir, args.model_save_path)
|
||||
save_model(model, model_path)
|
||||
|
||||
print(f"Model has been trained and saved to {model_path}")
|
||||
print("This model can now be used by the GNN service.")
|
||||
40
nvidia/txt2kg/assets/deploy/services/gpu-viz/Dockerfile
Normal file
40
nvidia/txt2kg/assets/deploy/services/gpu-viz/Dockerfile
Normal file
@ -0,0 +1,40 @@
|
||||
# Use latest NVIDIA PyG container which includes cuGraph and graph-related packages
|
||||
FROM nvcr.io/nvidia/pyg:25.08-py3
|
||||
|
||||
# Ensure we're running as root for system package installation
|
||||
USER root
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
wget \
|
||||
git \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first to leverage Docker cache
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the service code
|
||||
COPY unified_gpu_service.py .
|
||||
COPY pygraphistry_service.py .
|
||||
|
||||
# Create a non-root user for security (using a different UID to avoid conflicts)
|
||||
RUN useradd -m -u 1001 appuser && chown -R appuser:appuser /app
|
||||
USER appuser
|
||||
|
||||
# Expose unified service port
|
||||
EXPOSE 8080
|
||||
|
||||
# Health check for unified service
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8080/api/health || exit 1
|
||||
|
||||
# Start unified service
|
||||
CMD ["python", "unified_gpu_service.py"]
|
||||
@ -0,0 +1,305 @@
|
||||
# GPU Rendering Library Options for Remote Visualization
|
||||
|
||||
## 🎯 **Yes! Three.js is Perfect for Adding GPU Rendering**
|
||||
|
||||
Your existing **Three.js v0.176.0** stack is ideal for adding true GPU-accelerated WebGL rendering to the remote service. Here's a comprehensive comparison of options:
|
||||
|
||||
## 🚀 **Option 1: Three.js (Recommended)**
|
||||
|
||||
### **Why Three.js is Perfect**
|
||||
- ✅ **Already in your stack** - Three.js v0.176.0 in package.json
|
||||
- ✅ **Mature WebGL abstraction** - Handles GPU complexity
|
||||
- ✅ **InstancedMesh for performance** - Single draw call for millions of nodes
|
||||
- ✅ **Built-in optimizations** - Frustum culling, LOD, memory management
|
||||
- ✅ **Easy development** - High-level API, good documentation
|
||||
|
||||
### **Three.js GPU Features for Graph Rendering**
|
||||
|
||||
#### **1. InstancedMesh for Mass Node Rendering**
|
||||
```javascript
|
||||
// Single GPU draw call for 100k+ nodes
|
||||
const geometry = new THREE.CircleGeometry(1, 8);
|
||||
const material = new THREE.MeshBasicMaterial({ vertexColors: true });
|
||||
const instancedMesh = new THREE.InstancedMesh(geometry, material, nodeCount);
|
||||
|
||||
// Set position, scale, color for each instance
|
||||
const matrix = new THREE.Matrix4();
|
||||
const color = new THREE.Color();
|
||||
|
||||
nodes.forEach((node, i) => {
|
||||
matrix.makeScale(node.size, node.size, 1);
|
||||
matrix.setPosition(node.x, node.y, 0);
|
||||
instancedMesh.setMatrixAt(i, matrix);
|
||||
|
||||
color.setHex(node.clusterColor);
|
||||
instancedMesh.setColorAt(i, color);
|
||||
});
|
||||
|
||||
// GPU renders all nodes in one call
|
||||
scene.add(instancedMesh);
|
||||
```
|
||||
|
||||
#### **2. BufferGeometry for Edge Performance**
|
||||
```javascript
|
||||
// GPU-optimized edge rendering
|
||||
const positions = new Float32Array(edgeCount * 6);
|
||||
const colors = new Float32Array(edgeCount * 6);
|
||||
|
||||
edges.forEach((edge, i) => {
|
||||
const idx = i * 6;
|
||||
// Source vertex
|
||||
positions[idx] = edge.source.x;
|
||||
positions[idx + 1] = edge.source.y;
|
||||
// Target vertex
|
||||
positions[idx + 3] = edge.target.x;
|
||||
positions[idx + 4] = edge.target.y;
|
||||
});
|
||||
|
||||
const geometry = new THREE.BufferGeometry();
|
||||
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
||||
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
|
||||
|
||||
const lineSegments = new THREE.LineSegments(geometry, material);
|
||||
```
|
||||
|
||||
#### **3. Built-in Performance Optimizations**
|
||||
```javascript
|
||||
// Three.js GPU optimizations
|
||||
renderer.sortObjects = false; // Disable expensive sorting
|
||||
renderer.setPixelRatio(Math.min(devicePixelRatio, 2)); // Limit pixel density
|
||||
|
||||
// Frustum culling (automatic)
|
||||
// Level-of-detail (LOD) support
|
||||
// Automatic geometry merging
|
||||
// GPU texture atlasing
|
||||
```
|
||||
|
||||
### **Performance Comparison**
|
||||
|
||||
| Approach | 10k Nodes | 100k Nodes | 1M Nodes | FPS |
|
||||
|----------|-----------|------------|----------|-----|
|
||||
| **D3.js SVG** | ✅ Good | ❌ Slow | ❌ Unusable | 15fps |
|
||||
| **Three.js Standard** | ✅ Excellent | ✅ Good | ❌ Slow | 45fps |
|
||||
| **Three.js Instanced** | ✅ Excellent | ✅ Excellent | ✅ Good | 60fps |
|
||||
|
||||
## 🔧 **Option 2: deck.gl (For Data-Heavy Visualizations)**
|
||||
|
||||
### **Pros**
|
||||
- ✅ **Built for large datasets** - Optimized for millions of points
|
||||
- ✅ **WebGL2 compute shaders** - True GPU computation
|
||||
- ✅ **Built-in graph layouts** - Force-directed on GPU
|
||||
- ✅ **Excellent performance** - 1M+ nodes at 60fps
|
||||
|
||||
### **Cons**
|
||||
- ❌ **Large bundle size** - Adds ~500KB
|
||||
- ❌ **Complex API** - Steeper learning curve
|
||||
- ❌ **React-focused** - Less suitable for iframe embedding
|
||||
|
||||
```javascript
|
||||
// deck.gl GPU-accelerated approach
|
||||
import { ScatterplotLayer, LineLayer } from '@deck.gl/layers';
|
||||
|
||||
const nodeLayer = new ScatterplotLayer({
|
||||
data: nodes,
|
||||
getPosition: d => [d.x, d.y],
|
||||
getRadius: d => d.size,
|
||||
getFillColor: d => d.color,
|
||||
radiusUnits: 'pixels',
|
||||
// GPU instancing automatically enabled
|
||||
});
|
||||
|
||||
const edgeLayer = new LineLayer({
|
||||
data: edges,
|
||||
getSourcePosition: d => [d.source.x, d.source.y],
|
||||
getTargetPosition: d => [d.target.x, d.target.y],
|
||||
getColor: [100, 100, 100],
|
||||
getWidth: 1
|
||||
});
|
||||
```
|
||||
|
||||
## ⚡ **Option 3: regl (Raw WebGL Performance)**
|
||||
|
||||
### **Pros**
|
||||
- ✅ **Maximum performance** - Direct WebGL access
|
||||
- ✅ **Small bundle** - ~50KB
|
||||
- ✅ **Full control** - Custom shaders, compute pipelines
|
||||
- ✅ **Functional API** - Clean, predictable
|
||||
|
||||
### **Cons**
|
||||
- ❌ **Low-level complexity** - Manual memory management
|
||||
- ❌ **Shader development** - GLSL programming required
|
||||
- ❌ **More development time** - Everything custom
|
||||
|
||||
```javascript
|
||||
// regl direct WebGL approach
|
||||
const drawNodes = regl({
|
||||
vert: `
|
||||
attribute vec2 position;
|
||||
attribute float size;
|
||||
attribute vec3 color;
|
||||
varying vec3 vColor;
|
||||
|
||||
void main() {
|
||||
gl_Position = vec4(position, 0, 1);
|
||||
gl_PointSize = size;
|
||||
vColor = color;
|
||||
}
|
||||
`,
|
||||
|
||||
frag: `
|
||||
precision mediump float;
|
||||
varying vec3 vColor;
|
||||
|
||||
void main() {
|
||||
gl_FragColor = vec4(vColor, 1);
|
||||
}
|
||||
`,
|
||||
|
||||
attributes: {
|
||||
position: nodePositions,
|
||||
size: nodeSizes,
|
||||
color: nodeColors
|
||||
},
|
||||
|
||||
count: nodeCount,
|
||||
primitive: 'points'
|
||||
});
|
||||
```
|
||||
|
||||
## 🎮 **Option 4: WebGPU (Future-Proof)**
|
||||
|
||||
### **Pros**
|
||||
- ✅ **Next-generation API** - Successor to WebGL
|
||||
- ✅ **Compute shaders** - True parallel processing
|
||||
- ✅ **Better performance** - Lower overhead
|
||||
- ✅ **Multi-threading** - Parallel command buffers
|
||||
|
||||
### **Cons**
|
||||
- ❌ **Limited browser support** - Chrome/Edge only (2024)
|
||||
- ❌ **New API** - Rapidly changing specification
|
||||
- ❌ **Complex setup** - More verbose than WebGL
|
||||
|
||||
```javascript
|
||||
// WebGPU approach (future)
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
const device = await adapter.requestDevice();
|
||||
|
||||
const computePipeline = device.createComputePipeline({
|
||||
compute: {
|
||||
module: device.createShaderModule({
|
||||
code: `
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
|
||||
let index = global_id.x;
|
||||
if (index >= arrayLength(&positions)) { return; }
|
||||
|
||||
// GPU-parallel force calculation
|
||||
var force = vec2<f32>(0.0, 0.0);
|
||||
for (var i = 0u; i < arrayLength(&positions); i++) {
|
||||
if (i != index) {
|
||||
let diff = positions[index] - positions[i];
|
||||
let dist = length(diff);
|
||||
force += normalize(diff) * (1.0 / (dist * dist));
|
||||
}
|
||||
}
|
||||
|
||||
velocities[index] += force * 0.01;
|
||||
positions[index] += velocities[index] * 0.1;
|
||||
}
|
||||
`
|
||||
}),
|
||||
entryPoint: 'main'
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
## 🏆 **Recommendation: Three.js Integration**
|
||||
|
||||
### **For Your Use Case, Three.js is Optimal Because:**
|
||||
|
||||
1. **Already Available** - No new dependencies
|
||||
2. **Proven Performance** - Handles 100k+ nodes smoothly
|
||||
3. **Easy Integration** - Replace D3.js rendering with Three.js
|
||||
4. **Maintenance** - Well-documented, stable API
|
||||
5. **Development Speed** - Rapid implementation
|
||||
|
||||
### **Implementation Strategy**
|
||||
|
||||
#### **Phase 1: Basic Three.js WebGL (Week 1)**
|
||||
```python
|
||||
# Enhanced remote service with Three.js
|
||||
def _generate_threejs_html(self, session_data, config):
|
||||
return f"""
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/0.176.0/three.min.js"></script>
|
||||
<script>
|
||||
// Basic Three.js WebGL rendering
|
||||
const renderer = new THREE.WebGLRenderer({{
|
||||
powerPreference: "high-performance"
|
||||
}});
|
||||
const scene = new THREE.Scene();
|
||||
const camera = new THREE.PerspectiveCamera(75, width/height, 0.1, 1000);
|
||||
|
||||
// Render nodes and edges with GPU
|
||||
createNodeVisualization();
|
||||
createEdgeVisualization();
|
||||
</script>
|
||||
"""
|
||||
```
|
||||
|
||||
#### **Phase 2: GPU Optimization (Week 2)**
|
||||
- Add InstancedMesh for node rendering
|
||||
- Implement BufferGeometry for edges
|
||||
- Enable frustum culling and LOD
|
||||
|
||||
#### **Phase 3: Advanced Features (Week 3)**
|
||||
- GPU-based interaction (raycasting)
|
||||
- Smooth camera controls
|
||||
- Real-time layout animation
|
||||
|
||||
### **Expected Performance Improvements**
|
||||
|
||||
| Feature | D3.js SVG | Three.js WebGL | Improvement |
|
||||
|---------|-----------|----------------|-------------|
|
||||
| **50k nodes** | 5 FPS | 60 FPS | **12x faster** |
|
||||
| **Animation** | Choppy | Smooth | **Fluid motion** |
|
||||
| **Memory usage** | 200MB DOM | 50MB GPU | **4x less memory** |
|
||||
| **Interaction** | Laggy | Responsive | **Real-time** |
|
||||
|
||||
## 💡 **Implementation Roadmap**
|
||||
|
||||
### **Step 1: Replace HTML Template**
|
||||
```python
|
||||
# In remote_gpu_rendering_service.py
|
||||
def _generate_interactive_html(self, session_data, config):
|
||||
if config.get('use_webgl', True):
|
||||
return self._generate_threejs_webgl_html(session_data, config)
|
||||
else:
|
||||
return self._generate_d3_svg_html(session_data, config) # Fallback
|
||||
```
|
||||
|
||||
### **Step 2: Add WebGL Configuration**
|
||||
```typescript
|
||||
// In RemoteGPUViewer component
|
||||
const processWithWebGLOptimization = async () => {
|
||||
const config = {
|
||||
use_webgl: nodeCount > 5000,
|
||||
instanced_rendering: nodeCount > 10000,
|
||||
lod_enabled: nodeCount > 25000,
|
||||
render_quality: 'high'
|
||||
};
|
||||
// Process with enhanced GPU service
|
||||
};
|
||||
```
|
||||
|
||||
### **Step 3: Performance Monitoring**
|
||||
```javascript
|
||||
// Built-in Three.js performance monitoring
|
||||
console.log('Render Info:', {
|
||||
triangles: renderer.info.render.triangles,
|
||||
calls: renderer.info.render.calls,
|
||||
geometries: renderer.info.memory.geometries,
|
||||
textures: renderer.info.memory.textures
|
||||
});
|
||||
```
|
||||
|
||||
**Result**: Your remote GPU service will provide **true GPU-accelerated rendering** with minimal development effort by leveraging your existing Three.js stack.
|
||||
@ -0,0 +1,264 @@
|
||||
# JavaScript Library Stack Integration with Remote GPU Rendering
|
||||
|
||||
## 🚀 **Library Architecture Overview**
|
||||
|
||||
Your project leverages a sophisticated JavaScript stack optimized for graph visualization performance:
|
||||
|
||||
### **Core Visualization Libraries**
|
||||
```json
|
||||
{
|
||||
"3d-force-graph": "^1.77.0", // WebGL 3D graph rendering
|
||||
"three": "^0.176.0", // WebGL/WebGPU 3D engine
|
||||
"d3": "^7.9.0", // Data binding & force simulation
|
||||
"@types/d3": "^7.4.3", // TypeScript definitions
|
||||
"@types/three": "^0.175.0" // Three.js TypeScript support
|
||||
}
|
||||
```
|
||||
|
||||
### **Frontend Framework**
|
||||
```json
|
||||
{
|
||||
"next": "15.1.0", // React framework with SSR
|
||||
"react": "^19", // Component architecture
|
||||
"tailwindcss": "^3.4.17" // Utility-first CSS
|
||||
}
|
||||
```
|
||||
|
||||
## 🎯 **Performance Optimization Strategies**
|
||||
|
||||
### **1. Dynamic Import Strategy**
|
||||
|
||||
**Problem:** Large visualization libraries increase initial bundle size
|
||||
**Solution:** Conditional loading based on graph complexity
|
||||
|
||||
```typescript
|
||||
// ForceGraphWrapper.tsx - Dynamic loading pattern
|
||||
const ForceGraph3D = (await import('3d-force-graph')).default;
|
||||
|
||||
// Benefits:
|
||||
// - Reduces initial bundle by ~2MB
|
||||
// - Enables GPU capability detection
|
||||
// - Prevents SSR WebGL conflicts
|
||||
```
|
||||
|
||||
### **2. GPU Capability Detection**
|
||||
|
||||
**Enhanced detection based on your library capabilities:**
|
||||
|
||||
```typescript
|
||||
const shouldUseRemoteRendering = (nodeCount: number) => {
|
||||
const maxWebGLNodes = window.WebGL2RenderingContext ? 50000 : 10000;
|
||||
const maxWebGPUNodes = 'gpu' in navigator ? 100000 : 25000;
|
||||
|
||||
// Three.js geometry memory limits
|
||||
const estimatedMemoryMB = (nodeCount * 64) / (1024 * 1024);
|
||||
const maxClientMemory = hasWebGPU ? 512 : 256; // MB
|
||||
|
||||
return nodeCount > maxWebGLNodes || estimatedMemoryMB > maxClientMemory;
|
||||
};
|
||||
```
|
||||
|
||||
### **3. Library-Specific Optimizations**
|
||||
|
||||
#### **Three.js Renderer Settings**
|
||||
```typescript
|
||||
const optimizeForThreeJS = (nodeCount: number) => ({
|
||||
// Instanced rendering for large graphs
|
||||
instance_rendering: nodeCount > 10000,
|
||||
|
||||
// Texture optimization
|
||||
texture_atlasing: nodeCount > 5000,
|
||||
max_texture_size: nodeCount > 25000 ? 2048 : 1024,
|
||||
|
||||
// Performance culling
|
||||
frustum_culling: nodeCount > 15000,
|
||||
occlusion_culling: nodeCount > 25000,
|
||||
|
||||
// Level-of-detail for distant nodes
|
||||
enable_lod: nodeCount > 25000
|
||||
});
|
||||
```
|
||||
|
||||
#### **D3.js Force Simulation Tuning**
|
||||
```typescript
|
||||
const optimizeForD3 = (nodeCount: number) => ({
|
||||
// Reduced iterations for large graphs
|
||||
physics_iterations: nodeCount > 50000 ? 100 : 300,
|
||||
|
||||
// Faster convergence
|
||||
alpha_decay: nodeCount > 50000 ? 0.05 : 0.02,
|
||||
|
||||
// More damping for stability
|
||||
velocity_decay: nodeCount > 50000 ? 0.6 : 0.4
|
||||
});
|
||||
```
|
||||
|
||||
## 🔧 **Remote GPU Service Integration**
|
||||
|
||||
### **Enhanced HTML Template Generation**
|
||||
|
||||
The remote GPU service now generates HTML compatible with your frontend:
|
||||
|
||||
```python
|
||||
def _generate_interactive_html(self, session_data: dict, config: dict) -> str:
|
||||
html_template = f"""
|
||||
<!-- Using D3.js v7.9.0 consistent with frontend -->
|
||||
<script src="https://d3js.org/d3.v7.min.js"></script>
|
||||
|
||||
<script>
|
||||
// Configuration matching your library versions
|
||||
const config = {{
|
||||
d3_version: "7.9.0", // Match package.json
|
||||
threejs_version: "0.176.0", // Match package.json
|
||||
force_graph_version: "1.77.0", // Match package.json
|
||||
|
||||
// Performance settings based on render quality
|
||||
maxParticles: {settings['particles']},
|
||||
lineWidth: {settings['line_width']},
|
||||
nodeDetail: {settings['node_detail']}
|
||||
}};
|
||||
|
||||
// D3 force simulation with GPU-optimized parameters
|
||||
this.simulation = d3.forceSimulation()
|
||||
.force("link", d3.forceLink().id(d => d.id).distance(60))
|
||||
.force("charge", d3.forceManyBody().strength(-120))
|
||||
.force("center", d3.forceCenter(this.width / 2, this.height / 2))
|
||||
.alphaDecay(0.02)
|
||||
.velocityDecay(0.4);
|
||||
</script>
|
||||
"""
|
||||
```
|
||||
|
||||
### **Frontend Component Integration**
|
||||
|
||||
```typescript
|
||||
// RemoteGPUViewer.tsx - Library-aware processing
|
||||
const processGraphWithLibraryOptimization = async () => {
|
||||
const optimizedConfig = {
|
||||
// Frontend library compatibility
|
||||
d3_version: "7.9.0",
|
||||
threejs_version: "0.176.0",
|
||||
force_graph_version: "1.77.0",
|
||||
|
||||
// WebGL optimization features
|
||||
webgl_features: {
|
||||
instance_rendering: nodeCount > 10000,
|
||||
texture_atlasing: nodeCount > 5000,
|
||||
frustum_culling: nodeCount > 15000
|
||||
},
|
||||
|
||||
// Performance tuning
|
||||
progressive_loading: nodeCount > 25000,
|
||||
gpu_memory_management: true
|
||||
};
|
||||
|
||||
const response = await fetch('/api/render', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ graph_data, config: optimizedConfig })
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
## 📊 **Performance Benchmarks by Library Stack**
|
||||
|
||||
### **Client-Side Rendering Limits**
|
||||
|
||||
| Library Stack | Max Nodes | Memory Usage | Performance |
|
||||
|---------------|-----------|--------------|-------------|
|
||||
| **D3.js + SVG** | 5,000 | ~50MB | Good interaction |
|
||||
| **Three.js + WebGL** | 50,000 | ~256MB | Smooth 60fps |
|
||||
| **Three.js + WebGPU** | 100,000 | ~512MB | GPU-accelerated |
|
||||
| **Remote GPU** | 1M+ | ~100KB transfer | Server-rendered |
|
||||
|
||||
### **Rendering Strategy Decision Tree**
|
||||
|
||||
```typescript
|
||||
const selectRenderingStrategy = (nodeCount: number) => {
|
||||
if (nodeCount < 5000) {
|
||||
return "local_svg"; // D3.js + SVG DOM
|
||||
} else if (nodeCount < 25000) {
|
||||
return "local_webgl"; // Three.js + WebGL
|
||||
} else if (nodeCount < 100000 && hasWebGPU) {
|
||||
return "local_webgpu"; // Three.js + WebGPU
|
||||
} else {
|
||||
return "remote_gpu"; // Remote cuGraph + GPU
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
## 🚀 **Advanced Integration Features**
|
||||
|
||||
### **1. Progressive Loading**
|
||||
```typescript
|
||||
// For graphs >25k nodes, enable progressive loading
|
||||
if (nodeCount > 25000) {
|
||||
config.progressive_loading = true;
|
||||
config.initial_load_size = 10000; // Load first 10k nodes
|
||||
config.batch_size = 5000; // Load 5k at a time
|
||||
}
|
||||
```
|
||||
|
||||
### **2. WebSocket Real-time Updates**
|
||||
```typescript
|
||||
// Real-time parameter updates via WebSocket
|
||||
const updateLayoutAlgorithm = (algorithm: string) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify({
|
||||
type: "update_params",
|
||||
layout_algorithm: algorithm
|
||||
}));
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### **3. Memory-Aware Quality Settings**
|
||||
```typescript
|
||||
const adjustQuality = (availableMemory: number, nodeCount: number) => {
|
||||
if (availableMemory < 256) return "low"; // Mobile devices
|
||||
if (availableMemory < 512) return "medium"; // Standard devices
|
||||
if (nodeCount > 100000) return "high"; // Large graphs
|
||||
return "ultra"; // High-end systems
|
||||
};
|
||||
```
|
||||
|
||||
## 💡 **Best Practices for Your Stack**
|
||||
|
||||
### **1. Bundle Optimization**
|
||||
- Use dynamic imports for 3D libraries
|
||||
- Lazy load based on graph size detection
|
||||
- Implement service worker caching for repeated visualizations
|
||||
|
||||
### **2. Memory Management**
|
||||
```typescript
|
||||
// Cleanup Three.js resources
|
||||
const cleanup = () => {
|
||||
if (graphRef.current) {
|
||||
graphRef.current.scene?.traverse((object) => {
|
||||
if (object.geometry) object.geometry.dispose();
|
||||
if (object.material) object.material.dispose();
|
||||
});
|
||||
graphRef.current.renderer?.dispose();
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### **3. Responsive Rendering**
|
||||
```typescript
|
||||
// Adjust complexity based on device capabilities
|
||||
const getDeviceCapabilities = () => ({
|
||||
memory: (navigator as any).deviceMemory || 4, // GB
|
||||
cores: navigator.hardwareConcurrency || 4,
|
||||
gpu: 'gpu' in navigator ? 'webgpu' : 'webgl'
|
||||
});
|
||||
```
|
||||
|
||||
## 🎯 **Integration Results**
|
||||
|
||||
✅ **Seamless fallback** between local and remote rendering
|
||||
✅ **Library version consistency** across client and server
|
||||
✅ **Memory-aware quality adjustment** based on device capabilities
|
||||
✅ **Progressive enhancement** from SVG → WebGL → WebGPU → Remote GPU
|
||||
✅ **Real-time parameter updates** via WebSocket
|
||||
✅ **Zero-config optimization** based on graph complexity
|
||||
|
||||
This integration provides the best of both worlds: the interactivity of your existing Three.js/D3.js stack for smaller graphs, and the scalability of remote GPU processing for large-scale visualizations.
|
||||
221
nvidia/txt2kg/assets/deploy/services/gpu-viz/README.md
Normal file
221
nvidia/txt2kg/assets/deploy/services/gpu-viz/README.md
Normal file
@ -0,0 +1,221 @@
|
||||
# Unified GPU Graph Visualization Service
|
||||
|
||||
## 🚀 Overview
|
||||
|
||||
The unified service combines **PyGraphistry Cloud** and **Local GPU (cuGraph)** processing into a single FastAPI service, giving you maximum flexibility for graph visualization.
|
||||
|
||||
## ⚡ Processing Modes
|
||||
|
||||
| Mode | Description | Requirements |
|
||||
|------|-------------|--------------|
|
||||
| **PyGraphistry Cloud** | Interactive GPU embeds in browser | API credentials |
|
||||
| **Local GPU (cuGraph)** | Full GPU processing on your hardware | NVIDIA GPU + cuGraph |
|
||||
| **Local CPU** | NetworkX fallback processing | None |
|
||||
|
||||
## 🛠️ Quick Setup
|
||||
|
||||
### 1. Set Environment Variables (Optional)
|
||||
```bash
|
||||
# For PyGraphistry Cloud features
|
||||
export GRAPHISTRY_PERSONAL_KEY="your_personal_key"
|
||||
export GRAPHISTRY_SECRET_KEY="your_secret_key"
|
||||
```
|
||||
|
||||
### 2. Run the Service
|
||||
|
||||
#### Option A: Direct Python
|
||||
```bash
|
||||
cd services
|
||||
python unified_gpu_service.py
|
||||
```
|
||||
|
||||
#### Option B: Using Startup Script
|
||||
```bash
|
||||
cd services
|
||||
./start_gpu_services.sh
|
||||
```
|
||||
|
||||
#### Option C: Docker (NVIDIA PyG Container)
|
||||
```bash
|
||||
cd services
|
||||
docker build -t unified-gpu-viz .
|
||||
docker run --gpus all -p 8080:8080 \
|
||||
-e GRAPHISTRY_PERSONAL_KEY="your_key" \
|
||||
-e GRAPHISTRY_SECRET_KEY="your_secret" \
|
||||
unified-gpu-viz
|
||||
```
|
||||
|
||||
## 📡 API Usage
|
||||
|
||||
### Process Graph with Mode Selection
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/visualize \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"graph_data": {
|
||||
"nodes": [{"id": "1", "name": "Node 1"}, {"id": "2", "name": "Node 2"}],
|
||||
"links": [{"source": "1", "target": "2", "name": "edge_1_2"}]
|
||||
},
|
||||
"processing_mode": "local_gpu",
|
||||
"layout_algorithm": "force_atlas2",
|
||||
"clustering_algorithm": "leiden",
|
||||
"compute_centrality": true
|
||||
}'
|
||||
```
|
||||
|
||||
### Check Available Capabilities
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/capabilities
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"processing_modes": {
|
||||
"pygraphistry_cloud": {"available": true, "description": "..."},
|
||||
"local_gpu": {"available": true, "description": "..."},
|
||||
"local_cpu": {"available": true, "description": "..."}
|
||||
},
|
||||
"has_rapids": true,
|
||||
"gpu_available": true
|
||||
}
|
||||
```
|
||||
|
||||
## 🎯 Frontend Integration
|
||||
|
||||
### React Component Usage
|
||||
|
||||
```tsx
|
||||
import { UnifiedGPUViewer } from '@/components/unified-gpu-viewer'
|
||||
|
||||
function MyApp() {
|
||||
const graphData = {
|
||||
nodes: [...],
|
||||
links: [...]
|
||||
}
|
||||
|
||||
return (
|
||||
<UnifiedGPUViewer
|
||||
graphData={graphData}
|
||||
onError={(error) => console.error(error)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Mode-Specific Processing
|
||||
|
||||
```javascript
|
||||
// PyGraphistry Cloud mode
|
||||
const response = await fetch('/api/unified-gpu/visualize', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
graph_data: { nodes, links },
|
||||
processing_mode: 'pygraphistry_cloud',
|
||||
layout_type: 'force',
|
||||
clustering: true,
|
||||
gpu_acceleration: true
|
||||
})
|
||||
})
|
||||
|
||||
// Local GPU mode
|
||||
const response = await fetch('/api/unified-gpu/visualize', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
graph_data: { nodes, links },
|
||||
processing_mode: 'local_gpu',
|
||||
layout_algorithm: 'force_atlas2',
|
||||
clustering_algorithm: 'leiden',
|
||||
compute_centrality: true
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## 🔧 Configuration Options
|
||||
|
||||
### PyGraphistry Cloud Mode
|
||||
- `layout_type`: "force", "circular", "hierarchical"
|
||||
- `gpu_acceleration`: true/false
|
||||
- `clustering`: true/false
|
||||
|
||||
### Local GPU Mode
|
||||
- `layout_algorithm`: "force_atlas2", "spectral", "fruchterman_reingold"
|
||||
- `clustering_algorithm`: "leiden", "louvain", "spectral"
|
||||
- `compute_centrality`: true/false
|
||||
|
||||
### Local CPU Mode
|
||||
- Basic processing with NetworkX fallback
|
||||
- No additional configuration needed
|
||||
|
||||
## 📊 Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"processed_nodes": [...],
|
||||
"processed_edges": [...],
|
||||
"processing_mode": "local_gpu",
|
||||
"embed_url": "https://hub.graphistry.com/...", // Only for cloud mode
|
||||
"layout_positions": {...}, // Only for local GPU mode
|
||||
"clusters": {...},
|
||||
"centrality": {...},
|
||||
"stats": {
|
||||
"node_count": 1000,
|
||||
"edge_count": 5000,
|
||||
"gpu_accelerated": true,
|
||||
"layout_computed": true,
|
||||
"clusters_computed": true
|
||||
},
|
||||
"timestamp": "2024-01-01T12:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
## 🚀 Benefits of Unified Approach
|
||||
|
||||
### ✅ Advantages
|
||||
- **Single service** - One port, one deployment
|
||||
- **Mode switching** - Choose best processing per graph
|
||||
- **Fallback handling** - Graceful degradation if GPU unavailable
|
||||
- **Consistent API** - Same interface for all modes
|
||||
- **Better testing** - Easy comparison between modes
|
||||
|
||||
### 🎯 Use Cases
|
||||
- **PyGraphistry Cloud**: Sharing visualizations, demos, production embeds
|
||||
- **Local GPU**: Private data, large-scale processing, custom algorithms
|
||||
- **Local CPU**: Development, testing, small graphs
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### GPU Not Detected
|
||||
```bash
|
||||
# Check GPU availability
|
||||
nvidia-smi
|
||||
|
||||
# Check RAPIDS installation
|
||||
python -c "import cudf, cugraph; print('RAPIDS OK')"
|
||||
```
|
||||
|
||||
### PyGraphistry Credentials
|
||||
```bash
|
||||
# Verify credentials are set
|
||||
echo $GRAPHISTRY_PERSONAL_KEY
|
||||
echo $GRAPHISTRY_SECRET_KEY
|
||||
|
||||
# Test connection
|
||||
python -c "import graphistry; graphistry.register(personal_key_id='$GRAPHISTRY_PERSONAL_KEY', personal_key_secret='$GRAPHISTRY_SECRET_KEY'); print('PyGraphistry OK')"
|
||||
```
|
||||
|
||||
### Service Health
|
||||
```bash
|
||||
curl http://localhost:8080/api/health
|
||||
```
|
||||
|
||||
## 📈 Performance Tips
|
||||
|
||||
1. **Large graphs (>100k nodes)**: Use `local_gpu` mode
|
||||
2. **Sharing/demos**: Use `pygraphistry_cloud` mode
|
||||
3. **Development**: Use `local_cpu` mode for speed
|
||||
4. **Mixed workloads**: Switch modes dynamically based on graph size
|
||||
@ -0,0 +1,443 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
# GPU-accelerated imports (available in NVIDIA PyG container)
|
||||
try:
|
||||
import cudf
|
||||
import cugraph
|
||||
import cupy as cp
|
||||
from cuml import UMAP
|
||||
HAS_RAPIDS = True
|
||||
print("✓ RAPIDS cuGraph/cuDF/cuML available")
|
||||
except ImportError:
|
||||
HAS_RAPIDS = False
|
||||
print("⚠ RAPIDS not available, falling back to CPU")
|
||||
import networkx as nx
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch_geometric
|
||||
HAS_TORCH_GEOMETRIC = True
|
||||
print("✓ PyTorch Geometric available")
|
||||
except ImportError:
|
||||
HAS_TORCH_GEOMETRIC = False
|
||||
print("⚠ PyTorch Geometric not available")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GraphData(BaseModel):
|
||||
nodes: List[Dict[str, Any]]
|
||||
links: List[Dict[str, Any]]
|
||||
|
||||
class VisualizationRequest(BaseModel):
|
||||
graph_data: GraphData
|
||||
layout_algorithm: str = "force_atlas2" # force_atlas2, fruchterman_reingold, spectral
|
||||
clustering_algorithm: str = "leiden" # leiden, louvain, spectral
|
||||
gpu_acceleration: bool = True
|
||||
compute_centrality: bool = True
|
||||
|
||||
class GPUGraphProcessor:
|
||||
"""GPU-accelerated graph processing using cuGraph"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_gpu = HAS_RAPIDS
|
||||
logger.info(f"GPU Graph Processor initialized (GPU: {self.use_gpu})")
|
||||
|
||||
def create_cugraph_from_data(self, nodes: List[Dict], edges: List[Dict]) -> 'cugraph.Graph':
|
||||
"""Create cuGraph from node/edge data"""
|
||||
if not self.use_gpu:
|
||||
raise RuntimeError("GPU libraries not available")
|
||||
|
||||
# Create edge dataframe
|
||||
edge_data = []
|
||||
for edge in edges:
|
||||
edge_data.append({
|
||||
'src': edge['source'],
|
||||
'dst': edge['target'],
|
||||
'weight': edge.get('weight', 1.0)
|
||||
})
|
||||
|
||||
# Convert to cuDF
|
||||
edges_df = cudf.DataFrame(edge_data)
|
||||
|
||||
# Create cuGraph
|
||||
G = cugraph.Graph()
|
||||
G.from_cudf_edgelist(edges_df, source='src', destination='dst', edge_attr='weight')
|
||||
|
||||
return G, edges_df
|
||||
|
||||
def compute_gpu_layout(self, G, algorithm: str = "force_atlas2") -> Dict[str, Tuple[float, float]]:
|
||||
"""Compute GPU-accelerated graph layout"""
|
||||
try:
|
||||
if algorithm == "force_atlas2":
|
||||
layout_df = cugraph.force_atlas2(G)
|
||||
elif algorithm == "fruchterman_reingold":
|
||||
# Use spectral as fallback since FR might not be available
|
||||
layout_df = cugraph.spectral_layout(G, dim=2)
|
||||
else: # spectral
|
||||
layout_df = cugraph.spectral_layout(G, dim=2)
|
||||
|
||||
# Convert to dictionary
|
||||
positions = {}
|
||||
for _, row in layout_df.iterrows():
|
||||
node_id = str(row['vertex'])
|
||||
positions[node_id] = (float(row['x']), float(row['y']))
|
||||
|
||||
logger.info(f"Computed {algorithm} layout for {len(positions)} nodes on GPU")
|
||||
return positions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU layout computation failed: {e}")
|
||||
return {}
|
||||
|
||||
def compute_gpu_clustering(self, G, algorithm: str = "leiden") -> Dict[str, int]:
|
||||
"""Compute GPU-accelerated community detection"""
|
||||
try:
|
||||
if algorithm == "leiden":
|
||||
clusters_df, modularity = cugraph.leiden(G)
|
||||
elif algorithm == "louvain":
|
||||
clusters_df, modularity = cugraph.louvain(G)
|
||||
else: # spectral clustering
|
||||
clusters_df = cugraph.spectral_clustering(G, n_clusters=10)
|
||||
modularity = 0.0
|
||||
|
||||
# Convert to dictionary
|
||||
clusters = {}
|
||||
for _, row in clusters_df.iterrows():
|
||||
node_id = str(row['vertex'])
|
||||
clusters[node_id] = int(row['partition'])
|
||||
|
||||
logger.info(f"Computed {algorithm} clustering on GPU (modularity: {modularity:.3f})")
|
||||
return clusters
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU clustering failed: {e}")
|
||||
return {}
|
||||
|
||||
def compute_gpu_centrality(self, G) -> Dict[str, Dict[str, float]]:
|
||||
"""Compute GPU-accelerated centrality measures"""
|
||||
centrality_data = {}
|
||||
|
||||
try:
|
||||
# PageRank
|
||||
pagerank_df = cugraph.pagerank(G)
|
||||
pagerank = {}
|
||||
for _, row in pagerank_df.iterrows():
|
||||
pagerank[str(row['vertex'])] = float(row['pagerank'])
|
||||
centrality_data['pagerank'] = pagerank
|
||||
|
||||
# Betweenness centrality (for smaller graphs)
|
||||
if G.number_of_vertices() < 5000:
|
||||
betweenness_df = cugraph.betweenness_centrality(G)
|
||||
betweenness = {}
|
||||
for _, row in betweenness_df.iterrows():
|
||||
betweenness[str(row['vertex'])] = float(row['betweenness_centrality'])
|
||||
centrality_data['betweenness'] = betweenness
|
||||
|
||||
logger.info(f"Computed centrality measures on GPU")
|
||||
return centrality_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU centrality computation failed: {e}")
|
||||
return {}
|
||||
|
||||
class LocalGPUVisualizationService:
|
||||
"""Local GPU-powered interactive graph visualization service"""
|
||||
|
||||
def __init__(self):
|
||||
self.gpu_processor = GPUGraphProcessor()
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def process_graph(self, request: VisualizationRequest) -> Dict[str, Any]:
|
||||
"""Process graph with GPU acceleration"""
|
||||
try:
|
||||
nodes = request.graph_data.nodes
|
||||
edges = request.graph_data.links
|
||||
|
||||
result = {
|
||||
"nodes": nodes.copy(),
|
||||
"edges": edges.copy(),
|
||||
"gpu_processed": False,
|
||||
"layout_positions": {},
|
||||
"clusters": {},
|
||||
"centrality": {},
|
||||
"stats": {},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if request.gpu_acceleration and self.gpu_processor.use_gpu:
|
||||
logger.info("=== GPU PROCESSING START ===")
|
||||
|
||||
# Create cuGraph
|
||||
G, edges_df = self.gpu_processor.create_cugraph_from_data(nodes, edges)
|
||||
|
||||
# Compute layout on GPU
|
||||
positions = self.gpu_processor.compute_gpu_layout(G, request.layout_algorithm)
|
||||
if positions:
|
||||
result["layout_positions"] = positions
|
||||
# Add positions to nodes
|
||||
for node in result["nodes"]:
|
||||
node_id = str(node["id"])
|
||||
if node_id in positions:
|
||||
node["x"], node["y"] = positions[node_id]
|
||||
|
||||
# Compute clustering on GPU
|
||||
clusters = self.gpu_processor.compute_gpu_clustering(G, request.clustering_algorithm)
|
||||
if clusters:
|
||||
result["clusters"] = clusters
|
||||
# Add cluster info to nodes
|
||||
for node in result["nodes"]:
|
||||
node_id = str(node["id"])
|
||||
if node_id in clusters:
|
||||
node["cluster"] = clusters[node_id]
|
||||
|
||||
# Compute centrality on GPU
|
||||
if request.compute_centrality:
|
||||
centrality = self.gpu_processor.compute_gpu_centrality(G)
|
||||
result["centrality"] = centrality
|
||||
# Add centrality to nodes
|
||||
for node in result["nodes"]:
|
||||
node_id = str(node["id"])
|
||||
for metric, values in centrality.items():
|
||||
if node_id in values:
|
||||
node[metric] = values[node_id]
|
||||
|
||||
result["gpu_processed"] = True
|
||||
result["stats"] = {
|
||||
"node_count": len(nodes),
|
||||
"edge_count": len(edges),
|
||||
"gpu_accelerated": True,
|
||||
"layout_computed": len(positions) > 0,
|
||||
"clusters_computed": len(clusters) > 0,
|
||||
"centrality_computed": len(centrality) > 0
|
||||
}
|
||||
|
||||
logger.info("=== GPU PROCESSING COMPLETE ===")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Graph processing failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def broadcast_update(self, data: Dict[str, Any]):
|
||||
"""Broadcast updates to all connected WebSocket clients"""
|
||||
if self.active_connections:
|
||||
message = json.dumps(data)
|
||||
for connection in self.active_connections.copy():
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except WebSocketDisconnect:
|
||||
self.active_connections.remove(connection)
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="Local GPU Graph Visualization", version="1.0.0")
|
||||
service = LocalGPUVisualizationService()
|
||||
|
||||
@app.post("/api/process")
|
||||
async def process_graph(request: VisualizationRequest):
|
||||
"""Process graph with local GPU acceleration"""
|
||||
result = await service.process_graph(request)
|
||||
|
||||
# Broadcast to connected WebSocket clients
|
||||
await service.broadcast_update({
|
||||
"type": "graph_processed",
|
||||
"data": result
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates"""
|
||||
await websocket.accept()
|
||||
service.active_connections.append(websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
service.active_connections.remove(websocket)
|
||||
|
||||
@app.get("/api/capabilities")
|
||||
async def get_capabilities():
|
||||
"""Get GPU capabilities"""
|
||||
return {
|
||||
"has_rapids": HAS_RAPIDS,
|
||||
"has_torch_geometric": HAS_TORCH_GEOMETRIC,
|
||||
"gpu_available": HAS_RAPIDS,
|
||||
"supported_layouts": ["force_atlas2", "spectral", "fruchterman_reingold"],
|
||||
"supported_clustering": ["leiden", "louvain", "spectral"],
|
||||
"gpu_memory": "N/A" # Could add GPU memory info here
|
||||
}
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def get_visualization_page():
|
||||
"""Serve the interactive visualization page"""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Local GPU Graph Visualization</title>
|
||||
<script src="https://d3js.org/d3.v7.min.js"></script>
|
||||
<style>
|
||||
body { margin: 0; font-family: Arial, sans-serif; background: #1a1a1a; color: white; }
|
||||
#controls { position: absolute; top: 10px; left: 10px; z-index: 100; background: rgba(0,0,0,0.8); padding: 10px; border-radius: 5px; }
|
||||
#graph { width: 100vw; height: 100vh; }
|
||||
.node { cursor: pointer; }
|
||||
.link { stroke: #999; stroke-opacity: 0.6; }
|
||||
button { margin: 5px; padding: 5px 10px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="controls">
|
||||
<h3>🚀 Local GPU Visualization</h3>
|
||||
<button onclick="loadSampleGraph()">Load Sample Graph</button>
|
||||
<div id="status">Ready</div>
|
||||
</div>
|
||||
<div id="graph"></div>
|
||||
|
||||
<script>
|
||||
const width = window.innerWidth;
|
||||
const height = window.innerHeight;
|
||||
|
||||
const svg = d3.select("#graph")
|
||||
.append("svg")
|
||||
.attr("width", width)
|
||||
.attr("height", height);
|
||||
|
||||
const g = svg.append("g");
|
||||
|
||||
// Add zoom behavior
|
||||
const zoom = d3.zoom()
|
||||
.scaleExtent([0.1, 10])
|
||||
.on("zoom", (event) => {
|
||||
g.attr("transform", event.transform);
|
||||
});
|
||||
|
||||
svg.call(zoom);
|
||||
|
||||
// WebSocket connection for real-time updates
|
||||
const ws = new WebSocket(`ws://localhost:8081/ws`);
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
const message = JSON.parse(event.data);
|
||||
if (message.type === 'graph_processed') {
|
||||
renderGraph(message.data);
|
||||
}
|
||||
};
|
||||
|
||||
function renderGraph(data) {
|
||||
console.log("Rendering graph with", data.nodes.length, "nodes");
|
||||
|
||||
// Clear previous graph
|
||||
g.selectAll("*").remove();
|
||||
|
||||
// Create links
|
||||
const links = g.selectAll(".link")
|
||||
.data(data.edges)
|
||||
.enter().append("line")
|
||||
.attr("class", "link")
|
||||
.attr("stroke-width", 1);
|
||||
|
||||
// Create nodes
|
||||
const nodes = g.selectAll(".node")
|
||||
.data(data.nodes)
|
||||
.enter().append("circle")
|
||||
.attr("class", "node")
|
||||
.attr("r", d => Math.sqrt((d.pagerank || 0.001) * 1000) + 2)
|
||||
.attr("fill", d => d3.schemeCategory10[d.cluster % 10] || "#69b3a2")
|
||||
.attr("stroke", "#fff")
|
||||
.attr("stroke-width", 1.5);
|
||||
|
||||
// Add node labels for important nodes
|
||||
const labels = g.selectAll(".label")
|
||||
.data(data.nodes.filter(d => (d.pagerank || 0) > 0.01))
|
||||
.enter().append("text")
|
||||
.attr("class", "label")
|
||||
.attr("dy", -3)
|
||||
.attr("text-anchor", "middle")
|
||||
.style("font-size", "10px")
|
||||
.style("fill", "white")
|
||||
.text(d => d.id);
|
||||
|
||||
// Position nodes using GPU-computed coordinates
|
||||
if (data.layout_positions && Object.keys(data.layout_positions).length > 0) {
|
||||
nodes.attr("cx", d => (data.layout_positions[d.id] && data.layout_positions[d.id][0]) || width/2)
|
||||
.attr("cy", d => (data.layout_positions[d.id] && data.layout_positions[d.id][1]) || height/2);
|
||||
|
||||
labels.attr("x", d => (data.layout_positions[d.id] && data.layout_positions[d.id][0]) || width/2)
|
||||
.attr("y", d => (data.layout_positions[d.id] && data.layout_positions[d.id][1]) || height/2);
|
||||
|
||||
links.attr("x1", d => (data.layout_positions[d.source] && data.layout_positions[d.source][0]) || width/2)
|
||||
.attr("y1", d => (data.layout_positions[d.source] && data.layout_positions[d.source][1]) || height/2)
|
||||
.attr("x2", d => (data.layout_positions[d.target] && data.layout_positions[d.target][0]) || width/2)
|
||||
.attr("y2", d => (data.layout_positions[d.target] && data.layout_positions[d.target][1]) || height/2);
|
||||
} else {
|
||||
// Fallback to force simulation
|
||||
const simulation = d3.forceSimulation(data.nodes)
|
||||
.force("link", d3.forceLink(data.edges).id(d => d.id))
|
||||
.force("charge", d3.forceManyBody().strength(-30))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
|
||||
simulation.on("tick", () => {
|
||||
links.attr("x1", d => d.source.x)
|
||||
.attr("y1", d => d.source.y)
|
||||
.attr("x2", d => d.target.x)
|
||||
.attr("y2", d => d.target.y);
|
||||
|
||||
nodes.attr("cx", d => d.x)
|
||||
.attr("cy", d => d.y);
|
||||
|
||||
labels.attr("x", d => d.x)
|
||||
.attr("y", d => d.y);
|
||||
});
|
||||
}
|
||||
|
||||
// Add tooltips
|
||||
nodes.append("title")
|
||||
.text(d => `Node: ${d.id}\\nCluster: ${d.cluster || 'N/A'}\\nPageRank: ${(d.pagerank || 0).toFixed(4)}`);
|
||||
|
||||
document.getElementById("status").innerHTML =
|
||||
`Rendered ${data.nodes.length} nodes, ${data.edges.length} edges (GPU: ${data.gpu_processed})`;
|
||||
}
|
||||
|
||||
async function loadSampleGraph() {
|
||||
// This would load your graph data and send it for processing
|
||||
document.getElementById("status").innerHTML = "Loading sample graph...";
|
||||
|
||||
// You can integrate this with your existing graph generation
|
||||
// For now, this is a placeholder
|
||||
alert("Connect this to your graph generation service!");
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"gpu_available": HAS_RAPIDS,
|
||||
"torch_geometric": HAS_TORCH_GEOMETRIC,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8081)
|
||||
@ -0,0 +1,712 @@
|
||||
import graphistry
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import networkx as nx
|
||||
from enum import Enum
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize PyGraphistry
|
||||
def init_graphistry():
|
||||
"""Initialize PyGraphistry with GPU acceleration"""
|
||||
try:
|
||||
# Set up authentication - check for different credential types
|
||||
api_key = os.getenv('GRAPHISTRY_API_KEY')
|
||||
personal_key = os.getenv('GRAPHISTRY_PERSONAL_KEY')
|
||||
secret_key = os.getenv('GRAPHISTRY_SECRET_KEY')
|
||||
username = os.getenv('GRAPHISTRY_USERNAME')
|
||||
password = os.getenv('GRAPHISTRY_PASSWORD')
|
||||
|
||||
if personal_key and secret_key:
|
||||
# Configure for cloud API with personal key and secret
|
||||
graphistry.register(
|
||||
api=3,
|
||||
protocol="https",
|
||||
server="hub.graphistry.com",
|
||||
personal_key_id=personal_key,
|
||||
personal_key_secret=secret_key
|
||||
)
|
||||
logger.info("PyGraphistry initialized with personal key/secret for cloud GPU acceleration")
|
||||
return True
|
||||
elif api_key:
|
||||
# Configure for cloud API with API key
|
||||
graphistry.register(api=3, protocol="https", server="hub.graphistry.com", api_key=api_key)
|
||||
logger.info("PyGraphistry initialized with API key for cloud GPU acceleration")
|
||||
return True
|
||||
elif username and password:
|
||||
# Configure for cloud API with username/password
|
||||
graphistry.register(api=3, protocol="https", server="hub.graphistry.com",
|
||||
username=username, password=password)
|
||||
logger.info("PyGraphistry initialized with username/password for cloud GPU acceleration")
|
||||
return True
|
||||
else:
|
||||
# Configure for local mode
|
||||
graphistry.register(api=3)
|
||||
logger.info("PyGraphistry initialized in local CPU mode")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PyGraphistry: {e}")
|
||||
return False
|
||||
|
||||
class GraphPattern(str, Enum):
|
||||
RANDOM = "random"
|
||||
SCALE_FREE = "scale-free"
|
||||
SMALL_WORLD = "small-world"
|
||||
CLUSTERED = "clustered"
|
||||
HIERARCHICAL = "hierarchical"
|
||||
GRID = "grid"
|
||||
|
||||
class GraphData(BaseModel):
|
||||
nodes: List[Dict[str, Any]]
|
||||
links: List[Dict[str, Any]]
|
||||
|
||||
class GraphGenerationRequest(BaseModel):
|
||||
num_nodes: int
|
||||
pattern: GraphPattern = GraphPattern.SCALE_FREE
|
||||
avg_degree: Optional[int] = 5
|
||||
num_clusters: Optional[int] = 100
|
||||
small_world_k: Optional[int] = 6
|
||||
small_world_p: Optional[float] = 0.1
|
||||
grid_dimensions: Optional[List[int]] = [100, 100]
|
||||
seed: Optional[int] = None
|
||||
|
||||
class VisualizationRequest(BaseModel):
|
||||
graph_data: GraphData
|
||||
layout_type: Optional[str] = "force"
|
||||
gpu_acceleration: Optional[bool] = True
|
||||
clustering: Optional[bool] = False
|
||||
node_size_attribute: Optional[str] = None
|
||||
node_color_attribute: Optional[str] = None
|
||||
edge_weight_attribute: Optional[str] = None
|
||||
|
||||
class GraphGenerationStatus(BaseModel):
|
||||
task_id: str
|
||||
status: str # "running", "completed", "failed"
|
||||
progress: float
|
||||
message: str
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
class LargeGraphGenerator:
|
||||
"""Optimized graph generation using NetworkX and NumPy for performance"""
|
||||
|
||||
@staticmethod
|
||||
def generate_random_graph(num_nodes: int, avg_degree: int = 5, seed: Optional[int] = None) -> GraphData:
|
||||
"""Generate random graph using Erdős–Rényi model"""
|
||||
if seed:
|
||||
np.random.seed(seed)
|
||||
|
||||
# Calculate edge probability for desired average degree
|
||||
p = avg_degree / (num_nodes - 1)
|
||||
|
||||
# Use NetworkX for efficient generation
|
||||
G = nx.erdos_renyi_graph(num_nodes, p, seed=seed)
|
||||
|
||||
return LargeGraphGenerator._networkx_to_graphdata(G)
|
||||
|
||||
@staticmethod
|
||||
def generate_scale_free_graph(num_nodes: int, m: int = 3, seed: Optional[int] = None) -> GraphData:
|
||||
"""Generate scale-free graph using Barabási–Albert model"""
|
||||
G = nx.barabasi_albert_graph(num_nodes, m, seed=seed)
|
||||
return LargeGraphGenerator._networkx_to_graphdata(G)
|
||||
|
||||
@staticmethod
|
||||
def generate_small_world_graph(num_nodes: int, k: int = 6, p: float = 0.1, seed: Optional[int] = None) -> GraphData:
|
||||
"""Generate small-world graph using Watts-Strogatz model"""
|
||||
G = nx.watts_strogatz_graph(num_nodes, k, p, seed=seed)
|
||||
return LargeGraphGenerator._networkx_to_graphdata(G)
|
||||
|
||||
@staticmethod
|
||||
def generate_clustered_graph(num_nodes: int, num_clusters: int = 100, seed: Optional[int] = None) -> GraphData:
|
||||
"""Generate clustered graph with intra and inter-cluster connections"""
|
||||
if seed:
|
||||
np.random.seed(seed)
|
||||
|
||||
cluster_size = num_nodes // num_clusters
|
||||
G = nx.Graph()
|
||||
|
||||
# Add nodes with cluster information
|
||||
for i in range(num_nodes):
|
||||
cluster_id = i // cluster_size
|
||||
G.add_node(i, cluster=cluster_id)
|
||||
|
||||
# Generate intra-cluster edges
|
||||
intra_prob = 0.1
|
||||
for cluster in range(num_clusters):
|
||||
cluster_start = cluster * cluster_size
|
||||
cluster_end = min(cluster_start + cluster_size, num_nodes)
|
||||
cluster_nodes = list(range(cluster_start, cluster_end))
|
||||
|
||||
# Create subgraph for cluster
|
||||
cluster_subgraph = nx.erdos_renyi_graph(len(cluster_nodes), intra_prob)
|
||||
|
||||
# Add edges to main graph with proper node mapping
|
||||
for edge in cluster_subgraph.edges():
|
||||
G.add_edge(cluster_nodes[edge[0]], cluster_nodes[edge[1]])
|
||||
|
||||
# Generate inter-cluster edges
|
||||
inter_prob = 0.001
|
||||
for i in range(num_nodes):
|
||||
for j in range(i + 1, num_nodes):
|
||||
if G.nodes[i].get('cluster') != G.nodes[j].get('cluster'):
|
||||
if np.random.random() < inter_prob:
|
||||
G.add_edge(i, j)
|
||||
|
||||
return LargeGraphGenerator._networkx_to_graphdata(G)
|
||||
|
||||
@staticmethod
|
||||
def generate_hierarchical_graph(num_nodes: int, branching_factor: int = 3, seed: Optional[int] = None) -> GraphData:
|
||||
"""Generate hierarchical (tree-like) graph"""
|
||||
G = nx.random_tree(num_nodes, seed=seed)
|
||||
|
||||
# Add some cross-links to make it more interesting
|
||||
if seed:
|
||||
np.random.seed(seed)
|
||||
|
||||
# Add 10% additional edges for cross-connections
|
||||
num_additional_edges = max(1, num_nodes // 10)
|
||||
nodes = list(G.nodes())
|
||||
|
||||
for _ in range(num_additional_edges):
|
||||
u, v = np.random.choice(nodes, 2, replace=False)
|
||||
if not G.has_edge(u, v):
|
||||
G.add_edge(u, v)
|
||||
|
||||
return LargeGraphGenerator._networkx_to_graphdata(G)
|
||||
|
||||
@staticmethod
|
||||
def generate_grid_graph(dimensions: List[int], seed: Optional[int] = None) -> GraphData:
|
||||
"""Generate 2D or 3D grid graph"""
|
||||
if len(dimensions) == 2:
|
||||
G = nx.grid_2d_graph(dimensions[0], dimensions[1])
|
||||
elif len(dimensions) == 3:
|
||||
G = nx.grid_graph(dimensions)
|
||||
else:
|
||||
raise ValueError("Grid dimensions must be 2D or 3D")
|
||||
|
||||
# Convert coordinate tuples to integer node IDs
|
||||
mapping = {node: i for i, node in enumerate(G.nodes())}
|
||||
G = nx.relabel_nodes(G, mapping)
|
||||
|
||||
return LargeGraphGenerator._networkx_to_graphdata(G)
|
||||
|
||||
@staticmethod
|
||||
def _networkx_to_graphdata(G: nx.Graph) -> GraphData:
|
||||
"""Convert NetworkX graph to GraphData format"""
|
||||
nodes = []
|
||||
links = []
|
||||
|
||||
# Convert nodes
|
||||
for node_id in G.nodes():
|
||||
node_data = G.nodes[node_id]
|
||||
node = {
|
||||
"id": f"n{node_id}",
|
||||
"name": f"Node {node_id}",
|
||||
"val": np.random.randint(1, 11),
|
||||
"degree": G.degree(node_id)
|
||||
}
|
||||
|
||||
# Add cluster information if available
|
||||
if 'cluster' in node_data:
|
||||
node['group'] = f"cluster_{node_data['cluster']}"
|
||||
else:
|
||||
node['group'] = f"group_{node_id % 10}"
|
||||
|
||||
nodes.append(node)
|
||||
|
||||
# Convert edges
|
||||
for edge in G.edges():
|
||||
link = {
|
||||
"source": f"n{edge[0]}",
|
||||
"target": f"n{edge[1]}",
|
||||
"name": f"link_{edge[0]}_{edge[1]}",
|
||||
"weight": np.random.uniform(0.1, 5.0)
|
||||
}
|
||||
links.append(link)
|
||||
|
||||
return GraphData(nodes=nodes, links=links)
|
||||
|
||||
class PyGraphistryService:
|
||||
def __init__(self):
|
||||
self.initialized = init_graphistry()
|
||||
self.generation_tasks = {} # Store background tasks
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
async def generate_graph_async(self, request: GraphGenerationRequest, task_id: str):
|
||||
"""Generate graph asynchronously"""
|
||||
try:
|
||||
self.generation_tasks[task_id] = GraphGenerationStatus(
|
||||
task_id=task_id,
|
||||
status="running",
|
||||
progress=0.0,
|
||||
message="Starting graph generation..."
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Update progress
|
||||
self.generation_tasks[task_id].progress = 10.0
|
||||
self.generation_tasks[task_id].message = f"Generating {request.pattern.value} graph with {request.num_nodes} nodes..."
|
||||
|
||||
# Generate graph based on pattern
|
||||
if request.pattern == GraphPattern.RANDOM:
|
||||
graph_data = LargeGraphGenerator.generate_random_graph(
|
||||
request.num_nodes, request.avg_degree, request.seed
|
||||
)
|
||||
elif request.pattern == GraphPattern.SCALE_FREE:
|
||||
m = min(request.avg_degree, request.num_nodes - 1) if request.avg_degree else 3
|
||||
graph_data = LargeGraphGenerator.generate_scale_free_graph(
|
||||
request.num_nodes, m, request.seed
|
||||
)
|
||||
elif request.pattern == GraphPattern.SMALL_WORLD:
|
||||
graph_data = LargeGraphGenerator.generate_small_world_graph(
|
||||
request.num_nodes,
|
||||
request.small_world_k or 6,
|
||||
request.small_world_p or 0.1,
|
||||
request.seed
|
||||
)
|
||||
elif request.pattern == GraphPattern.CLUSTERED:
|
||||
graph_data = LargeGraphGenerator.generate_clustered_graph(
|
||||
request.num_nodes, request.num_clusters or 100, request.seed
|
||||
)
|
||||
elif request.pattern == GraphPattern.HIERARCHICAL:
|
||||
graph_data = LargeGraphGenerator.generate_hierarchical_graph(
|
||||
request.num_nodes, seed=request.seed
|
||||
)
|
||||
elif request.pattern == GraphPattern.GRID:
|
||||
# Calculate grid dimensions for given number of nodes
|
||||
if request.grid_dimensions:
|
||||
dimensions = request.grid_dimensions
|
||||
else:
|
||||
side_length = int(np.sqrt(request.num_nodes))
|
||||
dimensions = [side_length, side_length]
|
||||
graph_data = LargeGraphGenerator.generate_grid_graph(dimensions, request.seed)
|
||||
else:
|
||||
raise ValueError(f"Unknown graph pattern: {request.pattern}")
|
||||
|
||||
# Update progress
|
||||
self.generation_tasks[task_id].progress = 80.0
|
||||
self.generation_tasks[task_id].message = "Computing graph statistics..."
|
||||
|
||||
# Calculate statistics
|
||||
generation_time = time.time() - start_time
|
||||
stats = {
|
||||
"node_count": len(graph_data.nodes),
|
||||
"edge_count": len(graph_data.links),
|
||||
"generation_time": generation_time,
|
||||
"density": len(graph_data.links) / (len(graph_data.nodes) * (len(graph_data.nodes) - 1) / 2) if len(graph_data.nodes) > 1 else 0,
|
||||
"avg_degree": 2 * len(graph_data.links) / len(graph_data.nodes) if len(graph_data.nodes) > 0 else 0,
|
||||
"pattern": request.pattern.value,
|
||||
"parameters": request.model_dump()
|
||||
}
|
||||
|
||||
# Complete task
|
||||
self.generation_tasks[task_id].status = "completed"
|
||||
self.generation_tasks[task_id].progress = 100.0
|
||||
self.generation_tasks[task_id].message = f"Generated {stats['node_count']} nodes and {stats['edge_count']} edges in {generation_time:.2f}s"
|
||||
self.generation_tasks[task_id].result = {
|
||||
"graph_data": graph_data.model_dump(),
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
logger.info(f"Graph generation completed for task {task_id}: {stats}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Graph generation failed for task {task_id}: {e}")
|
||||
self.generation_tasks[task_id].status = "failed"
|
||||
self.generation_tasks[task_id].error = str(e)
|
||||
self.generation_tasks[task_id].message = f"Generation failed: {e}"
|
||||
|
||||
async def start_graph_generation(self, request: GraphGenerationRequest) -> str:
|
||||
"""Start graph generation as background task"""
|
||||
task_id = f"gen_{int(time.time() * 1000)}"
|
||||
|
||||
# Run generation in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: asyncio.run(self.generate_graph_async(request, task_id))
|
||||
)
|
||||
|
||||
return task_id
|
||||
|
||||
def get_generation_status(self, task_id: str) -> Optional[GraphGenerationStatus]:
|
||||
"""Get status of graph generation task"""
|
||||
return self.generation_tasks.get(task_id)
|
||||
|
||||
async def process_graph_data(self, request: VisualizationRequest) -> Dict[str, Any]:
|
||||
"""Process graph data with PyGraphistry GPU acceleration"""
|
||||
try:
|
||||
if not self.initialized:
|
||||
raise HTTPException(status_code=500, detail="PyGraphistry not initialized")
|
||||
|
||||
# Convert to pandas DataFrames for PyGraphistry
|
||||
nodes_df = pd.DataFrame(request.graph_data.nodes)
|
||||
edges_df = pd.DataFrame(request.graph_data.links)
|
||||
|
||||
# Ensure required columns exist
|
||||
if 'id' not in nodes_df.columns:
|
||||
nodes_df['id'] = nodes_df.index
|
||||
if 'source' not in edges_df.columns or 'target' not in edges_df.columns:
|
||||
raise HTTPException(status_code=400, detail="Links must have source and target columns")
|
||||
|
||||
logger.info(f"Processing graph with {len(nodes_df)} nodes and {len(edges_df)} edges")
|
||||
|
||||
# Create PyGraphistry graph object
|
||||
try:
|
||||
g = graphistry.edges(edges_df, 'source', 'target').nodes(nodes_df, 'id')
|
||||
logger.info(f"Created PyGraphistry graph object")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create PyGraphistry graph: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Graph creation failed: {e}")
|
||||
|
||||
# Apply GPU-accelerated processing
|
||||
if request.gpu_acceleration:
|
||||
g = await self._apply_gpu_acceleration(g, request)
|
||||
|
||||
# Apply clustering if requested
|
||||
if request.clustering:
|
||||
g = await self._apply_clustering(g)
|
||||
|
||||
# Generate layout
|
||||
g = await self._generate_layout(g, request.layout_type)
|
||||
|
||||
# Extract processed data
|
||||
try:
|
||||
processed_nodes = g._nodes.to_dict('records') if g._nodes is not None else nodes_df.to_dict('records')
|
||||
processed_edges = g._edges.to_dict('records') if g._edges is not None else edges_df.to_dict('records')
|
||||
logger.info(f"Extracted {len(processed_nodes)} nodes and {len(processed_edges)} edges")
|
||||
except Exception as e:
|
||||
logger.warning(f"Data extraction failed, using original data: {e}")
|
||||
processed_nodes = nodes_df.to_dict('records')
|
||||
processed_edges = edges_df.to_dict('records')
|
||||
|
||||
# Generate embedding URL for interactive visualization
|
||||
embed_url = None
|
||||
local_viz_data = None
|
||||
|
||||
try:
|
||||
embed_url = g.plot(render=False)
|
||||
logger.info(f"Generated PyGraphistry embed URL: {embed_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not generate embed URL (likely running in local mode): {e}")
|
||||
|
||||
# Create local visualization data as fallback
|
||||
try:
|
||||
local_viz_data = self._create_local_viz_data(g, processed_nodes, processed_edges)
|
||||
logger.info("Generated local visualization data as fallback")
|
||||
except Exception as viz_e:
|
||||
logger.warning(f"Could not generate local visualization data: {viz_e}")
|
||||
|
||||
return {
|
||||
"processed_nodes": processed_nodes,
|
||||
"processed_edges": processed_edges,
|
||||
"embed_url": embed_url,
|
||||
"local_viz_data": local_viz_data,
|
||||
"stats": {
|
||||
"node_count": len(processed_nodes),
|
||||
"edge_count": len(processed_edges),
|
||||
"gpu_accelerated": request.gpu_acceleration,
|
||||
"clustered": request.clustering,
|
||||
"layout_type": request.layout_type,
|
||||
"has_embed_url": embed_url is not None,
|
||||
"has_local_viz": local_viz_data is not None
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def _apply_gpu_acceleration(self, g, request: VisualizationRequest):
|
||||
"""Apply GPU acceleration using PyGraphistry's vector processing"""
|
||||
try:
|
||||
if not request.gpu_acceleration:
|
||||
logger.info("GPU acceleration disabled by request")
|
||||
return g
|
||||
|
||||
logger.info("=== GPU ACCELERATION ATTEMPT ===")
|
||||
logger.info(f"PyGraphistry object type: {type(g)}")
|
||||
logger.info(f"Available methods: {[method for method in dir(g) if not method.startswith('_')]}")
|
||||
|
||||
# Check what GPU methods are actually available
|
||||
has_compute_igraph = hasattr(g, 'compute_igraph')
|
||||
has_umap = hasattr(g, 'umap')
|
||||
logger.info(f"Has compute_igraph: {has_compute_igraph}")
|
||||
logger.info(f"Has UMAP: {has_umap}")
|
||||
|
||||
gpu_operations_successful = 0
|
||||
total_gpu_operations = 0
|
||||
|
||||
# Compute centrality measures if available
|
||||
total_gpu_operations += 1
|
||||
try:
|
||||
if has_compute_igraph and len(g._nodes) < 50000: # Limit for performance
|
||||
logger.info("Attempting PageRank computation...")
|
||||
g = g.compute_igraph('pagerank', out_col='pagerank')
|
||||
logger.info("✓ SUCCESS: Computed PageRank centrality with GPU")
|
||||
gpu_operations_successful += 1
|
||||
else:
|
||||
reason = "too many nodes" if len(g._nodes) >= 50000 else "compute_igraph not available"
|
||||
logger.warning(f"✗ SKIPPED: PageRank computation ({reason})")
|
||||
except Exception as e:
|
||||
logger.warning(f"✗ FAILED: PageRank computation failed: {e}")
|
||||
|
||||
# Apply UMAP for node positioning if available and beneficial
|
||||
total_gpu_operations += 1
|
||||
try:
|
||||
if has_umap and len(g._nodes) > 100 and len(g._nodes) < 10000:
|
||||
logger.info("Attempting UMAP for node positioning...")
|
||||
g = g.umap()
|
||||
logger.info("✓ SUCCESS: Applied UMAP for node positioning")
|
||||
gpu_operations_successful += 1
|
||||
else:
|
||||
reason = ("UMAP not available" if not has_umap else
|
||||
"too few nodes" if len(g._nodes) <= 100 else "too many nodes")
|
||||
logger.warning(f"✗ SKIPPED: UMAP processing ({reason})")
|
||||
except Exception as e:
|
||||
logger.warning(f"✗ FAILED: UMAP processing failed: {e}")
|
||||
|
||||
logger.info(f"=== GPU ACCELERATION SUMMARY ===")
|
||||
logger.info(f"GPU operations successful: {gpu_operations_successful}/{total_gpu_operations}")
|
||||
logger.info(f"GPU utilization: {(gpu_operations_successful/total_gpu_operations)*100:.1f}%")
|
||||
|
||||
return g
|
||||
except Exception as e:
|
||||
logger.warning(f"GPU acceleration failed completely, falling back to CPU: {e}")
|
||||
return g
|
||||
|
||||
async def _apply_clustering(self, g):
|
||||
"""Apply GPU-accelerated clustering"""
|
||||
try:
|
||||
logger.info("=== CLUSTERING ATTEMPT ===")
|
||||
|
||||
# Use PyGraphistry's built-in clustering if available
|
||||
if hasattr(g, 'compute_igraph') and len(g._nodes) < 20000: # Limit for performance
|
||||
logger.info("Attempting Leiden community detection...")
|
||||
try:
|
||||
g = g.compute_igraph('community_leiden', out_col='cluster')
|
||||
logger.info("✓ SUCCESS: Applied Leiden community detection")
|
||||
return g
|
||||
except Exception as e:
|
||||
logger.warning(f"✗ FAILED: Leiden clustering failed: {e}")
|
||||
logger.info("Attempting Louvain community detection as fallback...")
|
||||
try:
|
||||
g = g.compute_igraph('community_louvain', out_col='cluster')
|
||||
logger.info("✓ SUCCESS: Applied Louvain community detection")
|
||||
return g
|
||||
except Exception as e2:
|
||||
logger.warning(f"✗ FAILED: Louvain clustering also failed: {e2}")
|
||||
else:
|
||||
reason = "too many nodes" if len(g._nodes) >= 20000 else "compute_igraph not available"
|
||||
logger.warning(f"✗ SKIPPED: Clustering ({reason})")
|
||||
|
||||
logger.info("=== CLUSTERING SUMMARY: No clustering applied ===")
|
||||
return g
|
||||
except Exception as e:
|
||||
logger.warning(f"Clustering failed completely: {e}")
|
||||
return g
|
||||
|
||||
async def _generate_layout(self, g, layout_type: str = "force"):
|
||||
"""Generate layout using PyGraphistry's algorithms"""
|
||||
try:
|
||||
logger.info(f"Generating {layout_type} layout")
|
||||
|
||||
# Only apply layout computation for reasonable graph sizes
|
||||
if len(g._nodes) > 50000:
|
||||
logger.info("Skipping layout computation for very large graph")
|
||||
return g
|
||||
|
||||
if hasattr(g, 'compute_igraph'):
|
||||
try:
|
||||
if layout_type == "force":
|
||||
g = g.compute_igraph('layout_fruchterman_reingold', out_cols=['x', 'y'])
|
||||
logger.info("Applied Fruchterman-Reingold force layout")
|
||||
elif layout_type == "circular":
|
||||
g = g.compute_igraph('layout_circle', out_cols=['x', 'y'])
|
||||
logger.info("Applied circular layout")
|
||||
elif layout_type == "hierarchical":
|
||||
g = g.compute_igraph('layout_sugiyama', out_cols=['x', 'y'])
|
||||
logger.info("Applied hierarchical layout")
|
||||
else:
|
||||
# Default to force-directed
|
||||
g = g.compute_igraph('layout_fruchterman_reingold', out_cols=['x', 'y'])
|
||||
logger.info("Applied default force layout")
|
||||
except Exception as e:
|
||||
logger.warning(f"Layout computation failed: {e}")
|
||||
else:
|
||||
logger.info("Layout computation not available, using default positioning")
|
||||
|
||||
return g
|
||||
except Exception as e:
|
||||
logger.warning(f"Layout generation failed: {e}")
|
||||
return g
|
||||
|
||||
def _create_local_viz_data(self, g, processed_nodes: List[Dict], processed_edges: List[Dict]) -> Dict[str, Any]:
|
||||
"""Create local visualization data when embed URL cannot be generated"""
|
||||
try:
|
||||
# Extract layout positions if available
|
||||
positions = {}
|
||||
if g._nodes is not None and 'x' in g._nodes.columns and 'y' in g._nodes.columns:
|
||||
for _, row in g._nodes.iterrows():
|
||||
node_id = row.get('id', row.name)
|
||||
positions[str(node_id)] = {
|
||||
'x': float(row['x']) if pd.notna(row['x']) else 0,
|
||||
'y': float(row['y']) if pd.notna(row['y']) else 0
|
||||
}
|
||||
|
||||
# Add cluster information if available
|
||||
clusters = {}
|
||||
if g._nodes is not None and 'cluster' in g._nodes.columns:
|
||||
for _, row in g._nodes.iterrows():
|
||||
node_id = row.get('id', row.name)
|
||||
if pd.notna(row['cluster']):
|
||||
clusters[str(node_id)] = int(row['cluster'])
|
||||
|
||||
# Create enhanced nodes with layout and cluster info
|
||||
enhanced_nodes = []
|
||||
for node in processed_nodes:
|
||||
enhanced_node = node.copy()
|
||||
node_id = str(node.get('id', ''))
|
||||
|
||||
if node_id in positions:
|
||||
enhanced_node.update(positions[node_id])
|
||||
|
||||
if node_id in clusters:
|
||||
enhanced_node['cluster'] = clusters[node_id]
|
||||
|
||||
enhanced_nodes.append(enhanced_node)
|
||||
|
||||
return {
|
||||
"nodes": enhanced_nodes,
|
||||
"edges": processed_edges,
|
||||
"positions": positions,
|
||||
"clusters": clusters,
|
||||
"layout_computed": len(positions) > 0,
|
||||
"clusters_computed": len(clusters) > 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create local visualization data: {e}")
|
||||
return {
|
||||
"nodes": processed_nodes,
|
||||
"edges": processed_edges,
|
||||
"positions": {},
|
||||
"clusters": {},
|
||||
"layout_computed": False,
|
||||
"clusters_computed": False
|
||||
}
|
||||
|
||||
async def get_graph_stats(self, graph_data: GraphData) -> Dict[str, Any]:
|
||||
"""Get GPU-accelerated graph statistics"""
|
||||
try:
|
||||
nodes_df = pd.DataFrame(graph_data.nodes)
|
||||
edges_df = pd.DataFrame(graph_data.links)
|
||||
|
||||
g = graphistry.edges(edges_df, 'source', 'target').nodes(nodes_df, 'id')
|
||||
|
||||
# Compute various graph metrics using GPU acceleration
|
||||
stats = {
|
||||
"node_count": len(nodes_df),
|
||||
"edge_count": len(edges_df),
|
||||
"density": len(edges_df) / (len(nodes_df) * (len(nodes_df) - 1)) if len(nodes_df) > 1 else 0,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Add centrality measures if possible
|
||||
try:
|
||||
if len(nodes_df) < 10000 and hasattr(g, 'compute_igraph'): # Only for reasonably sized graphs
|
||||
g_with_metrics = g.compute_igraph('pagerank', out_col='pagerank')
|
||||
|
||||
if g_with_metrics._nodes is not None and 'pagerank' in g_with_metrics._nodes.columns:
|
||||
pagerank_data = g_with_metrics._nodes['pagerank'].to_list()
|
||||
stats.update({
|
||||
"avg_pagerank": float(np.mean(pagerank_data)),
|
||||
"max_pagerank": float(np.max(pagerank_data))
|
||||
})
|
||||
logger.info("Computed PageRank statistics")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not compute centrality measures: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error computing graph stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="PyGraphistry GPU Visualization Service", version="1.0.0")
|
||||
service = PyGraphistryService()
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate_graph(request: GraphGenerationRequest):
|
||||
"""Start graph generation as background task"""
|
||||
if request.num_nodes > 1000000:
|
||||
raise HTTPException(status_code=400, detail="Maximum 1 million nodes allowed")
|
||||
|
||||
task_id = await service.start_graph_generation(request)
|
||||
return {"task_id": task_id, "status": "started"}
|
||||
|
||||
@app.get("/api/generate/{task_id}")
|
||||
async def get_generation_status(task_id: str):
|
||||
"""Get status of graph generation task"""
|
||||
status = service.get_generation_status(task_id)
|
||||
if not status:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return status
|
||||
|
||||
@app.post("/api/visualize")
|
||||
async def visualize_graph(request: VisualizationRequest):
|
||||
"""Process graph data with PyGraphistry GPU acceleration"""
|
||||
return await service.process_graph_data(request)
|
||||
|
||||
@app.post("/api/stats")
|
||||
async def get_graph_statistics(graph_data: GraphData):
|
||||
"""Get GPU-accelerated graph statistics"""
|
||||
return await service.get_graph_stats(graph_data)
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"pygraphistry_initialized": service.initialized,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/patterns")
|
||||
async def get_available_patterns():
|
||||
"""Get list of available graph generation patterns"""
|
||||
return {
|
||||
"patterns": [
|
||||
{
|
||||
"name": pattern.value,
|
||||
"description": {
|
||||
GraphPattern.RANDOM: "Random graph using Erdős–Rényi model",
|
||||
GraphPattern.SCALE_FREE: "Scale-free graph using Barabási–Albert model",
|
||||
GraphPattern.SMALL_WORLD: "Small-world graph using Watts-Strogatz model",
|
||||
GraphPattern.CLUSTERED: "Clustered graph with community structure",
|
||||
GraphPattern.HIERARCHICAL: "Hierarchical tree-like graph with cross-links",
|
||||
GraphPattern.GRID: "2D or 3D grid graph"
|
||||
}[pattern]
|
||||
} for pattern in GraphPattern
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,800 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Remote WebGPU Clustering Service - CuPy Version with Semantic Clustering
|
||||
|
||||
Provides GPU-accelerated graph clustering using CuPy instead of cuDF to avoid segfaults.
|
||||
Uses stable CuPy operations for GPU clustering while maintaining the same API.
|
||||
Enhanced with semantic clustering based on node names and content similarity.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple, Union
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||
from fastapi.responses import HTMLResponse, StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
import time
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
# Import semantic clustering
|
||||
from semantic_clustering_service import SemanticClusteringEngine, cluster_nodes_by_similarity
|
||||
|
||||
# GPU-accelerated imports
|
||||
try:
|
||||
import cupy as cp
|
||||
HAS_CUPY = True
|
||||
print("✓ CuPy available for stable GPU clustering")
|
||||
except ImportError:
|
||||
HAS_CUPY = False
|
||||
print("⚠ CuPy not available, falling back to CPU")
|
||||
|
||||
# Optional cuGraph for force simulation (avoid cuDF operations)
|
||||
try:
|
||||
import cugraph
|
||||
import cudf
|
||||
HAS_CUGRAPH = True
|
||||
print("✓ cuGraph available for force simulation")
|
||||
except ImportError:
|
||||
HAS_CUGRAPH = False
|
||||
print("⚠ cuGraph not available")
|
||||
import networkx as nx
|
||||
|
||||
# WebRTC streaming imports
|
||||
try:
|
||||
import cv2
|
||||
import PIL.Image as PILImage
|
||||
HAS_OPENCV = True
|
||||
print("✓ OpenCV available for WebRTC streaming")
|
||||
except ImportError:
|
||||
HAS_OPENCV = False
|
||||
print("⚠ OpenCV not available, WebRTC streaming disabled")
|
||||
|
||||
# WebGL rendering imports
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import plotly.graph_objects as go
|
||||
import plotly.io as pio
|
||||
pio.renderers.default = "json"
|
||||
HAS_PLOTTING = True
|
||||
print("✓ Plotting libraries available for server-side rendering")
|
||||
except ImportError:
|
||||
HAS_PLOTTING = False
|
||||
print("⚠ Plotting libraries not available")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GraphData(BaseModel):
|
||||
nodes: List[Dict[str, Any]]
|
||||
links: List[Dict[str, Any]]
|
||||
|
||||
class ClusteringMode(str):
|
||||
HYBRID = "hybrid"
|
||||
WEBRTC_STREAM = "webrtc_stream"
|
||||
|
||||
class RemoteClusteringRequest(BaseModel):
|
||||
graph_data: GraphData
|
||||
mode: str = ClusteringMode.HYBRID
|
||||
cluster_dimensions: List[int] = [32, 18, 24]
|
||||
force_simulation: bool = True
|
||||
max_iterations: int = 100
|
||||
webrtc_options: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Semantic clustering options
|
||||
clustering_method: str = "hybrid" # "spatial", "semantic", "hybrid"
|
||||
semantic_algorithm: str = "hierarchical" # "hierarchical", "kmeans", "dbscan"
|
||||
n_clusters: Optional[int] = None
|
||||
similarity_threshold: float = 0.7
|
||||
|
||||
# Hybrid clustering weights
|
||||
name_weight: float = 0.6
|
||||
content_weight: float = 0.3
|
||||
spatial_weight: float = 0.1
|
||||
|
||||
class ClusteringResult(BaseModel):
|
||||
clustered_nodes: List[Dict[str, Any]]
|
||||
cluster_info: Dict[str, Any]
|
||||
processing_time: float
|
||||
mode: str
|
||||
session_id: Optional[str] = None
|
||||
|
||||
class WebRTCSession(BaseModel):
|
||||
session_id: str
|
||||
client_id: str
|
||||
created_at: datetime
|
||||
last_frame_time: datetime
|
||||
is_active: bool = True
|
||||
|
||||
class CuPyClusteringEngine:
|
||||
"""
|
||||
Stable GPU clustering using CuPy arrays instead of cuDF to avoid segfaults
|
||||
"""
|
||||
|
||||
def __init__(self, cluster_dimensions: Tuple[int, int, int] = (32, 18, 24)):
|
||||
self.cluster_dimensions = cluster_dimensions
|
||||
self.cluster_count = cluster_dimensions[0] * cluster_dimensions[1] * cluster_dimensions[2]
|
||||
self.has_gpu = HAS_CUPY
|
||||
logger.info(f"CuPy clustering engine initialized with {self.cluster_count} clusters")
|
||||
|
||||
def cluster_nodes_gpu(self, nodes: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Perform advanced GPU-accelerated clustering using RAPIDS cuML algorithms
|
||||
"""
|
||||
if not self.has_gpu:
|
||||
return self._cluster_nodes_cpu(nodes)
|
||||
|
||||
try:
|
||||
from cuml.cluster import KMeans, DBSCAN, HDBSCAN
|
||||
import cupy as cp
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract coordinates and prepare feature matrix
|
||||
coordinates = []
|
||||
for node in nodes:
|
||||
x = float(node.get('x', 0))
|
||||
y = float(node.get('y', 0))
|
||||
z = float(node.get('z', 0))
|
||||
coordinates.append([x, y, z])
|
||||
|
||||
# Create GPU feature matrix
|
||||
X = cp.array(coordinates, dtype=cp.float32)
|
||||
n_samples = X.shape[0]
|
||||
|
||||
print(f"🚀 GPU clustering {n_samples} nodes with RAPIDS cuML...")
|
||||
|
||||
# Choose clustering algorithm optimized for performance
|
||||
# KMeans is fastest and works well for most graph clustering scenarios
|
||||
if n_samples < 5000:
|
||||
# Small datasets: moderate cluster count
|
||||
n_clusters = min(max(int(np.sqrt(n_samples / 2)), 3), 25)
|
||||
clusterer = KMeans(n_clusters=n_clusters, random_state=42, max_iter=100)
|
||||
algorithm_name = f"KMeans(k={n_clusters})"
|
||||
|
||||
elif n_samples < 25000:
|
||||
# Medium datasets: higher cluster count for better granularity
|
||||
n_clusters = min(max(int(np.sqrt(n_samples / 1.5)), 10), 50)
|
||||
clusterer = KMeans(n_clusters=n_clusters, random_state=42, max_iter=150)
|
||||
algorithm_name = f"KMeans(k={n_clusters})"
|
||||
|
||||
else:
|
||||
# Large datasets: many clusters but capped for performance
|
||||
n_clusters = min(max(int(np.sqrt(n_samples)), 20), 100)
|
||||
clusterer = KMeans(n_clusters=n_clusters, random_state=42, max_iter=200)
|
||||
algorithm_name = f"KMeans(k={n_clusters})"
|
||||
|
||||
# Perform GPU clustering
|
||||
cluster_labels = clusterer.fit_predict(X)
|
||||
|
||||
# Convert results back to CPU
|
||||
if hasattr(cluster_labels, 'get'):
|
||||
cluster_result = cluster_labels.get()
|
||||
else:
|
||||
cluster_result = cp.asarray(cluster_labels).get()
|
||||
|
||||
# Update nodes with clustering results
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
cluster_id = int(cluster_result[i])
|
||||
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_index': cluster_id,
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
# Generate cluster statistics
|
||||
unique_clusters = len(np.unique(cluster_result))
|
||||
noise_points = 0 # KMeans doesn't produce noise points
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
print(f"✅ {algorithm_name} completed: {unique_clusters} clusters, {noise_points} noise points in {processing_time:.4f}s")
|
||||
|
||||
# Apply intelligent subsampling for large datasets
|
||||
if len(nodes) > 10000:
|
||||
print(f"🎯 Large dataset detected ({len(nodes)} nodes), applying cluster-based subsampling...")
|
||||
clustered_nodes = self._apply_cluster_subsampling(clustered_nodes, cluster_result, target_nodes=5000)
|
||||
print(f"✅ Subsampled to {len(clustered_nodes)} representative nodes")
|
||||
|
||||
cluster_info = {
|
||||
'total_clusters': self.cluster_count,
|
||||
'used_clusters': unique_clusters,
|
||||
'cluster_dimensions': self.cluster_dimensions,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': True,
|
||||
'engine': 'RAPIDS cuML',
|
||||
'algorithm': algorithm_name,
|
||||
'noise_points': int(noise_points),
|
||||
'original_node_count': len(nodes),
|
||||
'rendered_node_count': len(clustered_nodes),
|
||||
'subsampled': len(nodes) > 10000
|
||||
}
|
||||
|
||||
logger.info(f"CuPy GPU clustering completed in {processing_time:.3f}s for {len(nodes)} nodes -> {unique_clusters} clusters")
|
||||
return clustered_nodes, cluster_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAPIDS cuML GPU clustering failed: {e}")
|
||||
import traceback
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
print(f"❌ GPU clustering error: {e}")
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
return self._cluster_nodes_cpu(nodes)
|
||||
|
||||
def _apply_cluster_subsampling(self, clustered_nodes: List[Dict[str, Any]], cluster_labels: np.ndarray, target_nodes: int = 5000) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply intelligent cluster-based subsampling to reduce rendering load while preserving cluster structure.
|
||||
|
||||
Strategy:
|
||||
1. Keep cluster centroids (most representative nodes)
|
||||
2. Keep boundary nodes (cluster edges for visual separation)
|
||||
3. Sample remaining nodes proportionally from each cluster
|
||||
4. Always keep noise points (outliers are important)
|
||||
"""
|
||||
import cupy as cp
|
||||
|
||||
# Separate nodes by cluster
|
||||
cluster_groups = {}
|
||||
noise_nodes = []
|
||||
|
||||
for i, node in enumerate(clustered_nodes):
|
||||
cluster_id = cluster_labels[i]
|
||||
if cluster_id == -1: # Noise points
|
||||
noise_nodes.append(node)
|
||||
else:
|
||||
if cluster_id not in cluster_groups:
|
||||
cluster_groups[cluster_id] = []
|
||||
cluster_groups[cluster_id].append((i, node))
|
||||
|
||||
# Calculate sampling allocation
|
||||
total_clusters = len(cluster_groups)
|
||||
noise_count = len(noise_nodes)
|
||||
|
||||
# Reserve space for noise points and ensure minimum representation
|
||||
available_nodes = max(target_nodes - noise_count, total_clusters * 3) # At least 3 nodes per cluster
|
||||
|
||||
selected_nodes = []
|
||||
|
||||
# Include noise points if they exist (DBSCAN/HDBSCAN only)
|
||||
if noise_nodes:
|
||||
selected_nodes.extend(noise_nodes)
|
||||
print(f" 📍 Kept {len(noise_nodes)} noise points")
|
||||
else:
|
||||
print(f" 📍 No noise points (KMeans clustering)")
|
||||
|
||||
# Process each cluster
|
||||
for cluster_id, cluster_nodes in cluster_groups.items():
|
||||
cluster_size = len(cluster_nodes)
|
||||
|
||||
if cluster_size == 0:
|
||||
continue
|
||||
|
||||
# Calculate how many nodes to keep from this cluster
|
||||
# Larger clusters get more representation, but with diminishing returns
|
||||
cluster_weight = min(cluster_size / len(clustered_nodes), 0.1) # Cap at 10% weight
|
||||
target_from_cluster = max(3, int(available_nodes * cluster_weight)) # Minimum 3 per cluster
|
||||
target_from_cluster = min(target_from_cluster, cluster_size) # Don't exceed cluster size
|
||||
|
||||
if target_from_cluster >= cluster_size:
|
||||
# Keep all nodes from small clusters
|
||||
selected_nodes.extend([node for _, node in cluster_nodes])
|
||||
else:
|
||||
# Intelligent sampling for large clusters
|
||||
cluster_coords = np.array([[float(node.get('x', 0)), float(node.get('y', 0)), float(node.get('z', 0))] for _, node in cluster_nodes])
|
||||
|
||||
# Find cluster centroid
|
||||
centroid = np.mean(cluster_coords, axis=0)
|
||||
|
||||
# Calculate distances from centroid
|
||||
distances = np.linalg.norm(cluster_coords - centroid, axis=1)
|
||||
|
||||
# Select representative nodes:
|
||||
# 1. Centroid node (closest to center)
|
||||
centroid_idx = np.argmin(distances)
|
||||
selected_indices = {centroid_idx}
|
||||
|
||||
# 2. Boundary nodes (furthest from center for cluster separation)
|
||||
if target_from_cluster > 1:
|
||||
boundary_count = min(2, target_from_cluster - 1)
|
||||
boundary_indices = np.argsort(distances)[-boundary_count:]
|
||||
selected_indices.update(boundary_indices)
|
||||
|
||||
# 3. Random sampling for remaining slots
|
||||
remaining_slots = target_from_cluster - len(selected_indices)
|
||||
if remaining_slots > 0:
|
||||
available_indices = set(range(len(cluster_nodes))) - selected_indices
|
||||
if available_indices:
|
||||
random_indices = np.random.choice(list(available_indices),
|
||||
size=min(remaining_slots, len(available_indices)),
|
||||
replace=False)
|
||||
selected_indices.update(random_indices)
|
||||
|
||||
# Add selected nodes
|
||||
for idx in selected_indices:
|
||||
selected_nodes.append(cluster_nodes[idx][1])
|
||||
|
||||
print(f" 🎨 Cluster sampling: {len(cluster_groups)} clusters, {len(selected_nodes)} total nodes")
|
||||
return selected_nodes
|
||||
|
||||
def _cluster_nodes_cpu(self, nodes: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""CPU fallback clustering implementation"""
|
||||
start_time = time.time()
|
||||
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
# Apply same clustering logic as GPU version
|
||||
x = float(node.get('x', 0))
|
||||
y = float(node.get('y', 0))
|
||||
z = float(node.get('z', 0))
|
||||
|
||||
# Normalize positions
|
||||
norm_x = max(0.0, min(0.999, x / 100.0 + 0.5))
|
||||
norm_y = max(0.0, min(0.999, y / 100.0 + 0.5))
|
||||
norm_z = max(0.001, min(0.999, z / 100.0 + 0.5))
|
||||
|
||||
# Apply logarithmic scaling to Z
|
||||
log_z = max(0.0, min(0.999, np.log(norm_z) / np.log(0.999)))
|
||||
|
||||
# Calculate cluster indices
|
||||
cluster_x = min(self.cluster_dimensions[0] - 1, int(norm_x * self.cluster_dimensions[0]))
|
||||
cluster_y = min(self.cluster_dimensions[1] - 1, int(norm_y * self.cluster_dimensions[1]))
|
||||
cluster_z = min(self.cluster_dimensions[2] - 1, int(log_z * self.cluster_dimensions[2]))
|
||||
|
||||
cluster_index = (cluster_x +
|
||||
cluster_y * self.cluster_dimensions[0] +
|
||||
cluster_z * self.cluster_dimensions[0] * self.cluster_dimensions[1])
|
||||
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_index': cluster_index,
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
cluster_info = {
|
||||
'total_clusters': self.cluster_count,
|
||||
'cluster_dimensions': self.cluster_dimensions,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': False,
|
||||
'engine': 'CPU'
|
||||
}
|
||||
|
||||
logger.info(f"CPU clustering completed in {processing_time:.3f}s for {len(nodes)} nodes")
|
||||
return clustered_nodes, cluster_info
|
||||
|
||||
class ForceSimulationEngine:
|
||||
"""
|
||||
GPU-accelerated force simulation for graph layout
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.has_gpu = HAS_CUGRAPH
|
||||
|
||||
def simulate_forces(self, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]], max_iterations: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Run force-directed layout simulation"""
|
||||
|
||||
if not self.has_gpu or not links:
|
||||
return self._simulate_forces_cpu(nodes, links, max_iterations)
|
||||
|
||||
try:
|
||||
return self._simulate_forces_gpu(nodes, links, max_iterations)
|
||||
except Exception as e:
|
||||
logger.error(f"GPU force simulation failed: {e}")
|
||||
return self._simulate_forces_cpu(nodes, links, max_iterations)
|
||||
|
||||
def _simulate_forces_gpu(self, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]], max_iterations: int) -> List[Dict[str, Any]]:
|
||||
"""GPU-accelerated force simulation using cuGraph (avoid cuDF operations)"""
|
||||
|
||||
# Create simple edge list for cuGraph
|
||||
edge_list = []
|
||||
for link in links:
|
||||
source_id = str(link.get('source', ''))
|
||||
target_id = str(link.get('target', ''))
|
||||
edge_list.append([source_id, target_id])
|
||||
|
||||
if not edge_list:
|
||||
return nodes
|
||||
|
||||
try:
|
||||
# Use NetworkX for safer force simulation
|
||||
return self._simulate_forces_cpu(nodes, links, max_iterations)
|
||||
except Exception as e:
|
||||
logger.warning(f"Force simulation failed: {e}")
|
||||
return nodes
|
||||
|
||||
def _simulate_forces_cpu(self, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]], max_iterations: int) -> List[Dict[str, Any]]:
|
||||
"""CPU fallback force simulation using NetworkX"""
|
||||
|
||||
import networkx as nx
|
||||
|
||||
G = nx.Graph()
|
||||
|
||||
# Add nodes
|
||||
for node in nodes:
|
||||
G.add_node(str(node.get('id', '')), **node)
|
||||
|
||||
# Add edges
|
||||
for link in links:
|
||||
source = str(link.get('source', ''))
|
||||
target = str(link.get('target', ''))
|
||||
G.add_edge(source, target)
|
||||
|
||||
# Compute spring layout
|
||||
pos = nx.spring_layout(G, iterations=max_iterations, k=1.0)
|
||||
|
||||
# Update node positions
|
||||
updated_nodes = []
|
||||
for node in nodes:
|
||||
node_id = str(node.get('id', ''))
|
||||
if node_id in pos:
|
||||
x, y = pos[node_id]
|
||||
updated_node = {**node, 'x': float(x * 100), 'y': float(y * 100)}
|
||||
else:
|
||||
updated_node = node
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
return updated_nodes
|
||||
|
||||
class WebRTCStreamingEngine:
|
||||
"""WebRTC streaming engine for real-time graph visualization streaming"""
|
||||
|
||||
def __init__(self):
|
||||
self.has_rendering = HAS_PLOTTING and HAS_OPENCV
|
||||
self.active_sessions: Dict[str, WebRTCSession] = {}
|
||||
self.frame_buffer: Dict[str, bytes] = {}
|
||||
|
||||
def create_session(self, client_id: str) -> str:
|
||||
"""Create new WebRTC streaming session"""
|
||||
session_id = str(uuid.uuid4())
|
||||
session = WebRTCSession(
|
||||
session_id=session_id,
|
||||
client_id=client_id,
|
||||
created_at=datetime.now(),
|
||||
last_frame_time=datetime.now()
|
||||
)
|
||||
self.active_sessions[session_id] = session
|
||||
logger.info(f"Created WebRTC session {session_id} for client {client_id}")
|
||||
return session_id
|
||||
|
||||
def render_graph_frame(self, session_id: str, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]]) -> bool:
|
||||
"""Render graph to frame buffer for streaming"""
|
||||
|
||||
if not self.has_rendering:
|
||||
return False
|
||||
|
||||
if session_id not in self.active_sessions:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create 3D plotly visualization
|
||||
node_x = [node.get('x', 0) for node in nodes]
|
||||
node_y = [node.get('y', 0) for node in nodes]
|
||||
node_z = [node.get('z', 0) for node in nodes]
|
||||
node_text = [node.get('name', f"Node {i}") for i, node in enumerate(nodes)]
|
||||
node_colors = [node.get('cluster_index', 0) for node in nodes]
|
||||
|
||||
# Create node trace
|
||||
node_trace = go.Scatter3d(
|
||||
x=node_x, y=node_y, z=node_z,
|
||||
mode='markers',
|
||||
marker=dict(size=8, color=node_colors, colorscale='Viridis', showscale=True),
|
||||
text=node_text,
|
||||
hovertemplate='%{text}<br>(%{x:.1f}, %{y:.1f}, %{z:.1f})<extra></extra>',
|
||||
name='Nodes'
|
||||
)
|
||||
|
||||
# Create edge traces
|
||||
edge_traces = []
|
||||
for link in links:
|
||||
source_idx = None
|
||||
target_idx = None
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
if str(node.get('id', '')) == str(link.get('source', '')):
|
||||
source_idx = i
|
||||
if str(node.get('id', '')) == str(link.get('target', '')):
|
||||
target_idx = i
|
||||
|
||||
if source_idx is not None and target_idx is not None:
|
||||
edge_trace = go.Scatter3d(
|
||||
x=[node_x[source_idx], node_x[target_idx], None],
|
||||
y=[node_y[source_idx], node_y[target_idx], None],
|
||||
z=[node_z[source_idx], node_z[target_idx], None],
|
||||
mode='lines',
|
||||
line=dict(color='gray', width=2),
|
||||
showlegend=False,
|
||||
hoverinfo='none'
|
||||
)
|
||||
edge_traces.append(edge_trace)
|
||||
|
||||
# Create figure
|
||||
fig = go.Figure(data=[node_trace] + edge_traces)
|
||||
fig.update_layout(
|
||||
title='GPU-Clustered Knowledge Graph (CuPy)',
|
||||
scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z', bgcolor='rgb(10, 10, 10)'),
|
||||
showlegend=False,
|
||||
paper_bgcolor='rgb(10, 10, 10)',
|
||||
plot_bgcolor='rgb(10, 10, 10)',
|
||||
font=dict(color='white')
|
||||
)
|
||||
|
||||
# Convert to image
|
||||
img_bytes = pio.to_image(fig, format='png', width=1200, height=800, engine='kaleido')
|
||||
|
||||
# Store frame in buffer
|
||||
self.frame_buffer[session_id] = img_bytes
|
||||
self.active_sessions[session_id].last_frame_time = datetime.now()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Frame rendering failed for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_frame(self, session_id: str) -> Optional[bytes]:
|
||||
return self.frame_buffer.get(session_id)
|
||||
|
||||
def cleanup_session(self, session_id: str):
|
||||
if session_id in self.active_sessions:
|
||||
del self.active_sessions[session_id]
|
||||
if session_id in self.frame_buffer:
|
||||
del self.frame_buffer[session_id]
|
||||
|
||||
class RemoteWebGPUService:
|
||||
"""Main service class with stable CuPy clustering"""
|
||||
|
||||
def __init__(self):
|
||||
self.clustering_engine = CuPyClusteringEngine()
|
||||
self.force_engine = ForceSimulationEngine()
|
||||
self.webrtc_engine = WebRTCStreamingEngine()
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
async def process_clustering_request(self, request: RemoteClusteringRequest) -> ClusteringResult:
|
||||
"""Process remote clustering request with semantic clustering support"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
nodes = request.graph_data.nodes
|
||||
links = request.graph_data.links
|
||||
|
||||
# Apply force simulation if requested
|
||||
if request.force_simulation:
|
||||
logger.info("Running force simulation...")
|
||||
nodes = self.force_engine.simulate_forces(nodes, links, request.max_iterations)
|
||||
|
||||
# Choose clustering method based on request
|
||||
if request.clustering_method == "spatial":
|
||||
# Use traditional spatial clustering
|
||||
logger.info(f"Spatial clustering {len(nodes)} nodes in {request.mode} mode...")
|
||||
clustered_nodes, cluster_info = self.clustering_engine.cluster_nodes_gpu(nodes)
|
||||
|
||||
elif request.clustering_method == "semantic":
|
||||
# Use semantic clustering based on node names/content
|
||||
logger.info(f"Semantic clustering {len(nodes)} nodes using {request.semantic_algorithm}...")
|
||||
semantic_result = await cluster_nodes_by_similarity(
|
||||
nodes,
|
||||
method="name" if request.semantic_algorithm != "content" else "content",
|
||||
algorithm=request.semantic_algorithm,
|
||||
n_clusters=request.n_clusters,
|
||||
similarity_threshold=request.similarity_threshold
|
||||
)
|
||||
clustered_nodes = semantic_result.clustered_nodes
|
||||
cluster_info = semantic_result.cluster_info
|
||||
|
||||
elif request.clustering_method == "hybrid":
|
||||
# Use hybrid clustering (semantic + spatial)
|
||||
logger.info(f"Hybrid clustering {len(nodes)} nodes...")
|
||||
semantic_result = await cluster_nodes_by_similarity(
|
||||
nodes,
|
||||
method="hybrid",
|
||||
algorithm=request.semantic_algorithm,
|
||||
n_clusters=request.n_clusters,
|
||||
name_weight=request.name_weight,
|
||||
content_weight=request.content_weight,
|
||||
spatial_weight=request.spatial_weight
|
||||
)
|
||||
clustered_nodes = semantic_result.clustered_nodes
|
||||
cluster_info = semantic_result.cluster_info
|
||||
|
||||
else:
|
||||
# Fallback to spatial clustering
|
||||
logger.warning(f"Unknown clustering method '{request.clustering_method}', using spatial")
|
||||
clustered_nodes, cluster_info = self.clustering_engine.cluster_nodes_gpu(nodes)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Add clustering method info to result
|
||||
cluster_info['clustering_method'] = request.clustering_method
|
||||
cluster_info['total_processing_time'] = processing_time
|
||||
|
||||
result = ClusteringResult(
|
||||
clustered_nodes=clustered_nodes,
|
||||
cluster_info=cluster_info,
|
||||
processing_time=processing_time,
|
||||
mode=request.mode
|
||||
)
|
||||
|
||||
# Handle WebRTC streaming mode
|
||||
if request.mode == ClusteringMode.WEBRTC_STREAM:
|
||||
session_id = self.webrtc_engine.create_session("remote_client")
|
||||
success = self.webrtc_engine.render_graph_frame(session_id, clustered_nodes, links)
|
||||
if success:
|
||||
result.session_id = session_id
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Clustering request failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def broadcast_update(self, data: Dict[str, Any]):
|
||||
"""Broadcast updates to connected WebSocket clients"""
|
||||
if not self.active_connections:
|
||||
return
|
||||
|
||||
disconnected = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(data)
|
||||
except Exception:
|
||||
disconnected.append(connection)
|
||||
|
||||
for connection in disconnected:
|
||||
self.active_connections.remove(connection)
|
||||
|
||||
# FastAPI app setup
|
||||
app = FastAPI(
|
||||
title="Remote WebGPU Clustering Service (CuPy)",
|
||||
description="Stable GPU-accelerated graph clustering using CuPy",
|
||||
version="1.1.0"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
service = RemoteWebGPUService()
|
||||
|
||||
@app.post("/api/cluster", response_model=ClusteringResult)
|
||||
async def cluster_graph(request: RemoteClusteringRequest):
|
||||
"""Process graph clustering request"""
|
||||
result = await service.process_clustering_request(request)
|
||||
|
||||
await service.broadcast_update({
|
||||
"type": "clustering_complete",
|
||||
"data": result.dict()
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@app.get("/api/capabilities")
|
||||
async def get_capabilities():
|
||||
"""Get service capabilities"""
|
||||
return {
|
||||
"modes": {
|
||||
"hybrid": {
|
||||
"available": True,
|
||||
"description": "GPU clustering on server, CPU rendering on client"
|
||||
},
|
||||
"webrtc_stream": {
|
||||
"available": service.webrtc_engine.has_rendering,
|
||||
"description": "Full GPU rendering streamed to client browser"
|
||||
}
|
||||
},
|
||||
"clustering_methods": {
|
||||
"spatial": {
|
||||
"available": True,
|
||||
"description": "Traditional spatial/coordinate-based clustering"
|
||||
},
|
||||
"semantic": {
|
||||
"available": True,
|
||||
"description": "Semantic clustering based on node names and content similarity"
|
||||
},
|
||||
"hybrid": {
|
||||
"available": True,
|
||||
"description": "Combined semantic and spatial clustering with configurable weights"
|
||||
}
|
||||
},
|
||||
"clustering_algorithms": {
|
||||
"hierarchical": {
|
||||
"available": True,
|
||||
"description": "Hierarchical agglomerative clustering"
|
||||
},
|
||||
"kmeans": {
|
||||
"available": True,
|
||||
"description": "K-means clustering (GPU accelerated when available)"
|
||||
},
|
||||
"dbscan": {
|
||||
"available": True,
|
||||
"description": "Density-based spatial clustering"
|
||||
}
|
||||
},
|
||||
"gpu_acceleration": {
|
||||
"cupy_available": HAS_CUPY,
|
||||
"cugraph_available": HAS_CUGRAPH,
|
||||
"opencv_available": HAS_OPENCV,
|
||||
"plotting_available": HAS_PLOTTING,
|
||||
"semantic_gpu": HAS_CUPY
|
||||
},
|
||||
"cluster_dimensions": service.clustering_engine.cluster_dimensions,
|
||||
"max_cluster_count": service.clustering_engine.cluster_count
|
||||
}
|
||||
|
||||
@app.get("/api/stream/{session_id}")
|
||||
async def stream_frame(session_id: str):
|
||||
"""Stream rendered frame for WebRTC session"""
|
||||
frame_data = service.webrtc_engine.get_frame(session_id)
|
||||
if not frame_data:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
return StreamingResponse(
|
||||
BytesIO(frame_data),
|
||||
media_type="image/png",
|
||||
headers={"Cache-Control": "no-cache"}
|
||||
)
|
||||
|
||||
@app.delete("/api/stream/{session_id}")
|
||||
async def cleanup_stream(session_id: str):
|
||||
"""Clean up WebRTC streaming session"""
|
||||
service.webrtc_engine.cleanup_session(session_id)
|
||||
return {"status": "cleaned up"}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates"""
|
||||
await websocket.accept()
|
||||
service.active_connections.append(websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
service.active_connections.remove(websocket)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"gpu_available": HAS_CUPY,
|
||||
"webrtc_available": service.webrtc_engine.has_rendering,
|
||||
"active_sessions": len(service.webrtc_engine.active_sessions),
|
||||
"active_connections": len(service.active_connections),
|
||||
"engine": "RAPIDS cuML"
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.environ.get("PORT", 8083))
|
||||
logger.info(f"Starting Remote WebGPU Clustering Service (RAPIDS cuML) on port {port}")
|
||||
logger.info(f"CuPy GPU acceleration: {'✓' if HAS_CUPY else '✗'}")
|
||||
logger.info(f"WebRTC streaming: {'✓' if service.webrtc_engine.has_rendering else '✗'}")
|
||||
|
||||
uvicorn.run(
|
||||
"remote_webgpu_clustering_service:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
log_level="info",
|
||||
reload=False
|
||||
)
|
||||
@ -0,0 +1,582 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Remote WebGPU Clustering Service - CuPy Version
|
||||
|
||||
Provides GPU-accelerated graph clustering using CuPy instead of cuDF to avoid segfaults.
|
||||
Uses stable CuPy operations for GPU clustering while maintaining the same API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple, Union
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||
from fastapi.responses import HTMLResponse, StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
import time
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
# GPU-accelerated imports
|
||||
try:
|
||||
import cupy as cp
|
||||
HAS_CUPY = True
|
||||
print("✓ CuPy available for stable GPU clustering")
|
||||
except ImportError:
|
||||
HAS_CUPY = False
|
||||
print("⚠ CuPy not available, falling back to CPU")
|
||||
|
||||
# Optional cuGraph for force simulation (avoid cuDF operations)
|
||||
try:
|
||||
import cugraph
|
||||
import cudf
|
||||
HAS_CUGRAPH = True
|
||||
print("✓ cuGraph available for force simulation")
|
||||
except ImportError:
|
||||
HAS_CUGRAPH = False
|
||||
print("⚠ cuGraph not available")
|
||||
import networkx as nx
|
||||
|
||||
# WebRTC streaming imports
|
||||
try:
|
||||
import cv2
|
||||
import PIL.Image as PILImage
|
||||
HAS_OPENCV = True
|
||||
print("✓ OpenCV available for WebRTC streaming")
|
||||
except ImportError:
|
||||
HAS_OPENCV = False
|
||||
print("⚠ OpenCV not available, WebRTC streaming disabled")
|
||||
|
||||
# WebGL rendering imports
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import plotly.graph_objects as go
|
||||
import plotly.io as pio
|
||||
pio.renderers.default = "json"
|
||||
HAS_PLOTTING = True
|
||||
print("✓ Plotting libraries available for server-side rendering")
|
||||
except ImportError:
|
||||
HAS_PLOTTING = False
|
||||
print("⚠ Plotting libraries not available")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GraphData(BaseModel):
|
||||
nodes: List[Dict[str, Any]]
|
||||
links: List[Dict[str, Any]]
|
||||
|
||||
class ClusteringMode(str):
|
||||
HYBRID = "hybrid"
|
||||
WEBRTC_STREAM = "webrtc_stream"
|
||||
|
||||
class RemoteClusteringRequest(BaseModel):
|
||||
graph_data: GraphData
|
||||
mode: str = ClusteringMode.HYBRID
|
||||
cluster_dimensions: List[int] = [32, 18, 24]
|
||||
force_simulation: bool = True
|
||||
max_iterations: int = 100
|
||||
webrtc_options: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ClusteringResult(BaseModel):
|
||||
clustered_nodes: List[Dict[str, Any]]
|
||||
cluster_info: Dict[str, Any]
|
||||
processing_time: float
|
||||
mode: str
|
||||
session_id: Optional[str] = None
|
||||
|
||||
class WebRTCSession(BaseModel):
|
||||
session_id: str
|
||||
client_id: str
|
||||
created_at: datetime
|
||||
last_frame_time: datetime
|
||||
is_active: bool = True
|
||||
|
||||
class CuPyClusteringEngine:
|
||||
"""
|
||||
Stable GPU clustering using CuPy arrays instead of cuDF to avoid segfaults
|
||||
"""
|
||||
|
||||
def __init__(self, cluster_dimensions: Tuple[int, int, int] = (32, 18, 24)):
|
||||
self.cluster_dimensions = cluster_dimensions
|
||||
self.cluster_count = cluster_dimensions[0] * cluster_dimensions[1] * cluster_dimensions[2]
|
||||
self.has_gpu = HAS_CUPY
|
||||
logger.info(f"CuPy clustering engine initialized with {self.cluster_count} clusters")
|
||||
|
||||
def cluster_nodes_gpu(self, nodes: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Perform stable GPU-accelerated clustering using CuPy
|
||||
"""
|
||||
if not self.has_gpu:
|
||||
return self._cluster_nodes_cpu(nodes)
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Extract coordinates using CuPy arrays (stable)
|
||||
x_vals = cp.array([float(node.get('x', 0)) for node in nodes])
|
||||
y_vals = cp.array([float(node.get('y', 0)) for node in nodes])
|
||||
z_vals = cp.array([float(node.get('z', 0)) for node in nodes])
|
||||
|
||||
# Apply clustering algorithm (same as WebGPU shader)
|
||||
norm_x = cp.clip((x_vals / 100.0 + 0.5), 0.0, 0.999)
|
||||
norm_y = cp.clip((y_vals / 100.0 + 0.5), 0.0, 0.999)
|
||||
norm_z = cp.clip((z_vals / 100.0 + 0.5), 0.001, 0.999)
|
||||
|
||||
# Apply logarithmic scaling to Z dimension
|
||||
log_z = cp.clip(cp.log(norm_z) / cp.log(0.999), 0.0, 0.999)
|
||||
|
||||
# Calculate cluster indices
|
||||
cluster_x = cp.clip((norm_x * self.cluster_dimensions[0]).astype(cp.int32), 0, self.cluster_dimensions[0] - 1)
|
||||
cluster_y = cp.clip((norm_y * self.cluster_dimensions[1]).astype(cp.int32), 0, self.cluster_dimensions[1] - 1)
|
||||
cluster_z = cp.clip((log_z * self.cluster_dimensions[2]).astype(cp.int32), 0, self.cluster_dimensions[2] - 1)
|
||||
|
||||
# Calculate final cluster index
|
||||
cluster_indices = (cluster_x +
|
||||
cluster_y * self.cluster_dimensions[0] +
|
||||
cluster_z * self.cluster_dimensions[0] * self.cluster_dimensions[1])
|
||||
|
||||
# Convert back to CPU for results
|
||||
cluster_result = cluster_indices.get()
|
||||
|
||||
# Update nodes with clustering results
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_index': int(cluster_result[i]),
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
# Generate cluster statistics
|
||||
unique_clusters = len(np.unique(cluster_result))
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
cluster_info = {
|
||||
'total_clusters': self.cluster_count,
|
||||
'used_clusters': unique_clusters,
|
||||
'cluster_dimensions': self.cluster_dimensions,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': True,
|
||||
'engine': 'CuPy'
|
||||
}
|
||||
|
||||
logger.info(f"CuPy GPU clustering completed in {processing_time:.3f}s for {len(nodes)} nodes -> {unique_clusters} clusters")
|
||||
return clustered_nodes, cluster_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CuPy GPU clustering failed: {e}")
|
||||
return self._cluster_nodes_cpu(nodes)
|
||||
|
||||
def _cluster_nodes_cpu(self, nodes: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""CPU fallback clustering implementation"""
|
||||
start_time = time.time()
|
||||
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
# Apply same clustering logic as GPU version
|
||||
x = float(node.get('x', 0))
|
||||
y = float(node.get('y', 0))
|
||||
z = float(node.get('z', 0))
|
||||
|
||||
# Normalize positions
|
||||
norm_x = max(0.0, min(0.999, x / 100.0 + 0.5))
|
||||
norm_y = max(0.0, min(0.999, y / 100.0 + 0.5))
|
||||
norm_z = max(0.001, min(0.999, z / 100.0 + 0.5))
|
||||
|
||||
# Apply logarithmic scaling to Z
|
||||
log_z = max(0.0, min(0.999, np.log(norm_z) / np.log(0.999)))
|
||||
|
||||
# Calculate cluster indices
|
||||
cluster_x = min(self.cluster_dimensions[0] - 1, int(norm_x * self.cluster_dimensions[0]))
|
||||
cluster_y = min(self.cluster_dimensions[1] - 1, int(norm_y * self.cluster_dimensions[1]))
|
||||
cluster_z = min(self.cluster_dimensions[2] - 1, int(log_z * self.cluster_dimensions[2]))
|
||||
|
||||
cluster_index = (cluster_x +
|
||||
cluster_y * self.cluster_dimensions[0] +
|
||||
cluster_z * self.cluster_dimensions[0] * self.cluster_dimensions[1])
|
||||
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_index': cluster_index,
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
cluster_info = {
|
||||
'total_clusters': self.cluster_count,
|
||||
'cluster_dimensions': self.cluster_dimensions,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': False,
|
||||
'engine': 'CPU'
|
||||
}
|
||||
|
||||
logger.info(f"CPU clustering completed in {processing_time:.3f}s for {len(nodes)} nodes")
|
||||
return clustered_nodes, cluster_info
|
||||
|
||||
class ForceSimulationEngine:
|
||||
"""
|
||||
GPU-accelerated force simulation for graph layout
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.has_gpu = HAS_CUGRAPH
|
||||
|
||||
def simulate_forces(self, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]], max_iterations: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Run force-directed layout simulation"""
|
||||
|
||||
if not self.has_gpu or not links:
|
||||
return self._simulate_forces_cpu(nodes, links, max_iterations)
|
||||
|
||||
try:
|
||||
return self._simulate_forces_gpu(nodes, links, max_iterations)
|
||||
except Exception as e:
|
||||
logger.error(f"GPU force simulation failed: {e}")
|
||||
return self._simulate_forces_cpu(nodes, links, max_iterations)
|
||||
|
||||
def _simulate_forces_gpu(self, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]], max_iterations: int) -> List[Dict[str, Any]]:
|
||||
"""GPU-accelerated force simulation using cuGraph (avoid cuDF operations)"""
|
||||
|
||||
# Create simple edge list for cuGraph
|
||||
edge_list = []
|
||||
for link in links:
|
||||
source_id = str(link.get('source', ''))
|
||||
target_id = str(link.get('target', ''))
|
||||
edge_list.append([source_id, target_id])
|
||||
|
||||
if not edge_list:
|
||||
return nodes
|
||||
|
||||
try:
|
||||
# Use NetworkX for safer force simulation
|
||||
return self._simulate_forces_cpu(nodes, links, max_iterations)
|
||||
except Exception as e:
|
||||
logger.warning(f"Force simulation failed: {e}")
|
||||
return nodes
|
||||
|
||||
def _simulate_forces_cpu(self, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]], max_iterations: int) -> List[Dict[str, Any]]:
|
||||
"""CPU fallback force simulation using NetworkX"""
|
||||
|
||||
import networkx as nx
|
||||
|
||||
G = nx.Graph()
|
||||
|
||||
# Add nodes
|
||||
for node in nodes:
|
||||
G.add_node(str(node.get('id', '')), **node)
|
||||
|
||||
# Add edges
|
||||
for link in links:
|
||||
source = str(link.get('source', ''))
|
||||
target = str(link.get('target', ''))
|
||||
G.add_edge(source, target)
|
||||
|
||||
# Compute spring layout
|
||||
pos = nx.spring_layout(G, iterations=max_iterations, k=1.0)
|
||||
|
||||
# Update node positions
|
||||
updated_nodes = []
|
||||
for node in nodes:
|
||||
node_id = str(node.get('id', ''))
|
||||
if node_id in pos:
|
||||
x, y = pos[node_id]
|
||||
updated_node = {**node, 'x': float(x * 100), 'y': float(y * 100)}
|
||||
else:
|
||||
updated_node = node
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
return updated_nodes
|
||||
|
||||
class WebRTCStreamingEngine:
|
||||
"""WebRTC streaming engine for real-time graph visualization streaming"""
|
||||
|
||||
def __init__(self):
|
||||
self.has_rendering = HAS_PLOTTING and HAS_OPENCV
|
||||
self.active_sessions: Dict[str, WebRTCSession] = {}
|
||||
self.frame_buffer: Dict[str, bytes] = {}
|
||||
|
||||
def create_session(self, client_id: str) -> str:
|
||||
"""Create new WebRTC streaming session"""
|
||||
session_id = str(uuid.uuid4())
|
||||
session = WebRTCSession(
|
||||
session_id=session_id,
|
||||
client_id=client_id,
|
||||
created_at=datetime.now(),
|
||||
last_frame_time=datetime.now()
|
||||
)
|
||||
self.active_sessions[session_id] = session
|
||||
logger.info(f"Created WebRTC session {session_id} for client {client_id}")
|
||||
return session_id
|
||||
|
||||
def render_graph_frame(self, session_id: str, nodes: List[Dict[str, Any]], links: List[Dict[str, Any]]) -> bool:
|
||||
"""Render graph to frame buffer for streaming"""
|
||||
|
||||
if not self.has_rendering:
|
||||
return False
|
||||
|
||||
if session_id not in self.active_sessions:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create 3D plotly visualization
|
||||
node_x = [node.get('x', 0) for node in nodes]
|
||||
node_y = [node.get('y', 0) for node in nodes]
|
||||
node_z = [node.get('z', 0) for node in nodes]
|
||||
node_text = [node.get('name', f"Node {i}") for i, node in enumerate(nodes)]
|
||||
node_colors = [node.get('cluster_index', 0) for node in nodes]
|
||||
|
||||
# Create node trace
|
||||
node_trace = go.Scatter3d(
|
||||
x=node_x, y=node_y, z=node_z,
|
||||
mode='markers',
|
||||
marker=dict(size=8, color=node_colors, colorscale='Viridis', showscale=True),
|
||||
text=node_text,
|
||||
hovertemplate='%{text}<br>(%{x:.1f}, %{y:.1f}, %{z:.1f})<extra></extra>',
|
||||
name='Nodes'
|
||||
)
|
||||
|
||||
# Create edge traces
|
||||
edge_traces = []
|
||||
for link in links:
|
||||
source_idx = None
|
||||
target_idx = None
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
if str(node.get('id', '')) == str(link.get('source', '')):
|
||||
source_idx = i
|
||||
if str(node.get('id', '')) == str(link.get('target', '')):
|
||||
target_idx = i
|
||||
|
||||
if source_idx is not None and target_idx is not None:
|
||||
edge_trace = go.Scatter3d(
|
||||
x=[node_x[source_idx], node_x[target_idx], None],
|
||||
y=[node_y[source_idx], node_y[target_idx], None],
|
||||
z=[node_z[source_idx], node_z[target_idx], None],
|
||||
mode='lines',
|
||||
line=dict(color='gray', width=2),
|
||||
showlegend=False,
|
||||
hoverinfo='none'
|
||||
)
|
||||
edge_traces.append(edge_trace)
|
||||
|
||||
# Create figure
|
||||
fig = go.Figure(data=[node_trace] + edge_traces)
|
||||
fig.update_layout(
|
||||
title='GPU-Clustered Knowledge Graph (CuPy)',
|
||||
scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z', bgcolor='rgb(10, 10, 10)'),
|
||||
showlegend=False,
|
||||
paper_bgcolor='rgb(10, 10, 10)',
|
||||
plot_bgcolor='rgb(10, 10, 10)',
|
||||
font=dict(color='white')
|
||||
)
|
||||
|
||||
# Convert to image
|
||||
img_bytes = pio.to_image(fig, format='png', width=1200, height=800, engine='kaleido')
|
||||
|
||||
# Store frame in buffer
|
||||
self.frame_buffer[session_id] = img_bytes
|
||||
self.active_sessions[session_id].last_frame_time = datetime.now()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Frame rendering failed for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_frame(self, session_id: str) -> Optional[bytes]:
|
||||
return self.frame_buffer.get(session_id)
|
||||
|
||||
def cleanup_session(self, session_id: str):
|
||||
if session_id in self.active_sessions:
|
||||
del self.active_sessions[session_id]
|
||||
if session_id in self.frame_buffer:
|
||||
del self.frame_buffer[session_id]
|
||||
|
||||
class RemoteWebGPUService:
|
||||
"""Main service class with stable CuPy clustering"""
|
||||
|
||||
def __init__(self):
|
||||
self.clustering_engine = CuPyClusteringEngine()
|
||||
self.force_engine = ForceSimulationEngine()
|
||||
self.webrtc_engine = WebRTCStreamingEngine()
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
async def process_clustering_request(self, request: RemoteClusteringRequest) -> ClusteringResult:
|
||||
"""Process remote clustering request"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
nodes = request.graph_data.nodes
|
||||
links = request.graph_data.links
|
||||
|
||||
# Apply force simulation if requested
|
||||
if request.force_simulation:
|
||||
logger.info("Running force simulation...")
|
||||
nodes = self.force_engine.simulate_forces(nodes, links, request.max_iterations)
|
||||
|
||||
# Perform clustering
|
||||
logger.info(f"Clustering {len(nodes)} nodes in {request.mode} mode...")
|
||||
clustered_nodes, cluster_info = self.clustering_engine.cluster_nodes_gpu(nodes)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
result = ClusteringResult(
|
||||
clustered_nodes=clustered_nodes,
|
||||
cluster_info=cluster_info,
|
||||
processing_time=processing_time,
|
||||
mode=request.mode
|
||||
)
|
||||
|
||||
# Handle WebRTC streaming mode
|
||||
if request.mode == ClusteringMode.WEBRTC_STREAM:
|
||||
session_id = self.webrtc_engine.create_session("remote_client")
|
||||
success = self.webrtc_engine.render_graph_frame(session_id, clustered_nodes, links)
|
||||
if success:
|
||||
result.session_id = session_id
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Clustering request failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def broadcast_update(self, data: Dict[str, Any]):
|
||||
"""Broadcast updates to connected WebSocket clients"""
|
||||
if not self.active_connections:
|
||||
return
|
||||
|
||||
disconnected = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(data)
|
||||
except Exception:
|
||||
disconnected.append(connection)
|
||||
|
||||
for connection in disconnected:
|
||||
self.active_connections.remove(connection)
|
||||
|
||||
# FastAPI app setup
|
||||
app = FastAPI(
|
||||
title="Remote WebGPU Clustering Service (CuPy)",
|
||||
description="Stable GPU-accelerated graph clustering using CuPy",
|
||||
version="1.1.0"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
service = RemoteWebGPUService()
|
||||
|
||||
@app.post("/api/cluster", response_model=ClusteringResult)
|
||||
async def cluster_graph(request: RemoteClusteringRequest):
|
||||
"""Process graph clustering request"""
|
||||
result = await service.process_clustering_request(request)
|
||||
|
||||
await service.broadcast_update({
|
||||
"type": "clustering_complete",
|
||||
"data": result.dict()
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@app.get("/api/capabilities")
|
||||
async def get_capabilities():
|
||||
"""Get service capabilities"""
|
||||
return {
|
||||
"modes": {
|
||||
"hybrid": {
|
||||
"available": True,
|
||||
"description": "GPU clustering on server, CPU rendering on client"
|
||||
},
|
||||
"webrtc_stream": {
|
||||
"available": service.webrtc_engine.has_rendering,
|
||||
"description": "Full GPU rendering streamed to client browser"
|
||||
}
|
||||
},
|
||||
"gpu_acceleration": {
|
||||
"cupy_available": HAS_CUPY,
|
||||
"cugraph_available": HAS_CUGRAPH,
|
||||
"opencv_available": HAS_OPENCV,
|
||||
"plotting_available": HAS_PLOTTING
|
||||
},
|
||||
"cluster_dimensions": service.clustering_engine.cluster_dimensions,
|
||||
"max_cluster_count": service.clustering_engine.cluster_count
|
||||
}
|
||||
|
||||
@app.get("/api/stream/{session_id}")
|
||||
async def stream_frame(session_id: str):
|
||||
"""Stream rendered frame for WebRTC session"""
|
||||
frame_data = service.webrtc_engine.get_frame(session_id)
|
||||
if not frame_data:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
return StreamingResponse(
|
||||
BytesIO(frame_data),
|
||||
media_type="image/png",
|
||||
headers={"Cache-Control": "no-cache"}
|
||||
)
|
||||
|
||||
@app.delete("/api/stream/{session_id}")
|
||||
async def cleanup_stream(session_id: str):
|
||||
"""Clean up WebRTC streaming session"""
|
||||
service.webrtc_engine.cleanup_session(session_id)
|
||||
return {"status": "cleaned up"}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates"""
|
||||
await websocket.accept()
|
||||
service.active_connections.append(websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
service.active_connections.remove(websocket)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"gpu_available": HAS_CUPY,
|
||||
"webrtc_available": service.webrtc_engine.has_rendering,
|
||||
"active_sessions": len(service.webrtc_engine.active_sessions),
|
||||
"active_connections": len(service.active_connections),
|
||||
"engine": "CuPy"
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.environ.get("PORT", 8083))
|
||||
logger.info(f"Starting Remote WebGPU Clustering Service (CuPy) on port {port}")
|
||||
logger.info(f"CuPy GPU acceleration: {'✓' if HAS_CUPY else '✗'}")
|
||||
logger.info(f"WebRTC streaming: {'✓' if service.webrtc_engine.has_rendering else '✗'}")
|
||||
|
||||
uvicorn.run(
|
||||
"remote_webgpu_clustering_service_cupy:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
log_level="info",
|
||||
reload=False
|
||||
)
|
||||
@ -0,0 +1,30 @@
|
||||
# Remote WebGPU Clustering Service Dependencies
|
||||
# For GPU-accelerated graph clustering and WebRTC streaming
|
||||
|
||||
# Core FastAPI and web service dependencies
|
||||
fastapi==0.104.1
|
||||
uvicorn==0.24.0
|
||||
websockets==12.0
|
||||
python-multipart==0.0.6
|
||||
|
||||
# RAPIDS dependencies (GPU-accelerated data processing)
|
||||
# These are included in the NVIDIA PyTorch container
|
||||
# cudf, cugraph, cuml, cupy are pre-installed
|
||||
|
||||
# Data processing and scientific computing
|
||||
numpy>=1.24.0
|
||||
pandas>=2.0.0
|
||||
networkx>=3.0
|
||||
|
||||
# WebRTC streaming and visualization dependencies
|
||||
opencv-python-headless==4.8.1.78
|
||||
plotly>=5.17.0
|
||||
kaleido>=0.2.1
|
||||
Pillow>=10.0.0
|
||||
|
||||
# Redis for session management (optional)
|
||||
redis>=5.0.0
|
||||
|
||||
# Additional utilities
|
||||
pydantic>=2.0.0
|
||||
python-dotenv>=1.0.0
|
||||
@ -0,0 +1,14 @@
|
||||
graphistry>=0.32.0
|
||||
pandas>=2.0.0
|
||||
numpy>=1.24.0
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
pydantic>=2.0.0
|
||||
networkx>=3.0 # For efficient graph generation algorithms
|
||||
# cudf, cuml, cugraph are already included in PyG container
|
||||
# cupy>=12.0.0 # Already included in PyG container
|
||||
igraph>=0.10.0 # For additional graph algorithms
|
||||
scikit-learn>=1.3.0 # For additional ML features
|
||||
requests>=2.31.0
|
||||
aiofiles>=23.0.0
|
||||
python-multipart>=0.0.6
|
||||
@ -0,0 +1,600 @@
|
||||
"""
|
||||
Semantic Clustering Service for Knowledge Graphs
|
||||
Groups nodes by semantic similarity of names, types, and content rather than just spatial coordinates
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Tuple, Set, Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import networkx as nx
|
||||
|
||||
# Try to import GPU libraries
|
||||
try:
|
||||
import cupy as cp
|
||||
import cuml
|
||||
from cuml.cluster import KMeans as cuKMeans, DBSCAN as cuDBSCAN
|
||||
HAS_GPU = True
|
||||
print("✅ GPU libraries (CuPy, cuML) available for semantic clustering")
|
||||
except ImportError:
|
||||
HAS_GPU = False
|
||||
print("⚠️ GPU libraries not available, using CPU for semantic clustering")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class SemanticClusterResult:
|
||||
"""Result of semantic clustering operation"""
|
||||
clustered_nodes: List[Dict[str, Any]]
|
||||
cluster_info: Dict[str, Any]
|
||||
similarity_matrix: Optional[np.ndarray] = None
|
||||
cluster_labels: Optional[np.ndarray] = None
|
||||
|
||||
class SemanticSimilarityCalculator:
|
||||
"""Calculate semantic similarity between node names and content"""
|
||||
|
||||
def __init__(self):
|
||||
self.tfidf_vectorizer = TfidfVectorizer(
|
||||
max_features=1000,
|
||||
stop_words='english',
|
||||
ngram_range=(1, 2),
|
||||
lowercase=True
|
||||
)
|
||||
self.fitted = False
|
||||
|
||||
def calculate_name_similarity(self, name1: str, name2: str) -> float:
|
||||
"""Calculate similarity between two node names using multiple methods"""
|
||||
if not name1 or not name2:
|
||||
return 0.0
|
||||
|
||||
name1_clean = self._clean_name(name1)
|
||||
name2_clean = self._clean_name(name2)
|
||||
|
||||
# Method 1: Exact match
|
||||
if name1_clean == name2_clean:
|
||||
return 1.0
|
||||
|
||||
# Method 2: Substring match
|
||||
if name1_clean in name2_clean or name2_clean in name1_clean:
|
||||
return 0.8
|
||||
|
||||
# Method 3: Sequence similarity (Levenshtein-based)
|
||||
seq_similarity = SequenceMatcher(None, name1_clean, name2_clean).ratio()
|
||||
|
||||
# Method 4: Word overlap (Jaccard similarity)
|
||||
words1 = set(name1_clean.split())
|
||||
words2 = set(name2_clean.split())
|
||||
if words1 and words2:
|
||||
jaccard_sim = len(words1.intersection(words2)) / len(words1.union(words2))
|
||||
else:
|
||||
jaccard_sim = 0.0
|
||||
|
||||
# Method 5: Common prefix/suffix
|
||||
prefix_sim = self._prefix_similarity(name1_clean, name2_clean)
|
||||
suffix_sim = self._suffix_similarity(name1_clean, name2_clean)
|
||||
|
||||
# Combine similarities with weights
|
||||
combined_similarity = (
|
||||
seq_similarity * 0.3 +
|
||||
jaccard_sim * 0.4 +
|
||||
prefix_sim * 0.15 +
|
||||
suffix_sim * 0.15
|
||||
)
|
||||
|
||||
return min(combined_similarity, 1.0)
|
||||
|
||||
def calculate_content_similarity(self, nodes: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""Calculate content similarity matrix using TF-IDF"""
|
||||
# Extract text content from nodes
|
||||
texts = []
|
||||
for node in nodes:
|
||||
text_parts = []
|
||||
|
||||
# Add node name
|
||||
if node.get('name'):
|
||||
text_parts.append(str(node['name']))
|
||||
|
||||
# Add node type/group
|
||||
if node.get('group') or node.get('type'):
|
||||
text_parts.append(str(node.get('group', node.get('type', ''))))
|
||||
|
||||
# Add any description or content
|
||||
for key in ['description', 'content', 'label', 'properties']:
|
||||
if node.get(key):
|
||||
text_parts.append(str(node[key]))
|
||||
|
||||
# Combine all text
|
||||
combined_text = ' '.join(text_parts)
|
||||
texts.append(combined_text if combined_text.strip() else node.get('name', 'unnamed'))
|
||||
|
||||
# Calculate TF-IDF similarity
|
||||
if not self.fitted and texts:
|
||||
tfidf_matrix = self.tfidf_vectorizer.fit_transform(texts)
|
||||
self.fitted = True
|
||||
else:
|
||||
tfidf_matrix = self.tfidf_vectorizer.transform(texts)
|
||||
|
||||
# Calculate cosine similarity matrix
|
||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
return similarity_matrix
|
||||
|
||||
def _clean_name(self, name: str) -> str:
|
||||
"""Clean and normalize node name"""
|
||||
if not name:
|
||||
return ""
|
||||
|
||||
# Convert to lowercase
|
||||
cleaned = name.lower().strip()
|
||||
|
||||
# Remove special characters but keep spaces and alphanumeric
|
||||
cleaned = re.sub(r'[^\w\s-]', ' ', cleaned)
|
||||
|
||||
# Normalize whitespace
|
||||
cleaned = re.sub(r'\s+', ' ', cleaned)
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
def _prefix_similarity(self, name1: str, name2: str) -> float:
|
||||
"""Calculate similarity based on common prefix"""
|
||||
min_len = min(len(name1), len(name2))
|
||||
if min_len == 0:
|
||||
return 0.0
|
||||
|
||||
common_prefix = 0
|
||||
for i in range(min_len):
|
||||
if name1[i] == name2[i]:
|
||||
common_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return common_prefix / min_len
|
||||
|
||||
def _suffix_similarity(self, name1: str, name2: str) -> float:
|
||||
"""Calculate similarity based on common suffix"""
|
||||
min_len = min(len(name1), len(name2))
|
||||
if min_len == 0:
|
||||
return 0.0
|
||||
|
||||
common_suffix = 0
|
||||
for i in range(1, min_len + 1):
|
||||
if name1[-i] == name2[-i]:
|
||||
common_suffix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return common_suffix / min_len
|
||||
|
||||
class SemanticClusteringEngine:
|
||||
"""Main semantic clustering engine"""
|
||||
|
||||
def __init__(self, use_gpu: bool = None):
|
||||
self.use_gpu = use_gpu if use_gpu is not None else HAS_GPU
|
||||
self.similarity_calc = SemanticSimilarityCalculator()
|
||||
logger.info(f"Semantic clustering engine initialized (GPU: {self.use_gpu})")
|
||||
|
||||
def cluster_by_name_similarity(
|
||||
self,
|
||||
nodes: List[Dict[str, Any]],
|
||||
algorithm: str = "hierarchical",
|
||||
n_clusters: Optional[int] = None,
|
||||
similarity_threshold: float = 0.7
|
||||
) -> SemanticClusterResult:
|
||||
"""
|
||||
Cluster nodes based on name similarity
|
||||
|
||||
Args:
|
||||
nodes: List of node dictionaries
|
||||
algorithm: 'hierarchical', 'kmeans', 'dbscan'
|
||||
n_clusters: Number of clusters (for kmeans/hierarchical)
|
||||
similarity_threshold: Minimum similarity for clustering (for dbscan)
|
||||
"""
|
||||
start_time = time.time()
|
||||
n_nodes = len(nodes)
|
||||
|
||||
logger.info(f"🧠 Starting semantic clustering of {n_nodes} nodes using {algorithm}")
|
||||
|
||||
if n_nodes < 2:
|
||||
return self._create_single_cluster_result(nodes, start_time)
|
||||
|
||||
# Calculate name similarity matrix
|
||||
similarity_matrix = self._calculate_name_similarity_matrix(nodes)
|
||||
|
||||
# Convert similarity to distance matrix
|
||||
distance_matrix = 1.0 - similarity_matrix
|
||||
|
||||
# Apply clustering algorithm
|
||||
if algorithm == "hierarchical":
|
||||
cluster_labels = self._hierarchical_clustering(
|
||||
distance_matrix, n_clusters or min(10, n_nodes // 2)
|
||||
)
|
||||
elif algorithm == "kmeans":
|
||||
cluster_labels = self._kmeans_clustering(
|
||||
similarity_matrix, n_clusters or min(10, n_nodes // 2)
|
||||
)
|
||||
elif algorithm == "dbscan":
|
||||
cluster_labels = self._dbscan_clustering(
|
||||
distance_matrix, similarity_threshold
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown clustering algorithm: {algorithm}")
|
||||
|
||||
# Create clustered nodes
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_id': int(cluster_labels[i]),
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Calculate cluster statistics
|
||||
unique_clusters = len(set(cluster_labels))
|
||||
cluster_sizes = defaultdict(int)
|
||||
for label in cluster_labels:
|
||||
cluster_sizes[label] += 1
|
||||
|
||||
cluster_info = {
|
||||
'algorithm': f'semantic_{algorithm}',
|
||||
'total_clusters': unique_clusters,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': self.use_gpu,
|
||||
'cluster_sizes': dict(cluster_sizes),
|
||||
'average_cluster_size': n_nodes / unique_clusters if unique_clusters > 0 else 0,
|
||||
'similarity_threshold': similarity_threshold if algorithm == 'dbscan' else None
|
||||
}
|
||||
|
||||
logger.info(f"✅ Semantic clustering completed: {unique_clusters} clusters in {processing_time:.3f}s")
|
||||
|
||||
return SemanticClusterResult(
|
||||
clustered_nodes=clustered_nodes,
|
||||
cluster_info=cluster_info,
|
||||
similarity_matrix=similarity_matrix,
|
||||
cluster_labels=cluster_labels
|
||||
)
|
||||
|
||||
def cluster_by_content_similarity(
|
||||
self,
|
||||
nodes: List[Dict[str, Any]],
|
||||
algorithm: str = "kmeans",
|
||||
n_clusters: Optional[int] = None
|
||||
) -> SemanticClusterResult:
|
||||
"""Cluster nodes based on content similarity using TF-IDF"""
|
||||
start_time = time.time()
|
||||
n_nodes = len(nodes)
|
||||
|
||||
logger.info(f"📄 Starting content-based clustering of {n_nodes} nodes")
|
||||
|
||||
if n_nodes < 2:
|
||||
return self._create_single_cluster_result(nodes, start_time)
|
||||
|
||||
# Calculate content similarity
|
||||
similarity_matrix = self.similarity_calc.calculate_content_similarity(nodes)
|
||||
|
||||
# Apply clustering
|
||||
if algorithm == "kmeans":
|
||||
n_clusters = n_clusters or min(10, n_nodes // 2)
|
||||
if self.use_gpu and HAS_GPU:
|
||||
cluster_labels = self._gpu_kmeans_clustering(similarity_matrix, n_clusters)
|
||||
else:
|
||||
cluster_labels = self._kmeans_clustering(similarity_matrix, n_clusters)
|
||||
else:
|
||||
distance_matrix = 1.0 - similarity_matrix
|
||||
cluster_labels = self._hierarchical_clustering(
|
||||
distance_matrix, n_clusters or min(10, n_nodes // 2)
|
||||
)
|
||||
|
||||
# Create result
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_id': int(cluster_labels[i]),
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
unique_clusters = len(set(cluster_labels))
|
||||
|
||||
cluster_info = {
|
||||
'algorithm': f'content_{algorithm}',
|
||||
'total_clusters': unique_clusters,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': self.use_gpu and algorithm == 'kmeans',
|
||||
'average_cluster_size': n_nodes / unique_clusters if unique_clusters > 0 else 0
|
||||
}
|
||||
|
||||
logger.info(f"✅ Content clustering completed: {unique_clusters} clusters in {processing_time:.3f}s")
|
||||
|
||||
return SemanticClusterResult(
|
||||
clustered_nodes=clustered_nodes,
|
||||
cluster_info=cluster_info,
|
||||
similarity_matrix=similarity_matrix,
|
||||
cluster_labels=cluster_labels
|
||||
)
|
||||
|
||||
def hybrid_clustering(
|
||||
self,
|
||||
nodes: List[Dict[str, Any]],
|
||||
name_weight: float = 0.6,
|
||||
content_weight: float = 0.3,
|
||||
spatial_weight: float = 0.1,
|
||||
algorithm: str = "hierarchical",
|
||||
n_clusters: Optional[int] = None
|
||||
) -> SemanticClusterResult:
|
||||
"""
|
||||
Hybrid clustering combining name, content, and spatial similarities
|
||||
|
||||
Args:
|
||||
name_weight: Weight for name similarity (0.0-1.0)
|
||||
content_weight: Weight for content similarity (0.0-1.0)
|
||||
spatial_weight: Weight for spatial similarity (0.0-1.0)
|
||||
"""
|
||||
start_time = time.time()
|
||||
n_nodes = len(nodes)
|
||||
|
||||
logger.info(f"🔄 Starting hybrid clustering of {n_nodes} nodes")
|
||||
logger.info(f" Weights: name={name_weight}, content={content_weight}, spatial={spatial_weight}")
|
||||
|
||||
if n_nodes < 2:
|
||||
return self._create_single_cluster_result(nodes, start_time)
|
||||
|
||||
# Normalize weights
|
||||
total_weight = name_weight + content_weight + spatial_weight
|
||||
if total_weight > 0:
|
||||
name_weight /= total_weight
|
||||
content_weight /= total_weight
|
||||
spatial_weight /= total_weight
|
||||
|
||||
# Calculate different similarity matrices
|
||||
similarities = []
|
||||
weights = []
|
||||
|
||||
if name_weight > 0:
|
||||
name_similarity = self._calculate_name_similarity_matrix(nodes)
|
||||
similarities.append(name_similarity)
|
||||
weights.append(name_weight)
|
||||
|
||||
if content_weight > 0:
|
||||
content_similarity = self.similarity_calc.calculate_content_similarity(nodes)
|
||||
similarities.append(content_similarity)
|
||||
weights.append(content_weight)
|
||||
|
||||
if spatial_weight > 0:
|
||||
spatial_similarity = self._calculate_spatial_similarity_matrix(nodes)
|
||||
similarities.append(spatial_similarity)
|
||||
weights.append(spatial_weight)
|
||||
|
||||
# Combine similarities
|
||||
if not similarities:
|
||||
return self._create_single_cluster_result(nodes, start_time)
|
||||
|
||||
combined_similarity = np.zeros((n_nodes, n_nodes))
|
||||
for similarity, weight in zip(similarities, weights):
|
||||
combined_similarity += similarity * weight
|
||||
|
||||
# Apply clustering
|
||||
distance_matrix = 1.0 - combined_similarity
|
||||
|
||||
if algorithm == "hierarchical":
|
||||
cluster_labels = self._hierarchical_clustering(
|
||||
distance_matrix, n_clusters or min(10, n_nodes // 2)
|
||||
)
|
||||
elif algorithm == "kmeans":
|
||||
cluster_labels = self._kmeans_clustering(
|
||||
combined_similarity, n_clusters or min(10, n_nodes // 2)
|
||||
)
|
||||
else:
|
||||
cluster_labels = self._dbscan_clustering(distance_matrix, 0.3)
|
||||
|
||||
# Create result
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_id': int(cluster_labels[i]),
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
unique_clusters = len(set(cluster_labels))
|
||||
|
||||
cluster_info = {
|
||||
'algorithm': f'hybrid_{algorithm}',
|
||||
'total_clusters': unique_clusters,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': self.use_gpu,
|
||||
'weights': {
|
||||
'name': name_weight,
|
||||
'content': content_weight,
|
||||
'spatial': spatial_weight
|
||||
},
|
||||
'average_cluster_size': n_nodes / unique_clusters if unique_clusters > 0 else 0
|
||||
}
|
||||
|
||||
logger.info(f"✅ Hybrid clustering completed: {unique_clusters} clusters in {processing_time:.3f}s")
|
||||
|
||||
return SemanticClusterResult(
|
||||
clustered_nodes=clustered_nodes,
|
||||
cluster_info=cluster_info,
|
||||
similarity_matrix=combined_similarity,
|
||||
cluster_labels=cluster_labels
|
||||
)
|
||||
|
||||
def _calculate_name_similarity_matrix(self, nodes: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""Calculate pairwise name similarity matrix"""
|
||||
n_nodes = len(nodes)
|
||||
similarity_matrix = np.zeros((n_nodes, n_nodes))
|
||||
|
||||
for i in range(n_nodes):
|
||||
for j in range(i, n_nodes):
|
||||
if i == j:
|
||||
similarity_matrix[i, j] = 1.0
|
||||
else:
|
||||
name1 = nodes[i].get('name', '')
|
||||
name2 = nodes[j].get('name', '')
|
||||
similarity = self.similarity_calc.calculate_name_similarity(name1, name2)
|
||||
similarity_matrix[i, j] = similarity
|
||||
similarity_matrix[j, i] = similarity # Symmetric
|
||||
|
||||
return similarity_matrix
|
||||
|
||||
def _calculate_spatial_similarity_matrix(self, nodes: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""Calculate spatial similarity based on node positions"""
|
||||
n_nodes = len(nodes)
|
||||
similarity_matrix = np.zeros((n_nodes, n_nodes))
|
||||
|
||||
# Extract coordinates
|
||||
coords = []
|
||||
for node in nodes:
|
||||
x = float(node.get('x', 0))
|
||||
y = float(node.get('y', 0))
|
||||
z = float(node.get('z', 0))
|
||||
coords.append([x, y, z])
|
||||
|
||||
coords = np.array(coords)
|
||||
|
||||
# Calculate pairwise distances
|
||||
for i in range(n_nodes):
|
||||
for j in range(i, n_nodes):
|
||||
if i == j:
|
||||
similarity_matrix[i, j] = 1.0
|
||||
else:
|
||||
# Euclidean distance
|
||||
dist = np.linalg.norm(coords[i] - coords[j])
|
||||
# Convert distance to similarity (closer = more similar)
|
||||
# Use exponential decay: similarity = exp(-distance/scale)
|
||||
scale = 50.0 # Adjust based on your coordinate system
|
||||
similarity = np.exp(-dist / scale)
|
||||
similarity_matrix[i, j] = similarity
|
||||
similarity_matrix[j, i] = similarity
|
||||
|
||||
return similarity_matrix
|
||||
|
||||
def _hierarchical_clustering(self, distance_matrix: np.ndarray, n_clusters: int) -> np.ndarray:
|
||||
"""Apply hierarchical clustering"""
|
||||
clusterer = AgglomerativeClustering(
|
||||
n_clusters=n_clusters,
|
||||
metric='precomputed',
|
||||
linkage='average'
|
||||
)
|
||||
return clusterer.fit_predict(distance_matrix)
|
||||
|
||||
def _kmeans_clustering(self, similarity_matrix: np.ndarray, n_clusters: int) -> np.ndarray:
|
||||
"""Apply K-means clustering"""
|
||||
clusterer = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
||||
return clusterer.fit_predict(similarity_matrix)
|
||||
|
||||
def _gpu_kmeans_clustering(self, similarity_matrix: np.ndarray, n_clusters: int) -> np.ndarray:
|
||||
"""Apply GPU-accelerated K-means clustering"""
|
||||
try:
|
||||
gpu_matrix = cp.array(similarity_matrix, dtype=cp.float32)
|
||||
clusterer = cuKMeans(n_clusters=n_clusters, random_state=42)
|
||||
labels = clusterer.fit_predict(gpu_matrix)
|
||||
return cp.asnumpy(labels)
|
||||
except Exception as e:
|
||||
logger.warning(f"GPU K-means failed, falling back to CPU: {e}")
|
||||
return self._kmeans_clustering(similarity_matrix, n_clusters)
|
||||
|
||||
def _dbscan_clustering(self, distance_matrix: np.ndarray, eps: float) -> np.ndarray:
|
||||
"""Apply DBSCAN clustering"""
|
||||
clusterer = DBSCAN(eps=eps, metric='precomputed', min_samples=2)
|
||||
labels = clusterer.fit_predict(distance_matrix)
|
||||
|
||||
# DBSCAN uses -1 for noise points, convert to positive integers
|
||||
unique_labels = set(labels)
|
||||
if -1 in unique_labels:
|
||||
# Assign noise points to individual clusters
|
||||
max_label = max(labels) if len(unique_labels) > 1 else -1
|
||||
noise_cluster = max_label + 1
|
||||
labels = np.array([noise_cluster if label == -1 else label for label in labels])
|
||||
|
||||
return labels
|
||||
|
||||
def _create_single_cluster_result(self, nodes: List[Dict[str, Any]], start_time: float) -> SemanticClusterResult:
|
||||
"""Create result for single cluster (when too few nodes)"""
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(nodes):
|
||||
clustered_node = {
|
||||
**node,
|
||||
'cluster_id': 0,
|
||||
'node_index': i
|
||||
}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
cluster_info = {
|
||||
'algorithm': 'single_cluster',
|
||||
'total_clusters': 1,
|
||||
'processing_time': processing_time,
|
||||
'gpu_accelerated': False,
|
||||
'average_cluster_size': len(nodes)
|
||||
}
|
||||
|
||||
return SemanticClusterResult(
|
||||
clustered_nodes=clustered_nodes,
|
||||
cluster_info=cluster_info,
|
||||
similarity_matrix=None,
|
||||
cluster_labels=np.zeros(len(nodes), dtype=int)
|
||||
)
|
||||
|
||||
# Convenience functions for easy integration
|
||||
async def cluster_nodes_by_similarity(
|
||||
nodes: List[Dict[str, Any]],
|
||||
method: str = "hybrid",
|
||||
algorithm: str = "hierarchical",
|
||||
n_clusters: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> SemanticClusterResult:
|
||||
"""
|
||||
Main entry point for semantic clustering
|
||||
|
||||
Args:
|
||||
nodes: List of node dictionaries
|
||||
method: 'name', 'content', 'hybrid'
|
||||
algorithm: 'hierarchical', 'kmeans', 'dbscan'
|
||||
n_clusters: Number of clusters (if applicable)
|
||||
**kwargs: Additional parameters for specific methods
|
||||
"""
|
||||
engine = SemanticClusteringEngine()
|
||||
|
||||
if method == "name":
|
||||
return engine.cluster_by_name_similarity(nodes, algorithm, n_clusters, **kwargs)
|
||||
elif method == "content":
|
||||
return engine.cluster_by_content_similarity(nodes, algorithm, n_clusters, **kwargs)
|
||||
elif method == "hybrid":
|
||||
return engine.hybrid_clustering(nodes, algorithm=algorithm, n_clusters=n_clusters, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown clustering method: {method}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
test_nodes = [
|
||||
{"name": "Machine Learning", "x": 0, "y": 0, "z": 0, "group": "AI"},
|
||||
{"name": "Deep Learning", "x": 10, "y": 5, "z": 2, "group": "AI"},
|
||||
{"name": "Neural Networks", "x": 15, "y": 8, "z": 3, "group": "AI"},
|
||||
{"name": "Data Science", "x": 20, "y": 10, "z": 5, "group": "Data"},
|
||||
{"name": "Statistics", "x": 25, "y": 15, "z": 8, "group": "Math"},
|
||||
{"name": "Linear Algebra", "x": 30, "y": 20, "z": 10, "group": "Math"},
|
||||
]
|
||||
|
||||
async def test():
|
||||
result = await cluster_nodes_by_similarity(test_nodes, method="hybrid")
|
||||
print("Cluster Result:", result.cluster_info)
|
||||
for node in result.clustered_nodes:
|
||||
print(f" {node['name']} -> Cluster {node['cluster_id']}")
|
||||
|
||||
asyncio.run(test())
|
||||
@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple WebGPU clustering test service
|
||||
Minimal implementation to test basic functionality
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from typing import Dict, List, Any, Optional
|
||||
import time
|
||||
|
||||
# Simple data models
|
||||
class GraphData(BaseModel):
|
||||
nodes: List[Dict[str, Any]]
|
||||
links: List[Dict[str, Any]]
|
||||
|
||||
class SimpleClusteringRequest(BaseModel):
|
||||
graph_data: GraphData
|
||||
mode: str = "hybrid"
|
||||
|
||||
class SimpleClusteringResult(BaseModel):
|
||||
clustered_nodes: List[Dict[str, Any]]
|
||||
processing_time: float
|
||||
mode: str
|
||||
session_id: Optional[str] = None
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="Simple WebGPU Test Service", version="1.0.0")
|
||||
|
||||
# Enable CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"gpu_available": True,
|
||||
"webrtc_available": True,
|
||||
"active_sessions": 0,
|
||||
"active_connections": 0
|
||||
}
|
||||
|
||||
@app.get("/api/capabilities")
|
||||
async def get_capabilities():
|
||||
return {
|
||||
"modes": {
|
||||
"hybrid": {
|
||||
"available": True,
|
||||
"description": "GPU clustering on server, CPU rendering on client"
|
||||
},
|
||||
"webrtc_stream": {
|
||||
"available": True,
|
||||
"description": "Full GPU rendering streamed to client browser"
|
||||
}
|
||||
},
|
||||
"gpu_acceleration": {
|
||||
"rapids_available": True,
|
||||
"opencv_available": True,
|
||||
"plotting_available": True
|
||||
},
|
||||
"cluster_dimensions": [32, 18, 24],
|
||||
"max_cluster_count": 13824
|
||||
}
|
||||
|
||||
@app.post("/api/cluster", response_model=SimpleClusteringResult)
|
||||
async def cluster_graph(request: SimpleClusteringRequest):
|
||||
"""Simple clustering implementation for testing"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Simple clustering - just add cluster_index to each node
|
||||
clustered_nodes = []
|
||||
for i, node in enumerate(request.graph_data.nodes):
|
||||
clustered_node = {**node, "cluster_index": i % 10, "node_index": i}
|
||||
clustered_nodes.append(clustered_node)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
result = SimpleClusteringResult(
|
||||
clustered_nodes=clustered_nodes,
|
||||
processing_time=processing_time,
|
||||
mode=request.mode,
|
||||
session_id="test-session-123" if request.mode == "webrtc_stream" else None
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/stream/{session_id}")
|
||||
async def stream_frame(session_id: str):
|
||||
"""Simple streaming endpoint - returns a placeholder"""
|
||||
# Return a simple 1x1 PNG pixel as placeholder
|
||||
png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01\r\n-\xdb\x00\x00\x00\x00IEND\xaeB`\x82'
|
||||
|
||||
from fastapi.responses import Response
|
||||
return Response(
|
||||
content=png_data,
|
||||
media_type="image/png",
|
||||
headers={"Cache-Control": "no-cache"}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting Simple WebGPU Test Service...")
|
||||
uvicorn.run(
|
||||
"simple_webgpu_test:app",
|
||||
host="0.0.0.0",
|
||||
port=8083,
|
||||
log_level="info",
|
||||
reload=False
|
||||
)
|
||||
166
nvidia/txt2kg/assets/deploy/services/gpu-viz/start_remote_gpu_services.sh
Executable file
166
nvidia/txt2kg/assets/deploy/services/gpu-viz/start_remote_gpu_services.sh
Executable file
@ -0,0 +1,166 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Start Remote GPU Rendering Services
|
||||
# This script starts the custom remote GPU rendering service as an alternative to PyGraphistry cloud
|
||||
|
||||
echo "🚀 Starting Remote GPU Rendering Services"
|
||||
echo "========================================="
|
||||
|
||||
# Check if we're in a RAPIDS/cuGraph environment
|
||||
if python -c "import cudf, cugraph" 2>/dev/null; then
|
||||
echo "✓ RAPIDS/cuGraph environment detected"
|
||||
GPU_AVAILABLE=true
|
||||
else
|
||||
echo "⚠ RAPIDS/cuGraph not available - will use CPU fallback"
|
||||
GPU_AVAILABLE=false
|
||||
fi
|
||||
|
||||
# Check if Redis is available (optional for session storage)
|
||||
if command -v redis-server >/dev/null 2>&1; then
|
||||
echo "✓ Redis available for session storage"
|
||||
|
||||
# Start Redis if not running
|
||||
if ! pgrep -x "redis-server" > /dev/null; then
|
||||
echo "Starting Redis server..."
|
||||
redis-server --daemonize yes --port 6379
|
||||
sleep 2
|
||||
fi
|
||||
else
|
||||
echo "⚠ Redis not available - using in-memory session storage"
|
||||
fi
|
||||
|
||||
# Set environment variables
|
||||
export REDIS_HOST=${REDIS_HOST:-localhost}
|
||||
export REDIS_PORT=${REDIS_PORT:-6379}
|
||||
|
||||
# Create log directory
|
||||
mkdir -p logs
|
||||
|
||||
echo ""
|
||||
echo "🎯 Service Configuration:"
|
||||
echo " GPU Processing: $GPU_AVAILABLE"
|
||||
echo " Session Storage: ${REDIS_HOST:-memory}:${REDIS_PORT:-N/A}"
|
||||
echo " Service Port: 8082"
|
||||
echo ""
|
||||
|
||||
# Function to start service with proper error handling
|
||||
start_service() {
|
||||
local service_name=$1
|
||||
local script_path=$2
|
||||
local port=$3
|
||||
local log_file=$4
|
||||
|
||||
echo "Starting $service_name on port $port..."
|
||||
|
||||
# Kill existing process if running
|
||||
if lsof -Pi :$port -sTCP:LISTEN -t >/dev/null; then
|
||||
echo " Killing existing process on port $port"
|
||||
kill $(lsof -t -i:$port) 2>/dev/null || true
|
||||
sleep 2
|
||||
fi
|
||||
|
||||
# Start the service
|
||||
python $script_path > logs/$log_file 2>&1 &
|
||||
local pid=$!
|
||||
|
||||
# Wait a moment and check if it started successfully
|
||||
sleep 3
|
||||
if kill -0 $pid 2>/dev/null; then
|
||||
echo " ✓ $service_name started successfully (PID: $pid)"
|
||||
echo $pid > logs/${service_name,,}_pid.txt
|
||||
return 0
|
||||
else
|
||||
echo " ✗ Failed to start $service_name"
|
||||
echo " Check logs/$log_file for details"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Start Remote GPU Rendering Service
|
||||
echo "📊 Starting Remote GPU Rendering Service..."
|
||||
start_service "RemoteGPURenderer" "remote_gpu_rendering_service.py" 8082 "remote_gpu_rendering.log"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo ""
|
||||
echo "✅ Remote GPU Rendering Service is ready!"
|
||||
echo ""
|
||||
echo "🎯 Available endpoints:"
|
||||
echo " Process graph: POST http://localhost:8082/api/render"
|
||||
echo " Iframe visualization: GET http://localhost:8082/embed/{session_id}"
|
||||
echo " Session status: GET http://localhost:8082/api/session/{session_id}"
|
||||
echo " Real-time updates: WS ws://localhost:8082/ws/{session_id}"
|
||||
echo " Health check: GET http://localhost:8082/api/health"
|
||||
echo ""
|
||||
echo "📋 Usage examples:"
|
||||
echo ""
|
||||
echo " # Test health check"
|
||||
echo " curl http://localhost:8082/api/health"
|
||||
echo ""
|
||||
echo " # Process a sample graph"
|
||||
echo " curl -X POST http://localhost:8082/api/render \\"
|
||||
echo " -H 'Content-Type: application/json' \\"
|
||||
echo " -d '{"
|
||||
echo " \"graph_data\": {"
|
||||
echo " \"nodes\": [{\"id\": \"1\", \"name\": \"Node 1\"}, {\"id\": \"2\", \"name\": \"Node 2\"}],"
|
||||
echo " \"links\": [{\"source\": \"1\", \"target\": \"2\", \"name\": \"edge_1_2\"}]"
|
||||
echo " },"
|
||||
echo " \"layout_algorithm\": \"force_atlas2\","
|
||||
echo " \"clustering_algorithm\": \"leiden\","
|
||||
echo " \"compute_centrality\": true,"
|
||||
echo " \"render_quality\": \"high\","
|
||||
echo " \"interactive_mode\": true"
|
||||
echo " }'"
|
||||
echo ""
|
||||
echo "📁 Logs are available in:"
|
||||
echo " Remote GPU Rendering: logs/remote_gpu_rendering.log"
|
||||
echo ""
|
||||
echo "🛠️ Integration with frontend:"
|
||||
echo " import { RemoteGPUViewer } from '@/components/remote-gpu-viewer'"
|
||||
echo " <RemoteGPUViewer"
|
||||
echo " graphData={graphData}"
|
||||
echo " remoteServiceUrl=\"http://localhost:8082\""
|
||||
echo " onError={(err) => console.error(err)}"
|
||||
echo " />"
|
||||
echo ""
|
||||
echo "⚡ Performance tips:"
|
||||
echo " - Use 'ultra' quality for 1M+ node graphs"
|
||||
echo " - Enable Redis for production session storage"
|
||||
echo " - Run on GPU server for maximum performance"
|
||||
echo " - Use iframe embedding to isolate visualization"
|
||||
echo ""
|
||||
|
||||
# Start a simple monitoring script
|
||||
echo "🔍 Starting service monitor..."
|
||||
monitor_services() {
|
||||
while true; do
|
||||
sleep 30
|
||||
|
||||
# Check if services are still running
|
||||
if [ -f logs/remotegpurenderer_pid.txt ]; then
|
||||
pid=$(cat logs/remotegpurenderer_pid.txt)
|
||||
if ! kill -0 $pid 2>/dev/null; then
|
||||
echo "$(date): Remote GPU Rendering Service died, restarting..."
|
||||
start_service "RemoteGPURenderer" "remote_gpu_rendering_service.py" 8082 "remote_gpu_rendering.log"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# Run monitor in background
|
||||
monitor_services &
|
||||
echo $! > logs/monitor_pid.txt
|
||||
|
||||
echo "✅ All services started and monitoring enabled!"
|
||||
echo ""
|
||||
echo "To stop all services, run: ./stop_remote_gpu_services.sh"
|
||||
echo "To view logs in real-time: tail -f logs/remote_gpu_rendering.log"
|
||||
|
||||
else
|
||||
echo ""
|
||||
echo "❌ Failed to start Remote GPU Rendering Service"
|
||||
echo "Check the logs for details and ensure dependencies are installed:"
|
||||
echo " - FastAPI: pip install fastapi uvicorn"
|
||||
echo " - RAPIDS (optional): conda install -c rapidsai cudf cugraph"
|
||||
echo " - Redis (optional): sudo apt-get install redis-server"
|
||||
exit 1
|
||||
fi
|
||||
64
nvidia/txt2kg/assets/deploy/services/gpu-viz/stop_remote_gpu_services.sh
Executable file
64
nvidia/txt2kg/assets/deploy/services/gpu-viz/stop_remote_gpu_services.sh
Executable file
@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Stop Remote GPU Rendering Services
|
||||
|
||||
echo "🛑 Stopping Remote GPU Rendering Services"
|
||||
echo "========================================="
|
||||
|
||||
# Function to stop service by PID file
|
||||
stop_service() {
|
||||
local service_name=$1
|
||||
local pid_file=$2
|
||||
|
||||
if [ -f "$pid_file" ]; then
|
||||
local pid=$(cat "$pid_file")
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Stopping $service_name (PID: $pid)..."
|
||||
kill "$pid"
|
||||
|
||||
# Wait for graceful shutdown
|
||||
local count=0
|
||||
while kill -0 "$pid" 2>/dev/null && [ $count -lt 10 ]; do
|
||||
sleep 1
|
||||
count=$((count + 1))
|
||||
done
|
||||
|
||||
# Force kill if still running
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo " Force killing $service_name..."
|
||||
kill -9 "$pid"
|
||||
fi
|
||||
|
||||
echo " ✓ $service_name stopped"
|
||||
else
|
||||
echo " $service_name was not running"
|
||||
fi
|
||||
rm -f "$pid_file"
|
||||
else
|
||||
echo " No PID file found for $service_name"
|
||||
fi
|
||||
}
|
||||
|
||||
# Stop services
|
||||
stop_service "Remote GPU Renderer" "logs/remotegpurenderer_pid.txt"
|
||||
stop_service "Service Monitor" "logs/monitor_pid.txt"
|
||||
|
||||
# Stop any remaining processes on the service ports
|
||||
echo ""
|
||||
echo "🔍 Checking for remaining processes on service ports..."
|
||||
|
||||
ports=(8082)
|
||||
for port in "${ports[@]}"; do
|
||||
if lsof -Pi :$port -sTCP:LISTEN -t >/dev/null 2>&1; then
|
||||
echo "Killing process on port $port..."
|
||||
kill $(lsof -t -i:$port) 2>/dev/null || true
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "✅ All Remote GPU Rendering Services stopped"
|
||||
echo ""
|
||||
echo "📁 Log files are preserved in logs/ directory:"
|
||||
echo " - logs/remote_gpu_rendering.log"
|
||||
echo ""
|
||||
echo "To restart services, run: ./start_remote_gpu_services.sh"
|
||||
@ -0,0 +1,243 @@
|
||||
# True GPU Rendering vs Current Approach
|
||||
|
||||
## 🎯 **Current Remote GPU Service**
|
||||
|
||||
### **What Uses GPU (✅)**
|
||||
- **Graph Layout**: cuGraph Force Atlas 2, Spectral Layout
|
||||
- **Clustering**: cuGraph Leiden, Louvain algorithms
|
||||
- **Centrality**: cuGraph PageRank, Betweenness Centrality
|
||||
- **Data Processing**: Node positioning, edge bundling
|
||||
|
||||
### **What Uses CPU (❌)**
|
||||
- **Visual Rendering**: D3.js SVG/Canvas drawing
|
||||
- **Animation**: D3.js transitions and transforms
|
||||
- **Interaction**: DOM event handling, hover, zoom
|
||||
- **Text Rendering**: Node labels, tooltips
|
||||
|
||||
## 🔥 **True GPU Rendering (Like PyGraphistry)**
|
||||
|
||||
### **What Would Need GPU Acceleration**
|
||||
|
||||
#### **1. WebGL Compute Shaders**
|
||||
```glsl
|
||||
// Vertex shader for node positioning
|
||||
attribute vec2 position;
|
||||
attribute float size;
|
||||
attribute vec3 color;
|
||||
|
||||
uniform mat4 projectionMatrix;
|
||||
uniform float time;
|
||||
|
||||
void main() {
|
||||
// GPU-accelerated node positioning
|
||||
vec2 pos = position + computeForceLayout(time);
|
||||
gl_Position = projectionMatrix * vec4(pos, 0.0, 1.0);
|
||||
gl_PointSize = size;
|
||||
}
|
||||
```
|
||||
|
||||
#### **2. GPU Particle Systems**
|
||||
```javascript
|
||||
// WebGL-based node rendering
|
||||
class GPUNodeRenderer {
|
||||
constructor(gl, nodeCount) {
|
||||
this.nodeCount = nodeCount;
|
||||
|
||||
// Create vertex buffers for GPU processing
|
||||
this.positionBuffer = gl.createBuffer();
|
||||
this.colorBuffer = gl.createBuffer();
|
||||
this.sizeBuffer = gl.createBuffer();
|
||||
|
||||
// Compile GPU shaders
|
||||
this.program = this.createShaderProgram(gl);
|
||||
}
|
||||
|
||||
render(nodes) {
|
||||
// Update GPU buffers - no CPU iteration
|
||||
gl.bindBuffer(gl.ARRAY_BUFFER, this.positionBuffer);
|
||||
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array(positions), gl.DYNAMIC_DRAW);
|
||||
|
||||
// GPU draws all nodes in single call
|
||||
gl.drawArrays(gl.POINTS, 0, this.nodeCount);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### **3. GPU-Based Interaction**
|
||||
```javascript
|
||||
// GPU picking for node selection
|
||||
class GPUPicker {
|
||||
constructor(gl, nodeCount) {
|
||||
// Render nodes to off-screen framebuffer with unique colors
|
||||
this.pickingFramebuffer = gl.createFramebuffer();
|
||||
this.pickingTexture = gl.createTexture();
|
||||
}
|
||||
|
||||
getNodeAtPosition(x, y) {
|
||||
// Read single pixel from GPU framebuffer
|
||||
const pixel = new Uint8Array(4);
|
||||
gl.readPixels(x, y, 1, 1, gl.RGBA, gl.UNSIGNED_BYTE, pixel);
|
||||
|
||||
// Decode node ID from color
|
||||
return this.colorToNodeId(pixel);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 📊 **Performance Comparison**
|
||||
|
||||
### **Current D3.js CPU Rendering**
|
||||
```javascript
|
||||
// CPU-bound operations
|
||||
nodes.forEach(node => {
|
||||
// For each node, update DOM element
|
||||
d3.select(`#node-${node.id}`)
|
||||
.attr("cx", node.x)
|
||||
.attr("cy", node.y)
|
||||
.attr("r", node.size);
|
||||
});
|
||||
|
||||
// Performance: O(n) DOM operations
|
||||
// 10k nodes = 10k DOM updates per frame
|
||||
// Maximum ~60fps with heavy optimization
|
||||
```
|
||||
|
||||
### **GPU WebGL Rendering**
|
||||
```javascript
|
||||
// GPU-accelerated operations
|
||||
class GPURenderer {
|
||||
updateNodes(nodeData) {
|
||||
// Single buffer update for all nodes
|
||||
gl.bufferSubData(gl.ARRAY_BUFFER, 0, nodeData);
|
||||
|
||||
// Single draw call for all nodes
|
||||
gl.drawArraysInstanced(gl.TRIANGLES, 0, 6, nodeCount);
|
||||
}
|
||||
}
|
||||
|
||||
// Performance: O(1) GPU operations
|
||||
// 1M nodes = 1 GPU draw call
|
||||
// Can maintain 60fps with millions of nodes
|
||||
```
|
||||
|
||||
## 🛠️ **Implementation Options**
|
||||
|
||||
### **Option 1: WebGL2 + Compute Shaders**
|
||||
```html
|
||||
<!-- Enhanced HTML template with WebGL -->
|
||||
<canvas id="gpu-canvas" width="800" height="600"></canvas>
|
||||
<script>
|
||||
const canvas = document.getElementById('gpu-canvas');
|
||||
const gl = canvas.getContext('webgl2');
|
||||
|
||||
// Load compute shaders for layout animation
|
||||
const computeShader = gl.createShader(gl.COMPUTE_SHADER);
|
||||
gl.shaderSource(computeShader, computeShaderSource);
|
||||
|
||||
// Render loop using GPU
|
||||
function animate() {
|
||||
// Update node positions on GPU
|
||||
gl.useProgram(computeProgram);
|
||||
gl.dispatchCompute(Math.ceil(nodeCount / 64), 1, 1);
|
||||
|
||||
// Render nodes on GPU
|
||||
gl.useProgram(renderProgram);
|
||||
gl.drawArraysInstanced(gl.POINTS, 0, 1, nodeCount);
|
||||
|
||||
requestAnimationFrame(animate);
|
||||
}
|
||||
</script>
|
||||
```
|
||||
|
||||
### **Option 2: WebGPU (Future)**
|
||||
```javascript
|
||||
// Next-generation WebGPU API
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
const device = await adapter.requestDevice();
|
||||
|
||||
// GPU compute pipeline for layout
|
||||
const computePipeline = device.createComputePipeline({
|
||||
compute: {
|
||||
module: device.createShaderModule({ code: layoutComputeShader }),
|
||||
entryPoint: 'main'
|
||||
}
|
||||
});
|
||||
|
||||
// GPU render pipeline
|
||||
const renderPipeline = device.createRenderPipeline({
|
||||
vertex: { module: vertexShaderModule, entryPoint: 'main' },
|
||||
fragment: { module: fragmentShaderModule, entryPoint: 'main' },
|
||||
primitive: { topology: 'point-list' }
|
||||
});
|
||||
```
|
||||
|
||||
### **Option 3: Three.js GPU Optimization**
|
||||
```javascript
|
||||
// Use Three.js InstancedMesh for GPU instancing
|
||||
import * as THREE from 'three';
|
||||
|
||||
class GPUGraphRenderer {
|
||||
constructor(nodeCount) {
|
||||
// Single geometry instanced for all nodes
|
||||
const geometry = new THREE.CircleGeometry(1, 8);
|
||||
const material = new THREE.MeshBasicMaterial();
|
||||
|
||||
// GPU-instanced mesh for all nodes
|
||||
this.instancedMesh = new THREE.InstancedMesh(
|
||||
geometry, material, nodeCount
|
||||
);
|
||||
|
||||
// Position matrix for each instance
|
||||
this.matrix = new THREE.Matrix4();
|
||||
}
|
||||
|
||||
updateNode(index, x, y, scale, color) {
|
||||
// Update single instance matrix
|
||||
this.matrix.makeScale(scale, scale, 1);
|
||||
this.matrix.setPosition(x, y, 0);
|
||||
this.instancedMesh.setMatrixAt(index, this.matrix);
|
||||
this.instancedMesh.setColorAt(index, color);
|
||||
}
|
||||
|
||||
render() {
|
||||
// Single GPU draw call for all nodes
|
||||
this.instancedMesh.instanceMatrix.needsUpdate = true;
|
||||
this.instancedMesh.instanceColor.needsUpdate = true;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 🎯 **Recommendation**
|
||||
|
||||
### **Current Approach is Good For:**
|
||||
- ✅ **Rapid development** - Standard D3.js patterns
|
||||
- ✅ **Small-medium graphs** (<50k nodes)
|
||||
- ✅ **Interactive features** - Easy DOM manipulation
|
||||
- ✅ **Debugging** - Standard web dev tools
|
||||
- ✅ **Compatibility** - Works in all browsers
|
||||
|
||||
### **True GPU Rendering Needed For:**
|
||||
- 🚀 **Million+ node graphs** with smooth 60fps
|
||||
- 🚀 **Real-time layout animation**
|
||||
- 🚀 **Complex visual effects** (particles, trails)
|
||||
- 🚀 **VR/AR graph visualization**
|
||||
- 🚀 **Multi-touch interaction** on large displays
|
||||
|
||||
## 💡 **Hybrid Solution**
|
||||
|
||||
The optimal approach combines both:
|
||||
|
||||
```javascript
|
||||
// Intelligent renderer selection
|
||||
const selectRenderer = (nodeCount) => {
|
||||
if (nodeCount < 10000) {
|
||||
return new D3SVGRenderer(); // CPU DOM rendering
|
||||
} else if (nodeCount < 100000) {
|
||||
return new ThreeJSRenderer(); // WebGL with Three.js
|
||||
} else {
|
||||
return new WebGLRenderer(); // Custom GPU shaders
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
**Current Status:** Your remote service provides **GPU-accelerated data processing** with **CPU-based rendering** - which is perfect for most use cases and much easier to develop/maintain than full GPU rendering.
|
||||
@ -0,0 +1,773 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified GPU Graph Visualization Service
|
||||
|
||||
Combines PyGraphistry cloud processing and local GPU processing with cuGraph
|
||||
into a single FastAPI service for maximum flexibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import networkx as nx
|
||||
from enum import Enum
|
||||
|
||||
# PyGraphistry imports
|
||||
import graphistry
|
||||
|
||||
# GPU-accelerated imports (available in NVIDIA PyG container)
|
||||
try:
|
||||
import cudf
|
||||
import cugraph
|
||||
import cupy as cp
|
||||
from cuml import UMAP
|
||||
HAS_RAPIDS = True
|
||||
print("✓ RAPIDS cuGraph/cuDF/cuML available")
|
||||
except ImportError:
|
||||
HAS_RAPIDS = False
|
||||
print("⚠ RAPIDS not available, falling back to CPU")
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch_geometric
|
||||
HAS_TORCH_GEOMETRIC = True
|
||||
print("✓ PyTorch Geometric available")
|
||||
except ImportError:
|
||||
HAS_TORCH_GEOMETRIC = False
|
||||
print("⚠ PyTorch Geometric not available")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ProcessingMode(str, Enum):
|
||||
PYGRAPHISTRY_CLOUD = "pygraphistry_cloud"
|
||||
LOCAL_GPU = "local_gpu"
|
||||
LOCAL_CPU = "local_cpu"
|
||||
|
||||
class GraphPattern(str, Enum):
|
||||
RANDOM = "random"
|
||||
SCALE_FREE = "scale-free"
|
||||
SMALL_WORLD = "small-world"
|
||||
CLUSTERED = "clustered"
|
||||
HIERARCHICAL = "hierarchical"
|
||||
GRID = "grid"
|
||||
|
||||
class GraphData(BaseModel):
|
||||
nodes: List[Dict[str, Any]]
|
||||
links: List[Dict[str, Any]]
|
||||
|
||||
class GraphGenerationRequest(BaseModel):
|
||||
num_nodes: int
|
||||
pattern: GraphPattern = GraphPattern.SCALE_FREE
|
||||
avg_degree: Optional[int] = 5
|
||||
num_clusters: Optional[int] = 100
|
||||
small_world_k: Optional[int] = 6
|
||||
small_world_p: Optional[float] = 0.1
|
||||
grid_dimensions: Optional[List[int]] = [100, 100]
|
||||
seed: Optional[int] = None
|
||||
|
||||
class UnifiedVisualizationRequest(BaseModel):
|
||||
graph_data: GraphData
|
||||
processing_mode: ProcessingMode = ProcessingMode.PYGRAPHISTRY_CLOUD
|
||||
|
||||
# PyGraphistry Cloud options
|
||||
layout_type: Optional[str] = "force"
|
||||
gpu_acceleration: Optional[bool] = True
|
||||
clustering: Optional[bool] = False
|
||||
|
||||
# Local GPU options
|
||||
layout_algorithm: Optional[str] = "force_atlas2"
|
||||
clustering_algorithm: Optional[str] = "leiden"
|
||||
compute_centrality: Optional[bool] = True
|
||||
|
||||
class GraphGenerationStatus(BaseModel):
|
||||
task_id: str
|
||||
status: str # "running", "completed", "failed"
|
||||
progress: float
|
||||
message: str
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
# Import graph generation classes (keeping existing code)
|
||||
from pygraphistry_service import LargeGraphGenerator, init_graphistry
|
||||
|
||||
class LocalGPUProcessor:
|
||||
"""GPU-accelerated graph processing using cuGraph"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_gpu = HAS_RAPIDS
|
||||
logger.info(f"Local GPU Processor initialized (GPU: {self.use_gpu})")
|
||||
|
||||
def create_cugraph_from_data(self, nodes: List[Dict], edges: List[Dict]) -> Tuple['cugraph.Graph', 'cudf.DataFrame']:
|
||||
"""Create cuGraph from node/edge data"""
|
||||
if not self.use_gpu:
|
||||
raise RuntimeError("GPU libraries not available")
|
||||
|
||||
# Create edge dataframe
|
||||
edge_data = []
|
||||
for edge in edges:
|
||||
edge_data.append({
|
||||
'src': edge['source'],
|
||||
'dst': edge['target'],
|
||||
'weight': edge.get('weight', 1.0)
|
||||
})
|
||||
|
||||
# Convert to cuDF
|
||||
edges_df = cudf.DataFrame(edge_data)
|
||||
|
||||
# Create cuGraph
|
||||
G = cugraph.Graph()
|
||||
G.from_cudf_edgelist(edges_df, source='src', destination='dst', edge_attr='weight')
|
||||
|
||||
return G, edges_df
|
||||
|
||||
def compute_gpu_layout(self, G, algorithm: str = "force_atlas2") -> Dict[str, Tuple[float, float]]:
|
||||
"""Compute GPU-accelerated graph layout"""
|
||||
try:
|
||||
if algorithm == "force_atlas2":
|
||||
layout_df = cugraph.force_atlas2(G)
|
||||
elif algorithm == "fruchterman_reingold":
|
||||
layout_df = cugraph.spectral_layout(G, dim=2)
|
||||
else: # spectral
|
||||
layout_df = cugraph.spectral_layout(G, dim=2)
|
||||
|
||||
# Convert to dictionary
|
||||
positions = {}
|
||||
for _, row in layout_df.iterrows():
|
||||
node_id = str(row['vertex'])
|
||||
positions[node_id] = (float(row['x']), float(row['y']))
|
||||
|
||||
logger.info(f"Computed {algorithm} layout for {len(positions)} nodes on GPU")
|
||||
return positions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU layout computation failed: {e}")
|
||||
return {}
|
||||
|
||||
def compute_gpu_clustering(self, G, algorithm: str = "leiden") -> Dict[str, int]:
|
||||
"""Compute GPU-accelerated community detection"""
|
||||
try:
|
||||
if algorithm == "leiden":
|
||||
clusters_df, modularity = cugraph.leiden(G)
|
||||
elif algorithm == "louvain":
|
||||
clusters_df, modularity = cugraph.louvain(G)
|
||||
else: # spectral clustering
|
||||
clusters_df = cugraph.spectral_clustering(G, n_clusters=10)
|
||||
modularity = 0.0
|
||||
|
||||
# Convert to dictionary
|
||||
clusters = {}
|
||||
for _, row in clusters_df.iterrows():
|
||||
node_id = str(row['vertex'])
|
||||
clusters[node_id] = int(row['partition'])
|
||||
|
||||
logger.info(f"Computed {algorithm} clustering on GPU (modularity: {modularity:.3f})")
|
||||
return clusters
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU clustering failed: {e}")
|
||||
return {}
|
||||
|
||||
def compute_gpu_centrality(self, G) -> Dict[str, Dict[str, float]]:
|
||||
"""Compute GPU-accelerated centrality measures"""
|
||||
centrality_data = {}
|
||||
|
||||
try:
|
||||
# PageRank
|
||||
pagerank_df = cugraph.pagerank(G)
|
||||
pagerank = {}
|
||||
for _, row in pagerank_df.iterrows():
|
||||
pagerank[str(row['vertex'])] = float(row['pagerank'])
|
||||
centrality_data['pagerank'] = pagerank
|
||||
|
||||
# Betweenness centrality (for smaller graphs)
|
||||
if G.number_of_vertices() < 5000:
|
||||
betweenness_df = cugraph.betweenness_centrality(G)
|
||||
betweenness = {}
|
||||
for _, row in betweenness_df.iterrows():
|
||||
betweenness[str(row['vertex'])] = float(row['betweenness_centrality'])
|
||||
centrality_data['betweenness'] = betweenness
|
||||
|
||||
logger.info(f"Computed centrality measures on GPU")
|
||||
return centrality_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU centrality computation failed: {e}")
|
||||
return {}
|
||||
|
||||
class PyGraphistryProcessor:
|
||||
"""PyGraphistry cloud processing (existing functionality)"""
|
||||
|
||||
def __init__(self):
|
||||
self.initialized = init_graphistry()
|
||||
|
||||
async def process_graph_data(self, request: UnifiedVisualizationRequest) -> Dict[str, Any]:
|
||||
"""Process graph data with PyGraphistry GPU acceleration"""
|
||||
try:
|
||||
if not self.initialized:
|
||||
raise HTTPException(status_code=500, detail="PyGraphistry not initialized")
|
||||
|
||||
# Convert to pandas DataFrames for PyGraphistry
|
||||
nodes_df = pd.DataFrame(request.graph_data.nodes)
|
||||
edges_df = pd.DataFrame(request.graph_data.links)
|
||||
|
||||
# Ensure required columns exist
|
||||
if 'id' not in nodes_df.columns:
|
||||
nodes_df['id'] = nodes_df.index
|
||||
if 'source' not in edges_df.columns or 'target' not in edges_df.columns:
|
||||
raise HTTPException(status_code=400, detail="Links must have source and target columns")
|
||||
|
||||
logger.info(f"Processing graph with {len(nodes_df)} nodes and {len(edges_df)} edges")
|
||||
|
||||
# Create PyGraphistry graph object
|
||||
g = graphistry.edges(edges_df, 'source', 'target').nodes(nodes_df, 'id')
|
||||
|
||||
# Apply GPU-accelerated processing
|
||||
if request.gpu_acceleration:
|
||||
g = await self._apply_gpu_acceleration(g, request)
|
||||
|
||||
# Apply clustering if requested
|
||||
if request.clustering:
|
||||
g = await self._apply_clustering(g)
|
||||
|
||||
# Generate layout
|
||||
g = await self._generate_layout(g, request.layout_type)
|
||||
|
||||
# Extract processed data
|
||||
processed_nodes = g._nodes.to_dict('records') if g._nodes is not None else nodes_df.to_dict('records')
|
||||
processed_edges = g._edges.to_dict('records') if g._edges is not None else edges_df.to_dict('records')
|
||||
|
||||
# Generate embedding URL for interactive visualization
|
||||
embed_url = None
|
||||
local_viz_data = None
|
||||
|
||||
try:
|
||||
embed_url = g.plot(render=False)
|
||||
logger.info(f"Generated PyGraphistry embed URL: {embed_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not generate embed URL (likely running in local mode): {e}")
|
||||
|
||||
# Create local visualization data as fallback
|
||||
try:
|
||||
local_viz_data = self._create_local_viz_data(g, processed_nodes, processed_edges)
|
||||
logger.info("Generated local visualization data as fallback")
|
||||
except Exception as viz_e:
|
||||
logger.warning(f"Could not generate local visualization data: {viz_e}")
|
||||
|
||||
return {
|
||||
"processed_nodes": processed_nodes,
|
||||
"processed_edges": processed_edges,
|
||||
"embed_url": embed_url,
|
||||
"local_viz_data": local_viz_data,
|
||||
"processing_mode": ProcessingMode.PYGRAPHISTRY_CLOUD,
|
||||
"stats": {
|
||||
"node_count": len(processed_nodes),
|
||||
"edge_count": len(processed_edges),
|
||||
"gpu_accelerated": request.gpu_acceleration,
|
||||
"clustered": request.clustering,
|
||||
"layout_type": request.layout_type,
|
||||
"has_embed_url": embed_url is not None,
|
||||
"has_local_viz": local_viz_data is not None
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# ... (include other PyGraphistry methods from original service)
|
||||
async def _apply_gpu_acceleration(self, g, request):
|
||||
# Implementation from original service
|
||||
pass
|
||||
|
||||
async def _apply_clustering(self, g):
|
||||
# Implementation from original service
|
||||
pass
|
||||
|
||||
async def _generate_layout(self, g, layout_type):
|
||||
# Implementation from original service
|
||||
pass
|
||||
|
||||
def _create_local_viz_data(self, g, processed_nodes, processed_edges):
|
||||
# Implementation from original service
|
||||
pass
|
||||
|
||||
class UnifiedGPUService:
|
||||
"""Unified service offering both PyGraphistry cloud and local GPU processing"""
|
||||
|
||||
def __init__(self):
|
||||
self.pygraphistry_processor = PyGraphistryProcessor()
|
||||
self.local_gpu_processor = LocalGPUProcessor()
|
||||
self.generation_tasks = {}
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def process_graph(self, request: UnifiedVisualizationRequest) -> Dict[str, Any]:
|
||||
"""Process graph with selected processing mode"""
|
||||
|
||||
if request.processing_mode == ProcessingMode.PYGRAPHISTRY_CLOUD:
|
||||
return await self.pygraphistry_processor.process_graph_data(request)
|
||||
|
||||
elif request.processing_mode == ProcessingMode.LOCAL_GPU:
|
||||
return await self._process_with_local_gpu(request)
|
||||
|
||||
else: # LOCAL_CPU
|
||||
return await self._process_with_local_cpu(request)
|
||||
|
||||
async def _process_with_local_gpu(self, request: UnifiedVisualizationRequest) -> Dict[str, Any]:
|
||||
"""Process graph with local GPU acceleration"""
|
||||
try:
|
||||
nodes = request.graph_data.nodes
|
||||
edges = request.graph_data.links
|
||||
|
||||
result = {
|
||||
"processed_nodes": nodes.copy(),
|
||||
"processed_edges": edges.copy(),
|
||||
"processing_mode": ProcessingMode.LOCAL_GPU,
|
||||
"gpu_processed": False,
|
||||
"layout_positions": {},
|
||||
"clusters": {},
|
||||
"centrality": {},
|
||||
"stats": {},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if self.local_gpu_processor.use_gpu:
|
||||
logger.info("=== LOCAL GPU PROCESSING START ===")
|
||||
|
||||
# Create cuGraph
|
||||
G, edges_df = self.local_gpu_processor.create_cugraph_from_data(nodes, edges)
|
||||
|
||||
# Compute layout on GPU
|
||||
positions = self.local_gpu_processor.compute_gpu_layout(G, request.layout_algorithm)
|
||||
if positions:
|
||||
result["layout_positions"] = positions
|
||||
# Add positions to nodes
|
||||
for node in result["processed_nodes"]:
|
||||
node_id = str(node["id"])
|
||||
if node_id in positions:
|
||||
node["x"], node["y"] = positions[node_id]
|
||||
|
||||
# Compute clustering on GPU
|
||||
clusters = self.local_gpu_processor.compute_gpu_clustering(G, request.clustering_algorithm)
|
||||
if clusters:
|
||||
result["clusters"] = clusters
|
||||
# Add cluster info to nodes
|
||||
for node in result["processed_nodes"]:
|
||||
node_id = str(node["id"])
|
||||
if node_id in clusters:
|
||||
node["cluster"] = clusters[node_id]
|
||||
|
||||
# Compute centrality on GPU
|
||||
if request.compute_centrality:
|
||||
centrality = self.local_gpu_processor.compute_gpu_centrality(G)
|
||||
result["centrality"] = centrality
|
||||
# Add centrality to nodes
|
||||
for node in result["processed_nodes"]:
|
||||
node_id = str(node["id"])
|
||||
for metric, values in centrality.items():
|
||||
if node_id in values:
|
||||
node[metric] = values[node_id]
|
||||
|
||||
result["gpu_processed"] = True
|
||||
result["stats"] = {
|
||||
"node_count": len(nodes),
|
||||
"edge_count": len(edges),
|
||||
"gpu_accelerated": True,
|
||||
"layout_computed": len(positions) > 0,
|
||||
"clusters_computed": len(clusters) > 0,
|
||||
"centrality_computed": len(centrality) > 0
|
||||
}
|
||||
|
||||
logger.info("=== LOCAL GPU PROCESSING COMPLETE ===")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Local GPU processing failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def _process_with_local_cpu(self, request: UnifiedVisualizationRequest) -> Dict[str, Any]:
|
||||
"""Process graph with local CPU (NetworkX fallback)"""
|
||||
# Simple CPU fallback using NetworkX
|
||||
nodes = request.graph_data.nodes
|
||||
edges = request.graph_data.links
|
||||
|
||||
return {
|
||||
"processed_nodes": nodes,
|
||||
"processed_edges": edges,
|
||||
"processing_mode": ProcessingMode.LOCAL_CPU,
|
||||
"gpu_processed": False,
|
||||
"stats": {
|
||||
"node_count": len(nodes),
|
||||
"edge_count": len(edges),
|
||||
"gpu_accelerated": False
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def broadcast_update(self, data: Dict[str, Any]):
|
||||
"""Broadcast updates to all connected WebSocket clients"""
|
||||
if self.active_connections:
|
||||
message = json.dumps(data)
|
||||
for connection in self.active_connections.copy():
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except WebSocketDisconnect:
|
||||
self.active_connections.remove(connection)
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="Unified GPU Graph Visualization Service", version="2.0.0")
|
||||
service = UnifiedGPUService()
|
||||
|
||||
@app.post("/api/visualize")
|
||||
async def visualize_graph(request: UnifiedVisualizationRequest):
|
||||
"""Process graph with unified service (supports all processing modes)"""
|
||||
result = await service.process_graph(request)
|
||||
|
||||
# Broadcast to connected WebSocket clients
|
||||
await service.broadcast_update({
|
||||
"type": "graph_processed",
|
||||
"data": result
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate_graph(request: GraphGenerationRequest):
|
||||
"""Start graph generation as background task"""
|
||||
if request.num_nodes > 1000000:
|
||||
raise HTTPException(status_code=400, detail="Maximum 1 million nodes allowed")
|
||||
|
||||
# Use existing graph generation logic
|
||||
task_id = f"gen_{int(time.time() * 1000)}"
|
||||
# Implementation would go here...
|
||||
return {"task_id": task_id, "status": "started"}
|
||||
|
||||
@app.get("/api/capabilities")
|
||||
async def get_capabilities():
|
||||
"""Get GPU capabilities and available processing modes"""
|
||||
return {
|
||||
"processing_modes": {
|
||||
"pygraphistry_cloud": {
|
||||
"available": service.pygraphistry_processor.initialized,
|
||||
"description": "PyGraphistry cloud GPU processing with interactive embeds"
|
||||
},
|
||||
"local_gpu": {
|
||||
"available": HAS_RAPIDS,
|
||||
"description": "Local GPU processing with cuGraph/RAPIDS"
|
||||
},
|
||||
"local_cpu": {
|
||||
"available": True,
|
||||
"description": "Local CPU fallback processing"
|
||||
}
|
||||
},
|
||||
"has_rapids": HAS_RAPIDS,
|
||||
"has_torch_geometric": HAS_TORCH_GEOMETRIC,
|
||||
"gpu_available": HAS_RAPIDS,
|
||||
"supported_layouts": ["force_atlas2", "spectral", "fruchterman_reingold"],
|
||||
"supported_clustering": ["leiden", "louvain", "spectral"]
|
||||
}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates"""
|
||||
await websocket.accept()
|
||||
service.active_connections.append(websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
service.active_connections.remove(websocket)
|
||||
|
||||
@app.get("/api/sample-graph")
|
||||
async def get_sample_graph():
|
||||
"""Get a sample graph for testing"""
|
||||
return {
|
||||
"nodes": [
|
||||
{"id": "1", "name": "Central Hub", "group": "core"},
|
||||
{"id": "2", "name": "Data Source A", "group": "input"},
|
||||
{"id": "3", "name": "Data Source B", "group": "input"},
|
||||
{"id": "4", "name": "Processing Unit", "group": "compute"},
|
||||
{"id": "5", "name": "Output A", "group": "output"},
|
||||
{"id": "6", "name": "Output B", "group": "output"},
|
||||
{"id": "7", "name": "Analytics", "group": "analysis"},
|
||||
{"id": "8", "name": "Storage", "group": "storage"}
|
||||
],
|
||||
"links": [
|
||||
{"source": "2", "target": "1", "name": "data_feed"},
|
||||
{"source": "3", "target": "1", "name": "data_feed"},
|
||||
{"source": "1", "target": "4", "name": "process"},
|
||||
{"source": "4", "target": "5", "name": "output"},
|
||||
{"source": "4", "target": "6", "name": "output"},
|
||||
{"source": "1", "target": "7", "name": "analyze"},
|
||||
{"source": "1", "target": "8", "name": "store"}
|
||||
]
|
||||
}
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"pygraphistry_initialized": service.pygraphistry_processor.initialized,
|
||||
"local_gpu_available": HAS_RAPIDS,
|
||||
"torch_geometric": HAS_TORCH_GEOMETRIC,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def get_visualization_page():
|
||||
"""Serve the interactive visualization page"""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Unified GPU Graph Visualization</title>
|
||||
<script src="https://d3js.org/d3.v7.min.js"></script>
|
||||
<style>
|
||||
body { margin: 0; font-family: Arial, sans-serif; background: #1a1a1a; color: white; }
|
||||
#controls { position: absolute; top: 10px; left: 10px; z-index: 100; background: rgba(0,0,0,0.8); padding: 10px; border-radius: 5px; }
|
||||
#graph { width: 100vw; height: 100vh; }
|
||||
.node { cursor: pointer; }
|
||||
.link { stroke: #999; stroke-opacity: 0.6; }
|
||||
button { margin: 5px; padding: 5px 10px; }
|
||||
select { margin: 5px; padding: 5px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="controls">
|
||||
<h3>⚡ Unified GPU Visualization</h3>
|
||||
<div>
|
||||
<label>Processing Mode:</label>
|
||||
<select id="processingMode">
|
||||
<option value="pygraphistry_cloud">PyGraphistry Cloud</option>
|
||||
<option value="local_gpu">Local GPU (cuGraph)</option>
|
||||
<option value="local_cpu">Local CPU</option>
|
||||
</select>
|
||||
</div>
|
||||
<button onclick="loadSampleGraph()">Load Sample Graph</button>
|
||||
<div id="status">Ready - Select processing mode and load graph</div>
|
||||
</div>
|
||||
<div id="graph"></div>
|
||||
|
||||
<script>
|
||||
let currentGraph = null;
|
||||
let simulation = null;
|
||||
|
||||
async function loadSampleGraph() {
|
||||
const mode = document.getElementById('processingMode').value;
|
||||
document.getElementById("status").innerHTML = `Loading sample graph...`;
|
||||
|
||||
try {
|
||||
// Get sample graph data
|
||||
const graphResponse = await fetch('/api/sample-graph');
|
||||
const graphData = await graphResponse.json();
|
||||
|
||||
document.getElementById("status").innerHTML = `Processing with ${mode}...`;
|
||||
|
||||
// Process the graph with selected mode
|
||||
const processResponse = await fetch('/api/visualize', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
graph_data: graphData,
|
||||
processing_mode: mode,
|
||||
layout_algorithm: 'force_atlas2',
|
||||
clustering_algorithm: 'leiden',
|
||||
compute_centrality: true
|
||||
})
|
||||
});
|
||||
|
||||
const result = await processResponse.json();
|
||||
|
||||
if (result.processed_nodes && result.processed_edges) {
|
||||
document.getElementById("status").innerHTML =
|
||||
`✅ Processed ${result.processed_nodes.length} nodes, ${result.processed_edges.length} edges (GPU: ${result.gpu_processed})`;
|
||||
|
||||
// Visualize with D3.js
|
||||
visualizeGraph({
|
||||
nodes: result.processed_nodes,
|
||||
links: result.processed_edges
|
||||
});
|
||||
} else {
|
||||
document.getElementById("status").innerHTML = `❌ Error: ${result.detail || 'Unknown error'}`;
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
document.getElementById("status").innerHTML = `❌ Error: ${error.message}`;
|
||||
console.error('Error:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function visualizeGraph(graph) {
|
||||
// Clear previous visualization
|
||||
d3.select("#graph").selectAll("*").remove();
|
||||
|
||||
const width = window.innerWidth;
|
||||
const height = window.innerHeight;
|
||||
|
||||
const svg = d3.select("#graph")
|
||||
.append("svg")
|
||||
.attr("width", width)
|
||||
.attr("height", height);
|
||||
|
||||
// Create force simulation
|
||||
simulation = d3.forceSimulation(graph.nodes)
|
||||
.force("link", d3.forceLink(graph.links).id(d => d.id).distance(100))
|
||||
.force("charge", d3.forceManyBody().strength(-300))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
|
||||
// Create links
|
||||
const link = svg.append("g")
|
||||
.selectAll("line")
|
||||
.data(graph.links)
|
||||
.enter().append("line")
|
||||
.attr("class", "link")
|
||||
.attr("stroke-width", 2);
|
||||
|
||||
// Create nodes
|
||||
const node = svg.append("g")
|
||||
.selectAll("circle")
|
||||
.data(graph.nodes)
|
||||
.enter().append("circle")
|
||||
.attr("class", "node")
|
||||
.attr("r", 8)
|
||||
.attr("fill", d => d.group === 'core' ? '#ff6b6b' :
|
||||
d.group === 'input' ? '#4ecdc4' :
|
||||
d.group === 'output' ? '#45b7d1' :
|
||||
d.group === 'compute' ? '#f9ca24' : '#6c5ce7')
|
||||
.call(d3.drag()
|
||||
.on("start", dragstarted)
|
||||
.on("drag", dragged)
|
||||
.on("end", dragended));
|
||||
|
||||
// Add labels
|
||||
const label = svg.append("g")
|
||||
.selectAll("text")
|
||||
.data(graph.nodes)
|
||||
.enter().append("text")
|
||||
.text(d => d.name)
|
||||
.attr("font-size", 12)
|
||||
.attr("fill", "white")
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("dy", ".35em");
|
||||
|
||||
// Update positions on tick
|
||||
simulation.on("tick", () => {
|
||||
link
|
||||
.attr("x1", d => d.source.x)
|
||||
.attr("y1", d => d.source.y)
|
||||
.attr("x2", d => d.target.x)
|
||||
.attr("y2", d => d.target.y);
|
||||
|
||||
node
|
||||
.attr("cx", d => d.x)
|
||||
.attr("cy", d => d.y);
|
||||
|
||||
label
|
||||
.attr("x", d => d.x)
|
||||
.attr("y", d => d.y + 20);
|
||||
});
|
||||
}
|
||||
|
||||
function dragstarted(event, d) {
|
||||
if (!event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x;
|
||||
d.fy = d.y;
|
||||
}
|
||||
|
||||
function dragged(event, d) {
|
||||
d.fx = event.x;
|
||||
d.fy = event.y;
|
||||
}
|
||||
|
||||
function dragended(event, d) {
|
||||
if (!event.active) simulation.alphaTarget(0);
|
||||
d.fx = null;
|
||||
d.fy = null;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
def startup_diagnostics():
|
||||
"""Run startup diagnostics and display system info"""
|
||||
print("🚀 Starting Unified GPU-accelerated graph visualization service...")
|
||||
print("Container: NVIDIA PyG with cuGraph/RAPIDS support")
|
||||
|
||||
# Check GPU availability
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(['nvidia-smi', '--query-gpu=gpu_name,memory.total,memory.used', '--format=csv,noheader,nounits'],
|
||||
capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
print("✓ GPU detected:")
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if line.strip():
|
||||
print(f" {line.strip()}")
|
||||
else:
|
||||
print("⚠ No GPU detected, will use CPU fallback")
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
|
||||
print("⚠ No GPU detected, will use CPU fallback")
|
||||
|
||||
# Check RAPIDS availability
|
||||
if HAS_RAPIDS:
|
||||
print("✓ RAPIDS cuGraph/cuDF/cuML available")
|
||||
else:
|
||||
print("⚠ RAPIDS not available")
|
||||
|
||||
# Check PyTorch Geometric
|
||||
if HAS_TORCH_GEOMETRIC:
|
||||
print("✓ PyTorch Geometric available")
|
||||
else:
|
||||
print("⚠ PyTorch Geometric not available")
|
||||
|
||||
# Check PyGraphistry credentials
|
||||
print("Checking PyGraphistry credentials...")
|
||||
personal_key = os.getenv('GRAPHISTRY_PERSONAL_KEY')
|
||||
secret_key = os.getenv('GRAPHISTRY_SECRET_KEY')
|
||||
api_key = os.getenv('GRAPHISTRY_API_KEY')
|
||||
|
||||
if personal_key and secret_key:
|
||||
print("✓ PyGraphistry personal key/secret found")
|
||||
elif api_key:
|
||||
print("✓ PyGraphistry API key found")
|
||||
else:
|
||||
print("⚠ No PyGraphistry credentials found - cloud mode will be limited")
|
||||
print(" Set GRAPHISTRY_PERSONAL_KEY + GRAPHISTRY_SECRET_KEY for full cloud features")
|
||||
|
||||
print("")
|
||||
print("🎯 Available Processing Modes:")
|
||||
print(" ☁️ PyGraphistry Cloud - Interactive GPU embeds (requires credentials)")
|
||||
print(" 🚀 Local GPU (cuGraph) - Full local GPU processing")
|
||||
print(" 💻 Local CPU - NetworkX fallback")
|
||||
print("")
|
||||
print("📊 Service starting on: http://0.0.0.0:8080")
|
||||
print("🎯 API Endpoints:")
|
||||
print(" - Unified processing: POST /api/visualize")
|
||||
print(" - Processing modes: GET /api/capabilities")
|
||||
print(" - Health check: GET /api/health")
|
||||
print(" - WebSocket updates: WS /ws")
|
||||
print("")
|
||||
|
||||
if __name__ == "__main__":
|
||||
startup_diagnostics()
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
@ -0,0 +1,661 @@
|
||||
# WebGL-Enhanced Remote GPU Rendering Service
|
||||
# Using Three.js for GPU-accelerated visualization
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
|
||||
class WebGLGPUVisualizationService:
|
||||
"""Enhanced remote GPU service with Three.js WebGL rendering"""
|
||||
|
||||
def _generate_threejs_webgl_html(self, session_data: dict, config: dict) -> str:
|
||||
"""Generate Three.js WebGL visualization with GPU-accelerated rendering"""
|
||||
|
||||
# Extract data
|
||||
nodes = session_data['processed_nodes']
|
||||
edges = session_data['processed_edges']
|
||||
layout_positions = session_data.get('layout_positions', {})
|
||||
clusters = session_data.get('clusters', {})
|
||||
centrality = session_data.get('centrality', {})
|
||||
|
||||
# Configuration
|
||||
animation_duration = config.get('animation_duration', 3000)
|
||||
show_splash = config.get('show_splash', True)
|
||||
auto_zoom = config.get('auto_zoom', True)
|
||||
show_labels = config.get('show_labels', True)
|
||||
background_color = config.get('background_color', '#0a0a0a')
|
||||
render_quality = config.get('render_quality', 'high')
|
||||
|
||||
# GPU rendering settings
|
||||
gpu_settings = self._get_gpu_rendering_settings(len(nodes), render_quality)
|
||||
|
||||
html_template = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>GPU-Accelerated WebGL Graph Visualization</title>
|
||||
<style>
|
||||
body {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
background-color: {background_color};
|
||||
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
|
||||
overflow: hidden;
|
||||
color: #ffffff;
|
||||
}}
|
||||
#container {{
|
||||
width: 100vw;
|
||||
height: 100vh;
|
||||
position: relative;
|
||||
}}
|
||||
#webgl-canvas {{
|
||||
display: block;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}}
|
||||
.performance-monitor {{
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
left: 10px;
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
padding: 10px;
|
||||
border-radius: 5px;
|
||||
font-size: 12px;
|
||||
color: #76B900;
|
||||
z-index: 100;
|
||||
}}
|
||||
.controls {{
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
right: 10px;
|
||||
display: flex;
|
||||
gap: 5px;
|
||||
z-index: 100;
|
||||
}}
|
||||
.control-btn {{
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
color: #76B900;
|
||||
border: 1px solid #76B900;
|
||||
padding: 5px 10px;
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
font-size: 11px;
|
||||
}}
|
||||
.control-btn:hover {{
|
||||
background: rgba(118, 185, 0, 0.2);
|
||||
}}
|
||||
.tooltip {{
|
||||
position: absolute;
|
||||
background: rgba(0, 0, 0, 0.9);
|
||||
color: #fff;
|
||||
padding: 8px;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
pointer-events: none;
|
||||
z-index: 200;
|
||||
border: 1px solid #76B900;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s ease;
|
||||
}}
|
||||
{"" if not show_splash else '''
|
||||
.splash-screen {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: linear-gradient(135deg, #0a0a0a 0%, #1a1a1a 100%);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 1000;
|
||||
transition: opacity 0.8s ease;
|
||||
}
|
||||
.splash-logo {
|
||||
font-size: 2rem;
|
||||
font-weight: 700;
|
||||
color: #76B900;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.loading-progress {
|
||||
width: 300px;
|
||||
height: 4px;
|
||||
background: #333;
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.loading-bar {
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #76B900, #a8d45a);
|
||||
width: 0%;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
'''}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="container">
|
||||
<canvas id="webgl-canvas"></canvas>
|
||||
|
||||
<!-- Performance Monitor -->
|
||||
<div class="performance-monitor">
|
||||
<div>🚀 WebGL GPU Rendering</div>
|
||||
<div>FPS: <span id="fps">--</span></div>
|
||||
<div>Nodes: {len(nodes):,}</div>
|
||||
<div>Triangles: <span id="triangles">--</span></div>
|
||||
<div>Memory: <span id="memory">--</span>MB</div>
|
||||
</div>
|
||||
|
||||
<!-- Controls -->
|
||||
<div class="controls">
|
||||
<button class="control-btn" onclick="toggleAnimation()">⏸️ Animation</button>
|
||||
<button class="control-btn" onclick="resetCamera()">🎯 Reset</button>
|
||||
<button class="control-btn" onclick="toggleLabels()">🏷️ Labels</button>
|
||||
<button class="control-btn" onclick="toggleClusters()">🎨 Clusters</button>
|
||||
<button class="control-btn" onclick="exportImage()">📷 Export</button>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
<div id="tooltip" class="tooltip"></div>
|
||||
|
||||
{"" if not show_splash else '''
|
||||
<div id="splash-screen" class="splash-screen">
|
||||
<div class="splash-logo">GPU WebGL Visualization</div>
|
||||
<div style="color: #888; margin-bottom: 2rem; text-align: center;">
|
||||
Loading {len(nodes):,} nodes with GPU acceleration<br>
|
||||
Quality: {render_quality.title()} • WebGL 2.0
|
||||
</div>
|
||||
<div class="loading-progress">
|
||||
<div id="loading-bar" class="loading-bar"></div>
|
||||
</div>
|
||||
<div id="loading-text" style="color: #888; font-size: 14px;">Initializing WebGL...</div>
|
||||
</div>
|
||||
'''}
|
||||
</div>
|
||||
|
||||
<!-- Three.js Library (matching your package.json version) -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/0.176.0/three.min.js"></script>
|
||||
|
||||
<script>
|
||||
// Graph data from GPU processing
|
||||
const graphData = {{
|
||||
nodes: {json.dumps(nodes)},
|
||||
edges: {json.dumps(edges)},
|
||||
layoutPositions: {json.dumps(layout_positions)},
|
||||
clusters: {json.dumps(clusters)},
|
||||
centrality: {json.dumps(centrality)}
|
||||
}};
|
||||
|
||||
// GPU rendering configuration
|
||||
const gpuConfig = {{
|
||||
nodeCount: {len(nodes)},
|
||||
edgeCount: {len(edges)},
|
||||
maxInstancedNodes: {gpu_settings['max_instanced_nodes']},
|
||||
useInstancedMesh: {str(gpu_settings['use_instanced_mesh']).lower()},
|
||||
enableLOD: {str(gpu_settings['enable_lod']).lower()},
|
||||
frustumCulling: {str(gpu_settings['frustum_culling']).lower()},
|
||||
textureAtlasSize: {gpu_settings['texture_atlas_size']},
|
||||
animationDuration: {animation_duration},
|
||||
showLabels: {str(show_labels).lower()},
|
||||
autoZoom: {str(auto_zoom).lower()}
|
||||
}};
|
||||
|
||||
class WebGLGraphVisualization {{
|
||||
constructor() {{
|
||||
this.container = document.getElementById('container');
|
||||
this.canvas = document.getElementById('webgl-canvas');
|
||||
|
||||
// Performance monitoring
|
||||
this.frameCount = 0;
|
||||
this.lastTime = performance.now();
|
||||
this.isAnimating = true;
|
||||
this.labelsVisible = gpuConfig.showLabels;
|
||||
this.clustersVisible = true;
|
||||
|
||||
this.init();
|
||||
{"this.hideSplash();" if not show_splash else "this.showLoadingProgress();"}
|
||||
}}
|
||||
|
||||
init() {{
|
||||
// Initialize Three.js WebGL renderer with GPU optimizations
|
||||
this.renderer = new THREE.WebGLRenderer({{
|
||||
canvas: this.canvas,
|
||||
antialias: true,
|
||||
alpha: true,
|
||||
powerPreference: "high-performance"
|
||||
}});
|
||||
|
||||
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
this.renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
|
||||
this.renderer.setClearColor(0x0a0a0a, 1);
|
||||
|
||||
// Enable GPU optimizations
|
||||
this.renderer.sortObjects = false; // Disable sorting for performance
|
||||
|
||||
// Scene setup
|
||||
this.scene = new THREE.Scene();
|
||||
|
||||
// Camera setup with controls
|
||||
this.camera = new THREE.PerspectiveCamera(
|
||||
75, window.innerWidth / window.innerHeight, 0.1, 10000
|
||||
);
|
||||
this.camera.position.z = 1000;
|
||||
|
||||
// Add basic controls
|
||||
this.setupControls();
|
||||
|
||||
// Load and process graph data
|
||||
this.loadGraphData();
|
||||
|
||||
// Start render loop
|
||||
this.animate();
|
||||
|
||||
// Setup interaction
|
||||
this.setupInteraction();
|
||||
|
||||
// Start performance monitoring
|
||||
this.startPerformanceMonitoring();
|
||||
|
||||
console.log('WebGL Graph Visualization initialized');
|
||||
}}
|
||||
|
||||
setupControls() {{
|
||||
// Simple camera controls
|
||||
this.controls = {{
|
||||
mouseDown: false,
|
||||
mouseX: 0,
|
||||
mouseY: 0,
|
||||
targetX: 0,
|
||||
targetY: 0,
|
||||
zoom: 1
|
||||
}};
|
||||
|
||||
this.canvas.addEventListener('mousedown', (e) => {{
|
||||
this.controls.mouseDown = true;
|
||||
this.controls.mouseX = e.clientX;
|
||||
this.controls.mouseY = e.clientY;
|
||||
}});
|
||||
|
||||
this.canvas.addEventListener('mousemove', (e) => {{
|
||||
if (this.controls.mouseDown) {{
|
||||
const deltaX = e.clientX - this.controls.mouseX;
|
||||
const deltaY = e.clientY - this.controls.mouseY;
|
||||
|
||||
this.controls.targetX += deltaX * 2;
|
||||
this.controls.targetY -= deltaY * 2;
|
||||
|
||||
this.controls.mouseX = e.clientX;
|
||||
this.controls.mouseY = e.clientY;
|
||||
}}
|
||||
}});
|
||||
|
||||
this.canvas.addEventListener('mouseup', () => {{
|
||||
this.controls.mouseDown = false;
|
||||
}});
|
||||
|
||||
this.canvas.addEventListener('wheel', (e) => {{
|
||||
e.preventDefault();
|
||||
this.controls.zoom *= (1 - e.deltaY * 0.001);
|
||||
this.controls.zoom = Math.max(0.1, Math.min(10, this.controls.zoom));
|
||||
}});
|
||||
}}
|
||||
|
||||
loadGraphData() {{
|
||||
console.log('Loading graph data with WebGL...');
|
||||
|
||||
// Create node geometries and materials
|
||||
this.createNodeVisualization();
|
||||
this.createEdgeVisualization();
|
||||
|
||||
if (this.labelsVisible) {{
|
||||
this.createLabelVisualization();
|
||||
}}
|
||||
|
||||
console.log('Graph data loaded successfully');
|
||||
}}
|
||||
|
||||
createNodeVisualization() {{
|
||||
const nodeCount = graphData.nodes.length;
|
||||
|
||||
if (gpuConfig.useInstancedMesh && nodeCount > 1000) {{
|
||||
// GPU-accelerated instanced rendering for large graphs
|
||||
console.log('Using GPU instanced mesh for', nodeCount, 'nodes');
|
||||
|
||||
const geometry = new THREE.CircleGeometry(1, 8);
|
||||
const material = new THREE.MeshBasicMaterial({{
|
||||
vertexColors: true,
|
||||
transparent: true,
|
||||
opacity: 0.8
|
||||
}});
|
||||
|
||||
this.nodesMesh = new THREE.InstancedMesh(geometry, material, nodeCount);
|
||||
|
||||
// Set up instance matrices and colors
|
||||
const matrix = new THREE.Matrix4();
|
||||
const color = new THREE.Color();
|
||||
|
||||
graphData.nodes.forEach((node, i) => {{
|
||||
// Position from GPU-computed layout
|
||||
const pos = graphData.layoutPositions[node.id] || [0, 0];
|
||||
const x = pos[0] - 500; // Center
|
||||
const y = pos[1] - 500;
|
||||
|
||||
// Size based on centrality
|
||||
const centrality = node.pagerank || 0.001;
|
||||
const size = Math.max(2, Math.sqrt(centrality * 10000) + 3);
|
||||
|
||||
// Color based on cluster
|
||||
const cluster = node.cluster || 0;
|
||||
const clusterColor = this.getClusterColor(cluster);
|
||||
|
||||
// Set instance transform
|
||||
matrix.makeScale(size, size, 1);
|
||||
matrix.setPosition(x, y, 0);
|
||||
this.nodesMesh.setMatrixAt(i, matrix);
|
||||
|
||||
// Set instance color
|
||||
color.setHex(clusterColor);
|
||||
this.nodesMesh.setColorAt(i, color);
|
||||
}});
|
||||
|
||||
this.nodesMesh.instanceMatrix.needsUpdate = true;
|
||||
this.nodesMesh.instanceColor.needsUpdate = true;
|
||||
|
||||
this.scene.add(this.nodesMesh);
|
||||
|
||||
}} else {{
|
||||
// Standard mesh rendering for smaller graphs
|
||||
console.log('Using standard mesh rendering for', nodeCount, 'nodes');
|
||||
|
||||
this.nodesGroup = new THREE.Group();
|
||||
|
||||
graphData.nodes.forEach((node, i) => {{
|
||||
const pos = graphData.layoutPositions[node.id] || [0, 0];
|
||||
const x = pos[0] - 500;
|
||||
const y = pos[1] - 500;
|
||||
|
||||
const centrality = node.pagerank || 0.001;
|
||||
const size = Math.max(2, Math.sqrt(centrality * 10000) + 3);
|
||||
|
||||
const cluster = node.cluster || 0;
|
||||
const clusterColor = this.getClusterColor(cluster);
|
||||
|
||||
const geometry = new THREE.CircleGeometry(size, 8);
|
||||
const material = new THREE.MeshBasicMaterial({{
|
||||
color: clusterColor,
|
||||
transparent: true,
|
||||
opacity: 0.8
|
||||
}});
|
||||
|
||||
const nodeMesh = new THREE.Mesh(geometry, material);
|
||||
nodeMesh.position.set(x, y, 0);
|
||||
nodeMesh.userData = {{ nodeData: node, nodeIndex: i }};
|
||||
|
||||
this.nodesGroup.add(nodeMesh);
|
||||
}});
|
||||
|
||||
this.scene.add(this.nodesGroup);
|
||||
}}
|
||||
}}
|
||||
|
||||
createEdgeVisualization() {{
|
||||
console.log('Creating edge visualization...');
|
||||
|
||||
const edgeCount = graphData.edges.length;
|
||||
const positions = new Float32Array(edgeCount * 6); // 2 vertices * 3 coordinates
|
||||
const colors = new Float32Array(edgeCount * 6); // 2 vertices * 3 colors
|
||||
|
||||
graphData.edges.forEach((edge, i) => {{
|
||||
const sourcePos = graphData.layoutPositions[edge.source] || [0, 0];
|
||||
const targetPos = graphData.layoutPositions[edge.target] || [0, 0];
|
||||
|
||||
const idx = i * 6;
|
||||
|
||||
// Source vertex
|
||||
positions[idx] = sourcePos[0] - 500;
|
||||
positions[idx + 1] = sourcePos[1] - 500;
|
||||
positions[idx + 2] = 0;
|
||||
|
||||
// Target vertex
|
||||
positions[idx + 3] = targetPos[0] - 500;
|
||||
positions[idx + 4] = targetPos[1] - 500;
|
||||
positions[idx + 5] = 0;
|
||||
|
||||
// Edge color (gray)
|
||||
colors[idx] = colors[idx + 3] = 0.3;
|
||||
colors[idx + 1] = colors[idx + 4] = 0.3;
|
||||
colors[idx + 2] = colors[idx + 5] = 0.3;
|
||||
}});
|
||||
|
||||
const edgeGeometry = new THREE.BufferGeometry();
|
||||
edgeGeometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
||||
edgeGeometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
|
||||
|
||||
const edgeMaterial = new THREE.LineBasicMaterial({{
|
||||
vertexColors: true,
|
||||
transparent: true,
|
||||
opacity: 0.4
|
||||
}});
|
||||
|
||||
this.edgesMesh = new THREE.LineSegments(edgeGeometry, edgeMaterial);
|
||||
this.scene.add(this.edgesMesh);
|
||||
}}
|
||||
|
||||
createLabelVisualization() {{
|
||||
// Canvas-based text rendering for labels
|
||||
this.labelCanvases = [];
|
||||
|
||||
graphData.nodes.forEach((node, i) => {{
|
||||
if (i > 500) return; // Limit labels for performance
|
||||
|
||||
const canvas = document.createElement('canvas');
|
||||
const context = canvas.getContext('2d');
|
||||
canvas.width = 256;
|
||||
canvas.height = 64;
|
||||
|
||||
context.fillStyle = '#ffffff';
|
||||
context.font = '16px Arial';
|
||||
context.textAlign = 'center';
|
||||
context.fillText(node.name || node.id, 128, 32);
|
||||
|
||||
const texture = new THREE.CanvasTexture(canvas);
|
||||
const material = new THREE.SpriteMaterial({{ map: texture }});
|
||||
const sprite = new THREE.Sprite(material);
|
||||
|
||||
const pos = graphData.layoutPositions[node.id] || [0, 0];
|
||||
sprite.position.set(pos[0] - 500, pos[1] - 480, 1);
|
||||
sprite.scale.set(50, 12.5, 1);
|
||||
|
||||
this.scene.add(sprite);
|
||||
this.labelCanvases.push(sprite);
|
||||
}});
|
||||
}}
|
||||
|
||||
getClusterColor(cluster) {{
|
||||
// Midnight Tokyo inspired color palette - neon colors in hex format for WebGL
|
||||
const colors = [
|
||||
0xFF0080, // Hot pink neon
|
||||
0x00FFFF, // Electric cyan
|
||||
0xFF4081, // Neon pink
|
||||
0x8A2BE2, // Electric purple
|
||||
0x00FF41, // Matrix green
|
||||
0xFF6B35, // Neon orange
|
||||
0x1E90FF, // Electric blue
|
||||
0xFF1493, // Deep pink
|
||||
0x00CED1, // Dark turquoise
|
||||
0x9932CC, // Dark orchid
|
||||
0x32CD32, // Lime green
|
||||
0xFF4500, // Orange red
|
||||
0x4169E1, // Royal blue
|
||||
0xDC143C, // Crimson
|
||||
0x00FA9A, // Medium spring green
|
||||
0xFF69B4, // Hot pink
|
||||
0x1E88E5, // Blue
|
||||
0xE91E63, // Pink
|
||||
0x00E676, // Green
|
||||
0xFF5722, // Deep orange
|
||||
0x673AB7, // Deep purple
|
||||
0x03DAC6, // Teal
|
||||
0xBB86FC, // Light purple
|
||||
0xCF6679 // Light pink
|
||||
];
|
||||
return colors[cluster % colors.length];
|
||||
}}
|
||||
|
||||
animate() {{
|
||||
requestAnimationFrame(() => this.animate());
|
||||
|
||||
if (this.isAnimating) {{
|
||||
// Smooth camera movement
|
||||
this.camera.position.x += (this.controls.targetX - this.camera.position.x) * 0.05;
|
||||
this.camera.position.y += (this.controls.targetY - this.camera.position.y) * 0.05;
|
||||
this.camera.zoom += (this.controls.zoom - this.camera.zoom) * 0.05;
|
||||
this.camera.updateProjectionMatrix();
|
||||
}}
|
||||
|
||||
// Render with GPU
|
||||
this.renderer.render(this.scene, this.camera);
|
||||
|
||||
// Update performance monitor
|
||||
this.updatePerformanceMonitor();
|
||||
}}
|
||||
|
||||
setupInteraction() {{
|
||||
const raycaster = new THREE.Raycaster();
|
||||
const mouse = new THREE.Vector2();
|
||||
const tooltip = document.getElementById('tooltip');
|
||||
|
||||
this.canvas.addEventListener('mousemove', (event) => {{
|
||||
if (this.controls.mouseDown) return;
|
||||
|
||||
mouse.x = (event.clientX / window.innerWidth) * 2 - 1;
|
||||
mouse.y = -(event.clientY / window.innerHeight) * 2 + 1;
|
||||
|
||||
raycaster.setFromCamera(mouse, this.camera);
|
||||
|
||||
let intersects = [];
|
||||
if (this.nodesGroup) {{
|
||||
intersects = raycaster.intersectObjects(this.nodesGroup.children);
|
||||
}}
|
||||
|
||||
if (intersects.length > 0) {{
|
||||
const nodeData = intersects[0].object.userData.nodeData;
|
||||
tooltip.innerHTML = `
|
||||
<strong>${{nodeData.name || nodeData.id}}</strong><br>
|
||||
Cluster: ${{nodeData.cluster || 'N/A'}}<br>
|
||||
PageRank: ${{(nodeData.pagerank || 0).toFixed(4)}}
|
||||
`;
|
||||
tooltip.style.left = (event.clientX + 10) + 'px';
|
||||
tooltip.style.top = (event.clientY - 10) + 'px';
|
||||
tooltip.style.opacity = '1';
|
||||
}} else {{
|
||||
tooltip.style.opacity = '0';
|
||||
}}
|
||||
}});
|
||||
}}
|
||||
|
||||
startPerformanceMonitoring() {{
|
||||
setInterval(() => {{
|
||||
const now = performance.now();
|
||||
const fps = Math.round((this.frameCount * 1000) / (now - this.lastTime));
|
||||
|
||||
document.getElementById('fps').textContent = fps;
|
||||
document.getElementById('triangles').textContent =
|
||||
(this.renderer.info.render.triangles || 0).toLocaleString();
|
||||
document.getElementById('memory').textContent =
|
||||
Math.round(this.renderer.info.memory.geometries + this.renderer.info.memory.textures);
|
||||
|
||||
this.frameCount = 0;
|
||||
this.lastTime = now;
|
||||
}}, 1000);
|
||||
}}
|
||||
|
||||
updatePerformanceMonitor() {{
|
||||
this.frameCount++;
|
||||
}}
|
||||
|
||||
{"showLoadingProgress() { /* Loading animation */ }" if show_splash else ""}
|
||||
{"hideSplash() { /* Hide splash */ }" if show_splash else ""}
|
||||
|
||||
resetCamera() {{
|
||||
this.controls.targetX = 0;
|
||||
this.controls.targetY = 0;
|
||||
this.controls.zoom = 1;
|
||||
}}
|
||||
|
||||
toggleAnimation() {{
|
||||
this.isAnimating = !this.isAnimating;
|
||||
}}
|
||||
|
||||
toggleLabels() {{
|
||||
this.labelsVisible = !this.labelsVisible;
|
||||
this.labelCanvases.forEach(sprite => {{
|
||||
sprite.visible = this.labelsVisible;
|
||||
}});
|
||||
}}
|
||||
|
||||
toggleClusters() {{
|
||||
this.clustersVisible = !this.clustersVisible;
|
||||
// Toggle cluster coloring
|
||||
}}
|
||||
|
||||
exportImage() {{
|
||||
const link = document.createElement('a');
|
||||
link.download = 'webgl-graph.png';
|
||||
link.href = this.renderer.domElement.toDataURL();
|
||||
link.click();
|
||||
}}
|
||||
}}
|
||||
|
||||
// Global control functions
|
||||
window.toggleAnimation = () => window.graphViz.toggleAnimation();
|
||||
window.resetCamera = () => window.graphViz.resetCamera();
|
||||
window.toggleLabels = () => window.graphViz.toggleLabels();
|
||||
window.toggleClusters = () => window.graphViz.toggleClusters();
|
||||
window.exportImage = () => window.graphViz.exportImage();
|
||||
|
||||
// Handle window resize
|
||||
window.addEventListener('resize', () => {{
|
||||
if (window.graphViz) {{
|
||||
window.graphViz.camera.aspect = window.innerWidth / window.innerHeight;
|
||||
window.graphViz.camera.updateProjectionMatrix();
|
||||
window.graphViz.renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
}}
|
||||
}});
|
||||
|
||||
// Initialize when DOM is ready
|
||||
document.addEventListener('DOMContentLoaded', () => {{
|
||||
window.graphViz = new WebGLGraphVisualization();
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return html_template
|
||||
|
||||
def _get_gpu_rendering_settings(self, node_count: int, quality: str) -> Dict[str, Any]:
|
||||
"""Get GPU rendering settings based on graph size and quality"""
|
||||
|
||||
base_settings = {
|
||||
'max_instanced_nodes': 100000,
|
||||
'use_instanced_mesh': node_count > 1000,
|
||||
'enable_lod': node_count > 25000,
|
||||
'frustum_culling': node_count > 10000,
|
||||
'texture_atlas_size': 1024
|
||||
}
|
||||
|
||||
quality_multipliers = {
|
||||
'low': {'texture_atlas_size': 512, 'max_instanced_nodes': 50000},
|
||||
'medium': {'texture_atlas_size': 1024, 'max_instanced_nodes': 75000},
|
||||
'high': {'texture_atlas_size': 2048, 'max_instanced_nodes': 100000},
|
||||
'ultra': {'texture_atlas_size': 4096, 'max_instanced_nodes': 500000}
|
||||
}
|
||||
|
||||
settings = base_settings.copy()
|
||||
settings.update(quality_multipliers.get(quality, quality_multipliers['high']))
|
||||
|
||||
return settings
|
||||
9
nvidia/txt2kg/assets/deploy/services/ollama/Dockerfile
Normal file
9
nvidia/txt2kg/assets/deploy/services/ollama/Dockerfile
Normal file
@ -0,0 +1,9 @@
|
||||
FROM ollama/ollama:latest
|
||||
|
||||
# Copy the entrypoint script
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# Set the entrypoint
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
FROM ubuntu:22.04
|
||||
|
||||
# Install required packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
docker.io \
|
||||
bc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy the monitoring script
|
||||
COPY gpu_memory_monitor.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/gpu_memory_monitor.sh
|
||||
|
||||
# Create a non-root user
|
||||
RUN useradd -m -s /bin/bash monitor
|
||||
|
||||
# Set environment variables with defaults
|
||||
ENV CHECK_INTERVAL=60
|
||||
ENV MIN_AVAILABLE_PERCENT=70
|
||||
ENV AUTO_FIX=true
|
||||
|
||||
# Run as non-root user
|
||||
USER monitor
|
||||
WORKDIR /home/monitor
|
||||
|
||||
CMD ["/usr/local/bin/gpu_memory_monitor.sh"]
|
||||
252
nvidia/txt2kg/assets/deploy/services/ollama/NVIDIA_MPS_GUIDE.md
Normal file
252
nvidia/txt2kg/assets/deploy/services/ollama/NVIDIA_MPS_GUIDE.md
Normal file
@ -0,0 +1,252 @@
|
||||
# NVIDIA MPS Guide for Ollama GPU Optimization
|
||||
|
||||
## 🚀 Overview
|
||||
|
||||
NVIDIA Multi-Process Service (MPS) is a game-changing technology that enables multiple processes to share a single GPU context, eliminating expensive context switching overhead and dramatically improving concurrent workload performance.
|
||||
|
||||
This guide documents our discovery: **MPS transforms the DGX Spark from a single-threaded bottleneck into a high-throughput powerhouse**, achieving **3x concurrent performance** with near-perfect scaling.
|
||||
|
||||
## 📊 Performance Results Summary
|
||||
|
||||
### Triple Extraction Benchmark (llama3.1:8b)
|
||||
|
||||
| System | Mode | Individual Performance | Aggregate Throughput | Scaling Efficiency |
|
||||
|--------|------|----------------------|---------------------|-------------------|
|
||||
| **RTX 5090** | Single | ~300 tok/s | 300 tok/s | 100% (baseline) |
|
||||
| **Mac M4 Pro** | Single | ~45 tok/s | 45 tok/s | 100% (baseline) |
|
||||
| **DGX Spark** | Single (MPS) | 33.3 tok/s | 33.3 tok/s | 100% (baseline) |
|
||||
| **DGX Spark** | 2x Concurrent | ~33.2 tok/s each | **66.4 tok/s** | **97% efficiency** |
|
||||
| **DGX Spark** | 3x Concurrent | ~33.1 tok/s each | **99.4 tok/s** | **99% efficiency** |
|
||||
|
||||
### 🏆 Key Achievement
|
||||
**DGX Spark + MPS delivers 2.2x higher aggregate throughput than RTX 5090 in multi-request scenarios!**
|
||||
|
||||
## 🛠️ MPS Setup Instructions
|
||||
|
||||
### 1. Start MPS Server
|
||||
|
||||
```bash
|
||||
# Set MPS directory
|
||||
export CUDA_MPS_PIPE_DIRECTORY=/tmp/nvidia-mps
|
||||
mkdir -p /tmp/nvidia-mps
|
||||
|
||||
# Start MPS control daemon
|
||||
sudo env "CUDA_MPS_PIPE_DIRECTORY=/tmp/nvidia-mps" nvidia-cuda-mps-control -d
|
||||
```
|
||||
|
||||
### 2. Restart Ollama with MPS Support
|
||||
|
||||
```bash
|
||||
# Stop current Ollama
|
||||
cd /path/to/ollama
|
||||
docker compose down
|
||||
|
||||
# Start Ollama with MPS environment
|
||||
sudo env "CUDA_MPS_PIPE_DIRECTORY=/tmp/nvidia-mps" docker compose up -d
|
||||
```
|
||||
|
||||
### 3. Verify MPS is Working
|
||||
|
||||
```bash
|
||||
# Check MPS processes
|
||||
ps aux | grep mps
|
||||
|
||||
# Expected output:
|
||||
# root nvidia-cuda-mps-control -d
|
||||
# root nvidia-cuda-mps-server -force-tegra
|
||||
|
||||
# Check Ollama processes show M+C flag
|
||||
nvidia-smi
|
||||
# Look for M+C in the Type column for Ollama processes
|
||||
```
|
||||
|
||||
### 4. Stop MPS (when needed)
|
||||
|
||||
```bash
|
||||
sudo nvidia-cuda-mps-control quit
|
||||
```
|
||||
|
||||
## 🔬 Technical Architecture
|
||||
|
||||
### CUDA MPS Architecture
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ GPU (Single CUDA Context) │
|
||||
│ ├── MPS Server (Resource Manager) │
|
||||
│ ├── Ollama Process 1 ──┐ │
|
||||
│ ├── Ollama Process 2 ──┼── Shared │
|
||||
│ └── Ollama Process 3 ──┘ Context │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Traditional Multi-Process Architecture
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ GPU │
|
||||
│ ├── Process 1 (Context 1) ─────────────│
|
||||
│ ├── Process 2 (Context 2) ─────────────│
|
||||
│ └── Process 3 (Context 3) ─────────────│
|
||||
│ ↑ Context Switching Overhead │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## ⚖️ MPS vs Multiple API Servers Comparison
|
||||
|
||||
### 🚀 CUDA MPS Advantages
|
||||
|
||||
**Performance:**
|
||||
- ✅ No context switching overhead (single shared context)
|
||||
- ✅ Concurrent kernel execution from different processes
|
||||
- ✅ Lower latency for small requests
|
||||
- ✅ Better GPU utilization (kernels can overlap)
|
||||
|
||||
**Memory Efficiency:**
|
||||
- ✅ Shared GPU memory management
|
||||
- ✅ No duplicate driver overhead per process
|
||||
- ✅ More efficient memory allocation
|
||||
- ✅ Can fit more models in same memory
|
||||
|
||||
**Resource Management:**
|
||||
- ✅ Single point of GPU resource control
|
||||
- ✅ Automatic load balancing across processes
|
||||
- ✅ Better thermal management
|
||||
- ✅ Unified monitoring and debugging
|
||||
|
||||
### 🏢 Multiple API Servers Advantages
|
||||
|
||||
**Isolation & Reliability:**
|
||||
- ✅ Process isolation (one crash doesn't affect others)
|
||||
- ✅ Independent scaling per service
|
||||
- ✅ Different models can have different configurations
|
||||
- ✅ Easier to update/restart individual services
|
||||
|
||||
**Flexibility:**
|
||||
- ✅ Different frameworks (vLLM, TensorRT-LLM, etc.)
|
||||
- ✅ Per-service optimization
|
||||
- ✅ Independent monitoring and logging
|
||||
- ✅ Service-specific resource limits
|
||||
|
||||
**Operational:**
|
||||
- ✅ Standard container orchestration (K8s, Docker)
|
||||
- ✅ Familiar DevOps patterns
|
||||
- ✅ Load balancing at HTTP level
|
||||
- ✅ Rolling updates and deployments
|
||||
|
||||
## 🎯 Decision Framework
|
||||
|
||||
### Use CUDA MPS When:
|
||||
- 🏆 Maximum GPU utilization is critical
|
||||
- ⚡ Low latency is paramount
|
||||
- 💰 Cost optimization (more models per GPU)
|
||||
- 🔄 Same framework/runtime (e.g., all Ollama)
|
||||
- 📊 Predictable, homogeneous workloads
|
||||
- 🎮 Single-tenant environments
|
||||
|
||||
### Use Multiple API Servers When:
|
||||
- 🛡️ High availability/fault tolerance required
|
||||
- 🔧 Different models need different optimizations
|
||||
- 📈 Independent scaling per service needed
|
||||
- 🌐 Multi-tenant production environments
|
||||
- 🔄 Frequent model updates/deployments
|
||||
- 👥 Different teams managing different models
|
||||
|
||||
## 📊 Performance Impact Analysis
|
||||
|
||||
| Metric | CUDA MPS | Multiple Servers |
|
||||
|--------|----------|------------------|
|
||||
| Context Switch Overhead | ~0% | ~5-15% |
|
||||
| Memory Efficiency | ~95% | ~80-85% |
|
||||
| Latency (small requests) | Lower | Higher |
|
||||
| Throughput (concurrent) | Higher | Lower |
|
||||
| Fault Isolation | Lower | Higher |
|
||||
| Operational Complexity | Lower | Higher |
|
||||
|
||||
## 🔍 Memory Capacity Analysis
|
||||
|
||||
### Model Memory Requirements
|
||||
- **llama3.1:8b (Q4_K_M)**: ~4.9GB per instance
|
||||
|
||||
### System Comparison
|
||||
| System | Total Memory | Theoretical Max | Practical Max |
|
||||
|--------|--------------|----------------|---------------|
|
||||
| **RTX 5090** | 24GB VRAM | 4-5 models | 2-3 models |
|
||||
| **DGX Spark** | 120GB Unified | 20+ models | 10+ models |
|
||||
|
||||
### RTX 5090 Limitations:
|
||||
- ❌ Limited to 24GB VRAM (hard ceiling)
|
||||
- ❌ Driver overhead reduces available memory
|
||||
- ❌ Memory fragmentation issues
|
||||
- ❌ Thermal throttling under concurrent load
|
||||
- ❌ Context switching still expensive
|
||||
|
||||
### DGX Spark Advantages:
|
||||
- ✅ 5x more memory capacity (120GB vs 24GB)
|
||||
- ✅ Unified memory architecture
|
||||
- ✅ Better thermal design for sustained loads
|
||||
- ✅ Can scale to 10+ concurrent models
|
||||
- ✅ No VRAM bottleneck
|
||||
|
||||
## 🧪 Testing Concurrent Performance
|
||||
|
||||
### Single Instance Baseline
|
||||
```bash
|
||||
curl -X POST http://localhost:11434/api/chat \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "Your prompt here"}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
### Concurrent Testing
|
||||
```bash
|
||||
# Run multiple requests simultaneously
|
||||
curl [request1] & curl [request2] & curl [request3] & wait
|
||||
```
|
||||
|
||||
### Expected Results with MPS:
|
||||
- **1 instance**: 33.3 tok/s
|
||||
- **2 concurrent**: ~66.4 tok/s total (97% efficiency)
|
||||
- **3 concurrent**: ~99.4 tok/s total (99% efficiency)
|
||||
|
||||
## 🎯 Recommendations
|
||||
|
||||
### For Triple Extraction Workloads:
|
||||
**MPS is the optimal choice because:**
|
||||
1. **Homogeneous workload** - same model (llama3.1:8b)
|
||||
2. **Performance critical** - maximum throughput needed
|
||||
3. **Cost optimization** - more concurrent requests per GPU
|
||||
4. **Predictable usage** - biomedical triple extraction
|
||||
|
||||
### Hybrid Approach:
|
||||
Consider running:
|
||||
- **MPS in production** for maximum throughput
|
||||
- **Separate dev/test servers** for experimentation
|
||||
- **Different models** on separate instances when needed
|
||||
|
||||
## 🚨 Important Notes
|
||||
|
||||
1. **MPS requires careful setup** - ensure proper environment variables
|
||||
2. **Monitor GPU temperature** under heavy concurrent loads
|
||||
3. **Test thoroughly** before production deployment
|
||||
4. **Have fallback plan** to standard single-process mode
|
||||
5. **Consider workload patterns** - MPS excels with consistent concurrent requests
|
||||
|
||||
## 🔗 Related Files
|
||||
|
||||
- `docker-compose.yml` - Ollama service configuration
|
||||
- `ollama_gpu_benchmark.py` - Performance testing script
|
||||
- `clear_cache_and_restart.sh` - Memory optimization script
|
||||
- `gpu_memory_monitor.sh` - GPU monitoring script
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- [NVIDIA MPS Documentation](https://docs.nvidia.com/deploy/mps/index.html)
|
||||
- [CUDA Multi-Process Service Guide](https://docs.nvidia.com/cuda/mps/index.html)
|
||||
- [Ollama Documentation](https://ollama.ai/docs)
|
||||
|
||||
---
|
||||
|
||||
**Last Updated**: October 2, 2025
|
||||
**Tested On**: DGX Spark with 120GB unified memory, CUDA 13.0, Ollama latest
|
||||
@ -0,0 +1,78 @@
|
||||
# Ollama GPU Memory Monitoring
|
||||
|
||||
This setup includes automatic monitoring and fixing of GPU memory detection issues that can occur on unified memory systems (like DGX Spark, Jetson, etc.).
|
||||
|
||||
## The Problem
|
||||
|
||||
On unified memory systems, Ollama sometimes can't detect the full amount of available GPU memory due to buffer cache not being reclaimable. This causes models to fall back to CPU inference, dramatically reducing performance.
|
||||
|
||||
**Symptoms:**
|
||||
- Ollama logs show low "available" vs "total" GPU memory
|
||||
- Models show mixed CPU/GPU processing instead of 100% GPU
|
||||
- Performance is much slower than expected
|
||||
|
||||
## The Solution
|
||||
|
||||
This Docker Compose setup includes an optional GPU memory monitor that:
|
||||
|
||||
1. **Monitors** Ollama's GPU memory detection every 60 seconds
|
||||
2. **Detects** when available memory drops below 70% of total
|
||||
3. **Automatically fixes** the issue by clearing buffer cache and restarting Ollama
|
||||
4. **Logs** all actions for debugging
|
||||
|
||||
## Usage
|
||||
|
||||
### Standard Setup (Most Systems)
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
### Unified Memory Systems (DGX Spark, Jetson, etc.)
|
||||
```bash
|
||||
docker compose --profile unified-memory up -d
|
||||
```
|
||||
|
||||
This will start both Ollama and the GPU memory monitor.
|
||||
|
||||
## Configuration
|
||||
|
||||
The monitor can be configured via environment variables:
|
||||
|
||||
- `CHECK_INTERVAL=60` - How often to check (seconds)
|
||||
- `MIN_AVAILABLE_PERCENT=70` - Threshold for triggering fixes (percentage)
|
||||
- `AUTO_FIX=true` - Whether to automatically fix issues
|
||||
|
||||
## Manual Commands
|
||||
|
||||
You can still use the manual scripts if needed:
|
||||
|
||||
```bash
|
||||
# Check current GPU memory status
|
||||
./monitor_gpu_memory.sh
|
||||
|
||||
# Manually clear cache and restart
|
||||
./clear_cache_and_restart.sh
|
||||
```
|
||||
|
||||
## Monitoring Logs
|
||||
|
||||
To see what the monitor is doing:
|
||||
|
||||
```bash
|
||||
docker logs ollama-gpu-monitor -f
|
||||
```
|
||||
|
||||
## When to Use
|
||||
|
||||
Use the unified memory profile if you experience:
|
||||
- Inconsistent Ollama performance
|
||||
- Models loading on CPU instead of GPU
|
||||
- GPU memory showing as much lower than system RAM
|
||||
- You're on a system with unified memory (DGX, Jetson, etc.)
|
||||
|
||||
## Performance Impact
|
||||
|
||||
The monitor has minimal performance impact:
|
||||
- Runs one check every 60 seconds
|
||||
- Only takes action when issues are detected
|
||||
- Automatic fixes typically resolve issues within 30 seconds
|
||||
34
nvidia/txt2kg/assets/deploy/services/ollama/clear_cache_and_restart.sh
Executable file
34
nvidia/txt2kg/assets/deploy/services/ollama/clear_cache_and_restart.sh
Executable file
@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Clear buffer cache and restart Ollama to fix unified memory detection
|
||||
# This script addresses the issue where Ollama can't see full GPU memory
|
||||
# due to buffer cache not being reclaimable in unified memory systems
|
||||
#
|
||||
|
||||
set -e
|
||||
|
||||
echo "🧹 Clearing system buffer cache..."
|
||||
echo "Current memory status:"
|
||||
free -h
|
||||
|
||||
echo "Stopping Ollama container..."
|
||||
docker compose -f /home/nvidia/txt2kg/txt2kg/deploy/services/ollama/docker-compose.yml down
|
||||
|
||||
echo "Clearing buffer cache..."
|
||||
sudo sync
|
||||
sudo sh -c 'echo 1 > /proc/sys/vm/drop_caches'
|
||||
|
||||
echo "Memory status after cache clear:"
|
||||
free -h
|
||||
|
||||
echo "Restarting Ollama container..."
|
||||
docker compose -f /home/nvidia/txt2kg/txt2kg/deploy/services/ollama/docker-compose.yml up -d
|
||||
|
||||
echo "Waiting for Ollama to start..."
|
||||
sleep 10
|
||||
|
||||
echo "Checking GPU memory detection..."
|
||||
timeout 30 bash -c 'while ! docker logs ollama-server 2>&1 | grep -q "inference compute"; do sleep 1; done'
|
||||
docker logs ollama-server 2>&1 | grep "inference compute" | tail -1
|
||||
|
||||
echo "✅ Ollama restarted with cleared cache"
|
||||
@ -0,0 +1,66 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
ollama:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: ollama-custom:latest
|
||||
container_name: ollama-server
|
||||
ports:
|
||||
- "11434:11434"
|
||||
volumes:
|
||||
- ollama_models:/root/.ollama
|
||||
environment:
|
||||
- OLLAMA_HOST=0.0.0.0:11434
|
||||
- OLLAMA_FLASH_ATTENTION=1
|
||||
- OLLAMA_KEEP_ALIVE=30m
|
||||
- OLLAMA_CUDA=1
|
||||
# Performance tuning for large models like Llama3 70B
|
||||
- OLLAMA_LLM_LIBRARY=cuda
|
||||
- OLLAMA_NUM_PARALLEL=1 # Favor latency/stability for 70B; increase for smaller models
|
||||
- OLLAMA_MAX_LOADED_MODELS=1 # Avoid VRAM contention
|
||||
- OLLAMA_KV_CACHE_TYPE=q8_0 # Reduce KV cache VRAM with minimal perf impact
|
||||
# Removed restrictive settings for 70B model testing:
|
||||
# - OLLAMA_CONTEXT_LENGTH=8192 (let Ollama auto-detect)
|
||||
# - OLLAMA_NUM_PARALLEL=4 (let Ollama decide)
|
||||
# - OLLAMA_MAX_LOADED=1 (allow multiple models)
|
||||
# - OLLAMA_NUM_THREADS=16 (may force CPU usage)
|
||||
runtime: nvidia
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:11434/api/tags"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
# GPU Memory Monitor - only for unified memory systems like DGX Spark
|
||||
gpu-monitor:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.monitor
|
||||
container_name: ollama-gpu-monitor
|
||||
depends_on:
|
||||
- ollama
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
environment:
|
||||
- CHECK_INTERVAL=60 # Check every 60 seconds
|
||||
- MIN_AVAILABLE_PERCENT=70 # Alert if less than 70% GPU memory available
|
||||
- AUTO_FIX=true # Automatically fix buffer cache issues
|
||||
privileged: true # Required to clear buffer cache and restart containers
|
||||
restart: unless-stopped
|
||||
profiles:
|
||||
- unified-memory # Only start with --profile unified-memory
|
||||
|
||||
volumes:
|
||||
ollama_models:
|
||||
driver: local
|
||||
42
nvidia/txt2kg/assets/deploy/services/ollama/entrypoint.sh
Normal file
42
nvidia/txt2kg/assets/deploy/services/ollama/entrypoint.sh
Normal file
@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Start Ollama server in the background
|
||||
echo "Starting Ollama server..."
|
||||
/bin/ollama serve &
|
||||
OLLAMA_PID=$!
|
||||
|
||||
# Wait for Ollama to be ready
|
||||
echo "Waiting for Ollama to be ready..."
|
||||
max_attempts=30
|
||||
attempt=0
|
||||
while [ $attempt -lt $max_attempts ]; do
|
||||
if curl -s http://localhost:11434/api/tags > /dev/null 2>&1; then
|
||||
echo "Ollama is ready!"
|
||||
break
|
||||
fi
|
||||
attempt=$((attempt + 1))
|
||||
sleep 2
|
||||
done
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ERROR: Ollama failed to start within the timeout period"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if any models are present
|
||||
echo "Checking for existing models..."
|
||||
MODELS=$(curl -s http://localhost:11434/api/tags | grep -o '"models":\s*\[\]' || echo "has_models")
|
||||
|
||||
if [[ "$MODELS" == *'"models": []'* ]]; then
|
||||
echo "No models found. Pulling llama3.1:8b..."
|
||||
/bin/ollama pull llama3.1:8b
|
||||
echo "Successfully pulled llama3.1:8b"
|
||||
else
|
||||
echo "Models already exist, skipping pull."
|
||||
fi
|
||||
|
||||
# Keep the container running
|
||||
echo "Setup complete. Ollama is running."
|
||||
wait $OLLAMA_PID
|
||||
|
||||
@ -0,0 +1,108 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Ollama GPU Memory Monitor - runs inside a sidecar container
|
||||
# Automatically detects and fixes unified memory buffer cache issues
|
||||
#
|
||||
|
||||
set -e
|
||||
|
||||
# Configuration
|
||||
CHECK_INTERVAL=${CHECK_INTERVAL:-60} # Check every 60 seconds
|
||||
MIN_AVAILABLE_PERCENT=${MIN_AVAILABLE_PERCENT:-70} # Alert if less than 70% available
|
||||
AUTO_FIX=${AUTO_FIX:-true} # Automatically fix issues
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
check_ollama_memory() {
|
||||
# Wait for Ollama to be ready
|
||||
if ! curl -s http://ollama:11434/api/tags > /dev/null 2>&1; then
|
||||
log "Ollama not ready, skipping check"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Get Ollama logs to find inference compute info
|
||||
local compute_log=$(docker logs ollama-server 2>&1 | grep "inference compute" | tail -1)
|
||||
|
||||
if [ -z "$compute_log" ]; then
|
||||
log "No inference compute logs found"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Extract memory info
|
||||
local total_mem=$(echo "$compute_log" | grep -o 'total="[^"]*"' | cut -d'"' -f2)
|
||||
local available_mem=$(echo "$compute_log" | grep -o 'available="[^"]*"' | cut -d'"' -f2)
|
||||
|
||||
if [ -z "$total_mem" ] || [ -z "$available_mem" ]; then
|
||||
log "Could not parse memory information"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Convert to numeric (assuming GiB)
|
||||
local total_num=$(echo "$total_mem" | sed 's/ GiB//')
|
||||
local available_num=$(echo "$available_mem" | sed 's/ GiB//')
|
||||
|
||||
# Calculate percentage
|
||||
local available_percent=$(echo "scale=1; $available_num * 100 / $total_num" | bc)
|
||||
|
||||
log "GPU Memory: $available_mem / $total_mem available (${available_percent}%)"
|
||||
|
||||
# Check if we need to take action
|
||||
if (( $(echo "$available_percent < $MIN_AVAILABLE_PERCENT" | bc -l) )); then
|
||||
log "WARNING: Low GPU memory availability detected (${available_percent}%)"
|
||||
|
||||
if [ "$AUTO_FIX" = "true" ]; then
|
||||
log "Attempting to fix by clearing buffer cache..."
|
||||
fix_memory_issue
|
||||
else
|
||||
log "Auto-fix disabled. Manual intervention required."
|
||||
fi
|
||||
|
||||
return 1
|
||||
else
|
||||
log "GPU memory availability OK (${available_percent}%)"
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
fix_memory_issue() {
|
||||
log "Clearing system buffer cache..."
|
||||
|
||||
# Clear buffer cache from host (requires privileged container)
|
||||
echo 1 > /proc/sys/vm/drop_caches 2>/dev/null || {
|
||||
log "Cannot clear buffer cache from container. Trying host command..."
|
||||
# Alternative: use nsenter to run on host
|
||||
nsenter -t 1 -m -p sh -c 'sync && echo 1 > /proc/sys/vm/drop_caches' 2>/dev/null || {
|
||||
log "Failed to clear buffer cache. Manual intervention required."
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
# Wait a moment
|
||||
sleep 5
|
||||
|
||||
# Restart Ollama container
|
||||
log "Restarting Ollama container..."
|
||||
docker restart ollama-server
|
||||
|
||||
# Wait for restart
|
||||
sleep 15
|
||||
|
||||
log "Fix applied. Ollama should have better memory detection now."
|
||||
}
|
||||
|
||||
main() {
|
||||
log "Starting Ollama GPU Memory Monitor"
|
||||
log "Check interval: ${CHECK_INTERVAL}s, Min available: ${MIN_AVAILABLE_PERCENT}%, Auto-fix: ${AUTO_FIX}"
|
||||
|
||||
while true; do
|
||||
check_ollama_memory || true # Don't exit on check failures
|
||||
sleep "$CHECK_INTERVAL"
|
||||
done
|
||||
}
|
||||
|
||||
# Handle signals gracefully
|
||||
trap 'log "Shutting down monitor..."; exit 0' SIGTERM SIGINT
|
||||
|
||||
main
|
||||
79
nvidia/txt2kg/assets/deploy/services/ollama/monitor_gpu_memory.sh
Executable file
79
nvidia/txt2kg/assets/deploy/services/ollama/monitor_gpu_memory.sh
Executable file
@ -0,0 +1,79 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Monitor Ollama GPU memory usage and alert when buffer cache is consuming too much
|
||||
# This helps detect when the unified memory issue is occurring
|
||||
#
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Thresholds
|
||||
MIN_AVAILABLE_PERCENT=70 # Alert if less than 70% GPU memory available
|
||||
|
||||
echo "🔍 Ollama GPU Memory Monitor"
|
||||
echo "================================"
|
||||
|
||||
# Check if Ollama container is running
|
||||
if ! docker ps | grep -q ollama-server; then
|
||||
echo -e "${RED}❌ Ollama container is not running${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the latest inference compute log
|
||||
COMPUTE_LOG=$(docker logs ollama-server 2>&1 | grep "inference compute" | tail -1)
|
||||
|
||||
if [ -z "$COMPUTE_LOG" ]; then
|
||||
echo -e "${YELLOW}⚠️ No inference compute logs found. Model may not be loaded.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Latest GPU memory status:"
|
||||
echo "$COMPUTE_LOG"
|
||||
|
||||
# Extract total and available memory
|
||||
TOTAL_MEM=$(echo "$COMPUTE_LOG" | grep -o 'total="[^"]*"' | cut -d'"' -f2)
|
||||
AVAILABLE_MEM=$(echo "$COMPUTE_LOG" | grep -o 'available="[^"]*"' | cut -d'"' -f2)
|
||||
|
||||
# Convert to numeric values (assuming GiB)
|
||||
TOTAL_NUM=$(echo "$TOTAL_MEM" | sed 's/ GiB//')
|
||||
AVAILABLE_NUM=$(echo "$AVAILABLE_MEM" | sed 's/ GiB//')
|
||||
|
||||
# Calculate percentage
|
||||
AVAILABLE_PERCENT=$(echo "scale=1; $AVAILABLE_NUM * 100 / $TOTAL_NUM" | bc)
|
||||
|
||||
echo ""
|
||||
echo "Memory Analysis:"
|
||||
echo " Total GPU Memory: $TOTAL_MEM"
|
||||
echo " Available Memory: $AVAILABLE_MEM"
|
||||
echo " Available Percentage: ${AVAILABLE_PERCENT}%"
|
||||
|
||||
# Check if we need to alert
|
||||
if (( $(echo "$AVAILABLE_PERCENT < $MIN_AVAILABLE_PERCENT" | bc -l) )); then
|
||||
echo ""
|
||||
echo -e "${RED}🚨 WARNING: Low GPU memory availability detected!${NC}"
|
||||
echo -e "${RED} Only ${AVAILABLE_PERCENT}% of GPU memory is available${NC}"
|
||||
echo -e "${YELLOW} This may cause models to run on CPU instead of GPU${NC}"
|
||||
echo ""
|
||||
echo "💡 Recommended action:"
|
||||
echo " Run: ./clear_cache_and_restart.sh"
|
||||
echo ""
|
||||
|
||||
# Show current system memory usage
|
||||
echo "Current system memory usage:"
|
||||
free -h
|
||||
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo -e "${GREEN}✅ GPU memory availability looks good (${AVAILABLE_PERCENT}%)${NC}"
|
||||
fi
|
||||
|
||||
# Show current model status
|
||||
echo ""
|
||||
echo "Current loaded models:"
|
||||
docker exec ollama-server ollama ps
|
||||
@ -0,0 +1,23 @@
|
||||
FROM python:3.9-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements and install dependencies first for better caching
|
||||
COPY requirements.txt /app/
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY app.py /app/
|
||||
|
||||
# Set default model name
|
||||
ENV MODEL_NAME="all-MiniLM-L6-v2"
|
||||
ENV TRANSFORMERS_CACHE="/app/.cache"
|
||||
|
||||
# Pre-download the model during build for faster startup
|
||||
RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('${MODEL_NAME}')"
|
||||
|
||||
# Expose the port
|
||||
EXPOSE 80
|
||||
|
||||
# Use Gunicorn for better performance
|
||||
CMD ["gunicorn", "--bind", "0.0.0.0:80", "--workers", "1", "--threads", "8", "app:app"]
|
||||
@ -0,0 +1,92 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Get model name from environment variable
|
||||
model_name = os.environ.get("MODEL_NAME", "all-MiniLM-L6-v2")
|
||||
logger.info(f"Loading model: {model_name}")
|
||||
|
||||
# Load model during startup
|
||||
start_time = time.time()
|
||||
try:
|
||||
model = SentenceTransformer(model_name)
|
||||
logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def health():
|
||||
return jsonify({"status": "healthy", "model": model_name})
|
||||
|
||||
@app.route("/embed", methods=["POST"])
|
||||
def embed():
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return jsonify({"error": "No JSON data provided"}), 400
|
||||
|
||||
texts = data.get("texts", [])
|
||||
if not texts:
|
||||
return jsonify({"error": "No texts provided"}), 400
|
||||
|
||||
# Process in batches if needed
|
||||
batch_size = data.get("batch_size", 32)
|
||||
|
||||
start_time = time.time()
|
||||
embeddings = model.encode(texts, batch_size=batch_size).tolist()
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
logger.info(f"Processed {len(texts)} texts in {processing_time:.2f} seconds")
|
||||
|
||||
return jsonify({
|
||||
"embeddings": embeddings,
|
||||
"model": model_name,
|
||||
"processing_time": processing_time
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
# Add compatibility with the /embeddings endpoint for the EmbeddingsService class
|
||||
@app.route("/embeddings", methods=["POST"])
|
||||
def embeddings():
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return jsonify({"error": "No JSON data provided"}), 400
|
||||
|
||||
texts = data.get("input", [])
|
||||
if not texts:
|
||||
return jsonify({"error": "No input texts provided"}), 400
|
||||
|
||||
batch_size = data.get("batch_size", 32)
|
||||
|
||||
start_time = time.time()
|
||||
embeddings = model.encode(texts, batch_size=batch_size).tolist()
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Format response for compatibility with the EmbeddingsService
|
||||
response_data = {
|
||||
"data": [{"embedding": embedding} for embedding in embeddings],
|
||||
"model": model_name,
|
||||
"processing_time": processing_time
|
||||
}
|
||||
|
||||
logger.info(f"Processed {len(texts)} texts in {processing_time:.2f} seconds for /embeddings endpoint")
|
||||
|
||||
return jsonify(response_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 80)))
|
||||
@ -0,0 +1,6 @@
|
||||
sentence-transformers==2.3.1
|
||||
transformers==4.36.2
|
||||
torch==2.1.2
|
||||
flask==2.3.3
|
||||
gunicorn==21.2.0
|
||||
numpy==1.26.2
|
||||
27
nvidia/txt2kg/assets/deploy/services/vllm/Dockerfile
Normal file
27
nvidia/txt2kg/assets/deploy/services/vllm/Dockerfile
Normal file
@ -0,0 +1,27 @@
|
||||
# Use NVIDIA Triton Inference Server with vLLM - optimized for latest NVIDIA hardware
|
||||
FROM nvcr.io/nvidia/tritonserver:25.08-vllm-python-py3
|
||||
|
||||
# Install curl for health checks
|
||||
RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the vLLM startup script
|
||||
COPY launch_server.sh .
|
||||
|
||||
# Make startup script executable
|
||||
RUN chmod +x launch_server.sh
|
||||
|
||||
# Create model directory
|
||||
RUN mkdir -p /app/models
|
||||
|
||||
# Expose the service port
|
||||
EXPOSE 8001
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD curl -f http://localhost:8001/health || exit 1
|
||||
|
||||
# Start vLLM's built-in OpenAI API server directly
|
||||
CMD ["./launch_server.sh"]
|
||||
@ -0,0 +1,21 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install required packages
|
||||
RUN pip install --no-cache-dir \
|
||||
aiohttp \
|
||||
asyncio \
|
||||
statistics
|
||||
|
||||
# Copy benchmark script
|
||||
COPY vllm_llama3_benchmark.py /app/
|
||||
|
||||
# Create results directory
|
||||
RUN mkdir -p /app/results
|
||||
|
||||
# Make script executable
|
||||
RUN chmod +x /app/vllm_llama3_benchmark.py
|
||||
|
||||
# Default command
|
||||
CMD ["python", "/app/vllm_llama3_benchmark.py", "--url", "http://vllm-llama3-8b:8001", "--output", "/app/results/benchmark_results.json"]
|
||||
92
nvidia/txt2kg/assets/deploy/services/vllm/README.md
Normal file
92
nvidia/txt2kg/assets/deploy/services/vllm/README.md
Normal file
@ -0,0 +1,92 @@
|
||||
# vLLM NVFP4 Deployment
|
||||
|
||||
This setup deploys the NVIDIA Llama 4 Scout model with NVFP4 quantization using vLLM, optimized for Blackwell and Hopper GPU architectures.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. **Set up your HuggingFace token:**
|
||||
```bash
|
||||
cp env.example .env
|
||||
# Edit .env and add your HF_TOKEN
|
||||
```
|
||||
|
||||
2. **Build and run:**
|
||||
```bash
|
||||
docker-compose up --build
|
||||
```
|
||||
|
||||
3. **Test the deployment:**
|
||||
```bash
|
||||
curl -X POST "http://localhost:8001/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
"messages": [{"role": "user", "content": "Hello! How are you?"}],
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Model Information
|
||||
|
||||
- **Model**: `nvidia/Llama-4-Scout-17B-16E-Instruct-FP4`
|
||||
- **Quantization**: NVFP4 (optimized for Blackwell architecture)
|
||||
- **Alternative**: `nvidia/Llama-4-Scout-17B-16E-Instruct-FP8` (for Hopper architecture)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
The startup script automatically detects your GPU architecture and applies optimal settings:
|
||||
|
||||
### Blackwell (Compute Capability 10.0)
|
||||
- Enables FlashInfer backend
|
||||
- Uses NVFP4 quantization
|
||||
- Enables async scheduling
|
||||
- Applies fusion optimizations
|
||||
|
||||
### Hopper (Compute Capability 9.0)
|
||||
- Uses FP8 quantization
|
||||
- Disables async scheduling (due to vLLM limitations)
|
||||
- Standard optimization settings
|
||||
|
||||
### Configuration Options
|
||||
|
||||
Adjust these environment variables in your `.env` file:
|
||||
|
||||
- `VLLM_TENSOR_PARALLEL_SIZE`: Number of GPUs to use (default: 2)
|
||||
- `VLLM_MAX_NUM_SEQS`: Batch size (default: 128)
|
||||
- `VLLM_MAX_NUM_BATCHED_TOKENS`: Token batching limit (default: 8192)
|
||||
- `VLLM_GPU_MEMORY_UTILIZATION`: GPU memory usage (default: 0.9)
|
||||
|
||||
### Performance Scenarios
|
||||
|
||||
- **Maximum Throughput**: `VLLM_TENSOR_PARALLEL_SIZE=1`, increase `VLLM_MAX_NUM_SEQS`
|
||||
- **Minimum Latency**: `VLLM_TENSOR_PARALLEL_SIZE=4-8`, `VLLM_MAX_NUM_SEQS=8`
|
||||
- **Balanced**: `VLLM_TENSOR_PARALLEL_SIZE=2`, `VLLM_MAX_NUM_SEQS=128` (default)
|
||||
|
||||
## Benchmarking
|
||||
|
||||
To benchmark performance:
|
||||
|
||||
```bash
|
||||
docker exec -it vllm-nvfp4-server vllm bench serve \
|
||||
--host 0.0.0.0 \
|
||||
--port 8001 \
|
||||
--model nvidia/Llama-4-Scout-17B-16E-Instruct-FP4 \
|
||||
--dataset-name random \
|
||||
--random-input-len 1024 \
|
||||
--random-output-len 1024 \
|
||||
--max-concurrency 128 \
|
||||
--num-prompts 1280
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- NVIDIA GPU with Blackwell or Hopper architecture
|
||||
- CUDA Driver 575 or above
|
||||
- Docker with NVIDIA Container Toolkit
|
||||
- HuggingFace token (for model access)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Check GPU compatibility: `nvidia-smi`
|
||||
- View logs: `docker-compose logs -f vllm-nvfp4`
|
||||
- Monitor GPU usage: `nvidia-smi -l 1`
|
||||
23
nvidia/txt2kg/assets/deploy/services/vllm/build_image.sh
Executable file
23
nvidia/txt2kg/assets/deploy/services/vllm/build_image.sh
Executable file
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Use latest stable vLLM release for better compute capability 12.1 support
|
||||
# Clone the vLLM GitHub repo and use latest stable release.
|
||||
git clone https://github.com/vllm-project/vllm.git /tmp/vllm-tutorial
|
||||
cd /tmp/vllm-tutorial
|
||||
git checkout $(git describe --tags --abbrev=0)
|
||||
|
||||
# Build the docker image using official vLLM Dockerfile.
|
||||
DOCKER_BUILDKIT=1 docker build . \
|
||||
--file docker/Dockerfile \
|
||||
--target vllm-openai \
|
||||
--build-arg CUDA_VERSION=12.8.1 \
|
||||
--build-arg max_jobs=8 \
|
||||
--build-arg nvcc_threads=2 \
|
||||
--build-arg RUN_WHEEL_CHECK=false \
|
||||
--build-arg torch_cuda_arch_list="10.0+PTX;12.1" \
|
||||
--build-arg vllm_fa_cmake_gpu_arches="100-real;121-real" \
|
||||
-t vllm/vllm-openai:deploy
|
||||
|
||||
# Clean up
|
||||
cd /
|
||||
rm -rf /tmp/vllm-tutorial
|
||||
@ -0,0 +1,100 @@
|
||||
services:
|
||||
vllm-llama3-8b:
|
||||
image: nvcr.io/nvidia/vllm:25.09-py3
|
||||
container_name: vllm-llama3-8b
|
||||
ports:
|
||||
- "8001:8001"
|
||||
environment:
|
||||
# Model configuration - Llama3 8B
|
||||
- MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
|
||||
- TENSOR_PARALLEL_SIZE=1
|
||||
- MAX_MODEL_LEN=4096
|
||||
- GPU_MEMORY_UTILIZATION=0.9
|
||||
|
||||
# Performance optimizations
|
||||
- QUANTIZATION=fp8
|
||||
- KV_CACHE_DTYPE=fp8
|
||||
- ENABLE_CHUNKED_PREFILL=true
|
||||
- MAX_NUM_BATCHED_TOKENS=8192
|
||||
- MAX_NUM_SEQS=256
|
||||
|
||||
# Service configuration
|
||||
- HOST=0.0.0.0
|
||||
- PORT=8001
|
||||
- DISABLE_LOG_STATS=false
|
||||
- DISABLE_LOG_REQUESTS=false
|
||||
|
||||
# CUDA settings
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
- NCCL_DEBUG=INFO
|
||||
|
||||
# Hugging Face settings
|
||||
- HF_HOME=/app/.cache/huggingface
|
||||
- TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
|
||||
|
||||
volumes:
|
||||
# Cache Hugging Face models for faster startup
|
||||
- ~/.cache/huggingface:/app/.cache/huggingface
|
||||
- /tmp:/tmp
|
||||
|
||||
command: >
|
||||
python -m vllm.entrypoints.openai.api_server
|
||||
--model meta-llama/Llama-3.1-8B-Instruct
|
||||
--host 0.0.0.0
|
||||
--port 8001
|
||||
--tensor-parallel-size 1
|
||||
--max-model-len 4096
|
||||
--gpu-memory-utilization 0.9
|
||||
--quantization fp8
|
||||
--kv-cache-dtype fp8
|
||||
--enable-chunked-prefill
|
||||
--max-num-batched-tokens 8192
|
||||
--max-num-seqs 256
|
||||
--disable-log-stats
|
||||
--trust-remote-code
|
||||
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
restart: unless-stopped
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8001/v1/models"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 300s # 5 minutes for model loading
|
||||
|
||||
networks:
|
||||
- vllm-network
|
||||
|
||||
# Benchmark runner service
|
||||
vllm-benchmark:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.benchmark
|
||||
container_name: vllm-benchmark
|
||||
depends_on:
|
||||
vllm-llama3-8b:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- VLLM_URL=http://vllm-llama3-8b:8001
|
||||
volumes:
|
||||
- ./benchmark_results:/app/results
|
||||
networks:
|
||||
- vllm-network
|
||||
profiles:
|
||||
- benchmark # Only start when explicitly requested
|
||||
|
||||
networks:
|
||||
vllm-network:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
vllm_cache:
|
||||
driver: local
|
||||
51
nvidia/txt2kg/assets/deploy/services/vllm/docker-compose.yml
Normal file
51
nvidia/txt2kg/assets/deploy/services/vllm/docker-compose.yml
Normal file
@ -0,0 +1,51 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
vllm-nvfp4:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: vllm-nvfp4-server
|
||||
ports:
|
||||
- "8001:8001"
|
||||
environment:
|
||||
# HuggingFace configuration
|
||||
- HF_TOKEN=${HF_TOKEN}
|
||||
- HF_HOME=/app/models/.cache
|
||||
|
||||
volumes:
|
||||
# Cache HuggingFace models locally
|
||||
- ./models:/app/models
|
||||
- huggingface_cache:/app/models/.cache
|
||||
# Mount the launch script
|
||||
- ./launch_server.sh:/app/launch_server.sh
|
||||
|
||||
# NVIDIA recommended settings for PyTorch
|
||||
ipc: host
|
||||
ulimits:
|
||||
memlock: -1
|
||||
stack: 67108864
|
||||
shm_size: 2gb
|
||||
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
|
||||
restart: unless-stopped
|
||||
|
||||
entrypoint: ["/bin/bash", "/app/launch_server.sh"]
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8001/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 120s
|
||||
|
||||
volumes:
|
||||
huggingface_cache:
|
||||
driver: local
|
||||
115
nvidia/txt2kg/assets/deploy/services/vllm/launch_server.sh
Executable file
115
nvidia/txt2kg/assets/deploy/services/vllm/launch_server.sh
Executable file
@ -0,0 +1,115 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Launch vLLM with NVIDIA Triton Inference Server optimized build
|
||||
# This should have proper support for compute capability 12.1 (DGX Spark)
|
||||
|
||||
# Enable unified memory usage for DGX Spark
|
||||
export CUDA_MANAGED_FORCE_DEVICE_ALLOC=1
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
|
||||
# Enable CUDA unified memory and oversubscription
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export PYTORCH_NO_CUDA_MEMORY_CACHING=0
|
||||
|
||||
# Force vLLM to use CPU offloading for large models
|
||||
export VLLM_CPU_OFFLOAD_GB=50
|
||||
export VLLM_ALLOW_RUNTIME_LORA_UPDATES_WITH_SGD_LORA=1
|
||||
export VLLM_SKIP_WARMUP=0
|
||||
|
||||
# Optimized environment for performance
|
||||
export VLLM_LOGGING_LEVEL=INFO
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Enable CUDA optimizations
|
||||
export VLLM_USE_MODELSCOPE=false
|
||||
|
||||
# Enable unified memory in vLLM
|
||||
export VLLM_USE_V1=0
|
||||
|
||||
# First, test basic CUDA functionality
|
||||
echo "=== Testing CUDA functionality ==="
|
||||
python3 -c "
|
||||
import torch
|
||||
print(f'PyTorch version: {torch.__version__}')
|
||||
print(f'CUDA available: {torch.cuda.is_available()}')
|
||||
if torch.cuda.is_available():
|
||||
print(f'CUDA version: {torch.version.cuda}')
|
||||
print(f'GPU count: {torch.cuda.device_count()}')
|
||||
for i in range(torch.cuda.device_count()):
|
||||
props = torch.cuda.get_device_properties(i)
|
||||
print(f'GPU {i}: {props.name} (compute capability {props.major}.{props.minor})')
|
||||
# Try basic CUDA operation
|
||||
try:
|
||||
x = torch.randn(10, 10).cuda(i)
|
||||
y = torch.matmul(x, x.T)
|
||||
print(f'GPU {i}: Basic CUDA operations work')
|
||||
except Exception as e:
|
||||
print(f'GPU {i}: CUDA operation failed: {e}')
|
||||
"
|
||||
|
||||
echo "=== Starting optimized vLLM server ==="
|
||||
# Optimized configuration for DGX Spark performance with NVFP4 quantization
|
||||
# Available quantized models from NVIDIA
|
||||
NVFP4_MODEL="nvidia/Llama-3.3-70B-Instruct-FP4"
|
||||
NVFP8_MODEL="nvidia/Llama-3.1-8B-Instruct-FP8"
|
||||
STANDARD_MODEL="meta-llama/Llama-3.1-70B-Instruct"
|
||||
|
||||
# Check GPU compute capability for optimal quantization
|
||||
COMPUTE_CAPABILITY=$(nvidia-smi -i 0 --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null || echo "unknown")
|
||||
echo "Detected GPU compute capability: $COMPUTE_CAPABILITY"
|
||||
|
||||
# Configure quantization based on GPU architecture
|
||||
if [[ "$COMPUTE_CAPABILITY" == "12.1" ]] || [[ "$COMPUTE_CAPABILITY" == "10.0" ]]; then
|
||||
# Blackwell/DGX Spark architecture - use standard 70B model with CPU offloading
|
||||
echo "Using standard Llama-3.1-70B model for Blackwell/DGX Spark with CPU offloading"
|
||||
QUANTIZATION_FLAG=""
|
||||
MODEL_TO_USE="$STANDARD_MODEL" # Use standard 70B model
|
||||
GPU_MEMORY_UTIL="0.7" # Lower GPU memory to allow unified memory
|
||||
MAX_MODEL_LEN="4096" # Shorter sequences for memory efficiency
|
||||
MAX_NUM_SEQS="16" # Lower concurrent sequences for 70B
|
||||
MAX_BATCHED_TOKENS="4096"
|
||||
CPU_OFFLOAD_GB="50" # Offload 50GB to CPU/unified memory
|
||||
elif [[ "$COMPUTE_CAPABILITY" == "9.0" ]]; then
|
||||
# Hopper architecture - use standard model
|
||||
echo "Using standard 70B model for Hopper architecture"
|
||||
QUANTIZATION_FLAG=""
|
||||
MODEL_TO_USE="$STANDARD_MODEL"
|
||||
GPU_MEMORY_UTIL="0.7"
|
||||
MAX_MODEL_LEN="4096"
|
||||
MAX_NUM_SEQS="16"
|
||||
MAX_BATCHED_TOKENS="4096"
|
||||
CPU_OFFLOAD_GB="40"
|
||||
else
|
||||
# Other architectures - use standard precision
|
||||
echo "Using standard 70B model for GPU architecture: $COMPUTE_CAPABILITY"
|
||||
QUANTIZATION_FLAG=""
|
||||
MODEL_TO_USE="$STANDARD_MODEL"
|
||||
GPU_MEMORY_UTIL="0.7"
|
||||
MAX_MODEL_LEN="2048"
|
||||
MAX_NUM_SEQS="16"
|
||||
MAX_BATCHED_TOKENS="2048"
|
||||
CPU_OFFLOAD_GB="40"
|
||||
fi
|
||||
|
||||
echo "Using model: $MODEL_TO_USE"
|
||||
echo "Quantization: ${QUANTIZATION_FLAG:-'disabled'}"
|
||||
echo "GPU memory utilization: $GPU_MEMORY_UTIL"
|
||||
|
||||
echo "CPU Offload: ${CPU_OFFLOAD_GB}GB"
|
||||
|
||||
vllm serve "$MODEL_TO_USE" \
|
||||
--host 0.0.0.0 \
|
||||
--port 8001 \
|
||||
--tensor-parallel-size 1 \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||
--max-num-batched-tokens "$MAX_BATCHED_TOKENS" \
|
||||
--gpu-memory-utilization "$GPU_MEMORY_UTIL" \
|
||||
--cpu-offload-gb "$CPU_OFFLOAD_GB" \
|
||||
--kv-cache-dtype auto \
|
||||
--trust-remote-code \
|
||||
--served-model-name "$MODEL_TO_USE" \
|
||||
--enable-chunked-prefill \
|
||||
--disable-custom-all-reduce \
|
||||
--disable-async-output-proc \
|
||||
$QUANTIZATION_FLAG
|
||||
199
nvidia/txt2kg/assets/deploy/services/vllm/run_benchmark.sh
Executable file
199
nvidia/txt2kg/assets/deploy/services/vllm/run_benchmark.sh
Executable file
@ -0,0 +1,199 @@
|
||||
#!/bin/bash
|
||||
|
||||
# vLLM Llama3 8B Benchmark Runner
|
||||
# Uses NVIDIA vLLM container for optimal performance
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
VLLM_URL="http://localhost:8001"
|
||||
RUNS=3
|
||||
MAX_TOKENS=512
|
||||
OUTPUT_FILE=""
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
print_header() {
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE} 🚀 vLLM Llama3 8B Benchmark Suite${NC}"
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
}
|
||||
|
||||
print_usage() {
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -u, --url URL vLLM service URL (default: http://localhost:8001)"
|
||||
echo " -r, --runs NUMBER Number of runs per prompt (default: 3)"
|
||||
echo " -t, --max-tokens NUM Maximum tokens to generate (default: 512)"
|
||||
echo " -o, --output FILE Output file for detailed results (JSON)"
|
||||
echo " -d, --docker Run using Docker Compose"
|
||||
echo " -s, --start-service Start vLLM service first"
|
||||
echo " -h, --health-check Only run health check"
|
||||
echo " --help Show this help message"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 # Run basic benchmark"
|
||||
echo " $0 --docker --start-service # Start service and run benchmark in Docker"
|
||||
echo " $0 -r 5 -t 1024 -o results.json # Custom settings with output file"
|
||||
echo " $0 --health-check # Check if service is running"
|
||||
}
|
||||
|
||||
check_dependencies() {
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
echo -e "${RED}❌ Python3 is required but not installed${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! python3 -c "import aiohttp, asyncio" &> /dev/null; then
|
||||
echo -e "${YELLOW}⚠️ Installing required Python packages...${NC}"
|
||||
pip3 install aiohttp asyncio
|
||||
fi
|
||||
}
|
||||
|
||||
check_nvidia_docker() {
|
||||
if ! command -v docker &> /dev/null; then
|
||||
echo -e "${RED}❌ Docker is required but not installed${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! docker info | grep -q "nvidia"; then
|
||||
echo -e "${YELLOW}⚠️ NVIDIA Docker runtime not detected. Make sure nvidia-container-toolkit is installed${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
start_vllm_service() {
|
||||
echo -e "${BLUE}🚀 Starting vLLM Llama3 8B service...${NC}"
|
||||
|
||||
cd "$SCRIPT_DIR"
|
||||
docker-compose -f docker-compose.llama3-8b.yml up -d vllm-llama3-8b
|
||||
|
||||
echo -e "${YELLOW}⏳ Waiting for model to load (this may take several minutes)...${NC}"
|
||||
|
||||
# Wait for service to be healthy
|
||||
local max_attempts=60 # 10 minutes
|
||||
local attempt=1
|
||||
|
||||
while [ $attempt -le $max_attempts ]; do
|
||||
if curl -sf "$VLLM_URL/v1/models" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}✅ vLLM service is ready!${NC}"
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo -e "${YELLOW}⏳ Attempt $attempt/$max_attempts - waiting for service...${NC}"
|
||||
sleep 10
|
||||
((attempt++))
|
||||
done
|
||||
|
||||
echo -e "${RED}❌ vLLM service failed to start within timeout${NC}"
|
||||
echo -e "${YELLOW}📋 Checking service logs:${NC}"
|
||||
docker-compose -f docker-compose.llama3-8b.yml logs vllm-llama3-8b
|
||||
exit 1
|
||||
}
|
||||
|
||||
run_benchmark() {
|
||||
local cmd_args=("--url" "$VLLM_URL" "--runs" "$RUNS" "--max-tokens" "$MAX_TOKENS")
|
||||
|
||||
if [ -n "$OUTPUT_FILE" ]; then
|
||||
cmd_args+=("--output" "$OUTPUT_FILE")
|
||||
fi
|
||||
|
||||
if [ "$HEALTH_CHECK_ONLY" = true ]; then
|
||||
cmd_args+=("--health-check-only")
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}🧪 Running vLLM Llama3 8B benchmark...${NC}"
|
||||
echo -e "${BLUE}URL: $VLLM_URL${NC}"
|
||||
echo -e "${BLUE}Runs per prompt: $RUNS${NC}"
|
||||
echo -e "${BLUE}Max tokens: $MAX_TOKENS${NC}"
|
||||
|
||||
if [ "$USE_DOCKER" = true ]; then
|
||||
# Run benchmark in Docker
|
||||
cd "$SCRIPT_DIR"
|
||||
docker-compose -f docker-compose.llama3-8b.yml run --rm vllm-benchmark \
|
||||
python /app/vllm_llama3_benchmark.py "${cmd_args[@]}"
|
||||
else
|
||||
# Run benchmark locally
|
||||
python3 "$SCRIPT_DIR/vllm_llama3_benchmark.py" "${cmd_args[@]}"
|
||||
fi
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
USE_DOCKER=false
|
||||
START_SERVICE=false
|
||||
HEALTH_CHECK_ONLY=false
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-u|--url)
|
||||
VLLM_URL="$2"
|
||||
shift 2
|
||||
;;
|
||||
-r|--runs)
|
||||
RUNS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-t|--max-tokens)
|
||||
MAX_TOKENS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-o|--output)
|
||||
OUTPUT_FILE="$2"
|
||||
shift 2
|
||||
;;
|
||||
-d|--docker)
|
||||
USE_DOCKER=true
|
||||
shift
|
||||
;;
|
||||
-s|--start-service)
|
||||
START_SERVICE=true
|
||||
shift
|
||||
;;
|
||||
-h|--health-check)
|
||||
HEALTH_CHECK_ONLY=true
|
||||
shift
|
||||
;;
|
||||
--help)
|
||||
print_usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo -e "${RED}❌ Unknown option: $1${NC}"
|
||||
print_usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Main execution
|
||||
print_header
|
||||
|
||||
if [ "$USE_DOCKER" = true ]; then
|
||||
check_nvidia_docker
|
||||
|
||||
if [ "$START_SERVICE" = true ]; then
|
||||
start_vllm_service
|
||||
fi
|
||||
|
||||
run_benchmark
|
||||
else
|
||||
check_dependencies
|
||||
|
||||
if [ "$START_SERVICE" = true ]; then
|
||||
echo -e "${YELLOW}⚠️ --start-service only works with --docker flag${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_benchmark
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✅ Benchmark completed successfully!${NC}"
|
||||
|
||||
if [ -n "$OUTPUT_FILE" ] && [ -f "$OUTPUT_FILE" ]; then
|
||||
echo -e "${BLUE}📊 Detailed results saved to: $OUTPUT_FILE${NC}"
|
||||
fi
|
||||
4
nvidia/txt2kg/assets/deploy/services/vllm/run_container.sh
Executable file
4
nvidia/txt2kg/assets/deploy/services/vllm/run_container.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Follow the official vLLM tutorial run_container.sh exactly
|
||||
docker run -e HF_TOKEN="$HF_TOKEN" -e HF_HOME="$HF_HOME" --ipc=host --gpus all --entrypoint "/bin/bash" --rm -it vllm/vllm-openai:deploy
|
||||
96
nvidia/txt2kg/assets/deploy/services/vllm/start-vllm.sh
Executable file
96
nvidia/txt2kg/assets/deploy/services/vllm/start-vllm.sh
Executable file
@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
# vLLM startup script with NVFP4 quantization support for Llama 4 Scout
|
||||
# Optimized for NVIDIA Blackwell and Hopper architectures
|
||||
|
||||
set -e
|
||||
|
||||
# Default configuration - using supported Llama 3.1 model for testing
|
||||
VLLM_MODEL=${VLLM_MODEL:-"meta-llama/Llama-3.1-8B-Instruct"}
|
||||
VLLM_PORT=${VLLM_PORT:-8001}
|
||||
VLLM_HOST=${VLLM_HOST:-"0.0.0.0"}
|
||||
VLLM_TENSOR_PARALLEL_SIZE=${VLLM_TENSOR_PARALLEL_SIZE:-2}
|
||||
VLLM_MAX_MODEL_LEN=${VLLM_MAX_MODEL_LEN:-8192}
|
||||
VLLM_GPU_MEMORY_UTILIZATION=${VLLM_GPU_MEMORY_UTILIZATION:-0.9}
|
||||
VLLM_MAX_NUM_SEQS=${VLLM_MAX_NUM_SEQS:-128}
|
||||
VLLM_MAX_NUM_BATCHED_TOKENS=${VLLM_MAX_NUM_BATCHED_TOKENS:-8192}
|
||||
VLLM_KV_CACHE_DTYPE=${VLLM_KV_CACHE_DTYPE:-"auto"}
|
||||
|
||||
# Detect GPU compute capability and set optimizations
|
||||
COMPUTE_CAPABILITY=$(nvidia-smi -i 0 --query-gpu=compute_cap --format=csv,noheader 2>/dev/null || echo "unknown")
|
||||
|
||||
echo "Starting vLLM service with the following configuration:"
|
||||
echo "Model: $VLLM_MODEL"
|
||||
echo "Port: $VLLM_PORT"
|
||||
echo "Host: $VLLM_HOST"
|
||||
echo "Tensor Parallel Size: $VLLM_TENSOR_PARALLEL_SIZE"
|
||||
echo "Max Model Length: $VLLM_MAX_MODEL_LEN"
|
||||
echo "Max Num Seqs: $VLLM_MAX_NUM_SEQS"
|
||||
echo "Max Batched Tokens: $VLLM_MAX_NUM_BATCHED_TOKENS"
|
||||
echo "GPU Memory Utilization: $VLLM_GPU_MEMORY_UTILIZATION"
|
||||
echo "KV Cache Dtype: $VLLM_KV_CACHE_DTYPE"
|
||||
echo "GPU Compute Capability: $COMPUTE_CAPABILITY"
|
||||
|
||||
# Set up environment variables for optimal performance based on GPU architecture
|
||||
if [ "$COMPUTE_CAPABILITY" = "10.0" ]; then
|
||||
echo "Detected Blackwell architecture - enabling NVFP4 optimizations"
|
||||
# Use FlashInfer backend for attentions
|
||||
export VLLM_ATTENTION_BACKEND=FLASHINFER
|
||||
# Use FlashInfer trtllm-gen attention kernels
|
||||
export VLLM_USE_TRTLLM_ATTENTION=1
|
||||
# Use FlashInfer FP8/FP4 MoE
|
||||
export VLLM_USE_FLASHINFER_MOE_FP8=1
|
||||
export VLLM_USE_FLASHINFER_MOE_FP4=1
|
||||
# Use FlashInfer trtllm-gen MoE backend
|
||||
export VLLM_FLASHINFER_MOE_BACKEND="latency"
|
||||
# Enable async scheduling
|
||||
ASYNC_SCHEDULING_FLAG="--async-scheduling"
|
||||
# Enable FlashInfer fusions
|
||||
FUSION_FLAG='{"pass_config":{"enable_fi_allreduce_fusion":true,"enable_noop":true},"custom_ops":["+quant_fp8","+rms_norm"],"full_cuda_graph":true}'
|
||||
elif [ "$COMPUTE_CAPABILITY" = "9.0" ]; then
|
||||
echo "Detected Hopper architecture - enabling FP8 optimizations"
|
||||
# Disable async scheduling on Hopper architecture due to vLLM limitations
|
||||
ASYNC_SCHEDULING_FLAG=""
|
||||
# Disable FlashInfer fusions since they are not supported on Hopper architecture
|
||||
FUSION_FLAG="{}"
|
||||
else
|
||||
echo "GPU architecture not specifically optimized - using default settings"
|
||||
ASYNC_SCHEDULING_FLAG=""
|
||||
FUSION_FLAG="{}"
|
||||
fi
|
||||
|
||||
# Check GPU availability
|
||||
if ! nvidia-smi > /dev/null 2>&1; then
|
||||
echo "Warning: NVIDIA GPU not detected. vLLM may not work properly."
|
||||
fi
|
||||
|
||||
# Create model cache directory
|
||||
mkdir -p /app/models
|
||||
|
||||
echo "Starting vLLM's built-in OpenAI API server"
|
||||
|
||||
# Build vLLM command with NVFP4 optimizations
|
||||
VLLM_CMD="vllm serve $VLLM_MODEL \
|
||||
--host $VLLM_HOST \
|
||||
--port $VLLM_PORT \
|
||||
--tensor-parallel-size $VLLM_TENSOR_PARALLEL_SIZE \
|
||||
--max-model-len $VLLM_MAX_MODEL_LEN \
|
||||
--max-num-seqs $VLLM_MAX_NUM_SEQS \
|
||||
--max-num-batched-tokens $VLLM_MAX_NUM_BATCHED_TOKENS \
|
||||
--gpu-memory-utilization $VLLM_GPU_MEMORY_UTILIZATION \
|
||||
--kv-cache-dtype $VLLM_KV_CACHE_DTYPE \
|
||||
--trust-remote-code \
|
||||
--served-model-name $VLLM_MODEL"
|
||||
|
||||
# Add async scheduling if supported
|
||||
if [ -n "$ASYNC_SCHEDULING_FLAG" ]; then
|
||||
VLLM_CMD="$VLLM_CMD $ASYNC_SCHEDULING_FLAG"
|
||||
fi
|
||||
|
||||
# Add fusion optimizations if available
|
||||
if [ "$FUSION_FLAG" != "{}" ]; then
|
||||
VLLM_CMD="$VLLM_CMD --compilation-config '$FUSION_FLAG'"
|
||||
fi
|
||||
|
||||
# Start vLLM server
|
||||
exec $VLLM_CMD
|
||||
87
nvidia/txt2kg/assets/examples/download_biorxiv_dataset.py
Normal file
87
nvidia/txt2kg/assets/examples/download_biorxiv_dataset.py
Normal file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download and process the MTEB raw_biorxiv dataset for txt2kg demo.
|
||||
Filter for genetics/genomics categories and create individual txt files.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset
|
||||
|
||||
def sanitize_filename(text, max_length=100):
|
||||
"""Convert text to a safe filename."""
|
||||
# Remove special characters and replace with underscores
|
||||
filename = re.sub(r'[^\w\s-]', '', text)
|
||||
filename = re.sub(r'[-\s]+', '_', filename)
|
||||
filename = filename.strip('_')
|
||||
|
||||
# Truncate if too long
|
||||
if len(filename) > max_length:
|
||||
filename = filename[:max_length]
|
||||
|
||||
return filename
|
||||
|
||||
def main():
|
||||
print("Loading MTEB raw_biorxiv dataset...")
|
||||
|
||||
# Load the dataset
|
||||
ds = load_dataset("mteb/raw_biorxiv")
|
||||
|
||||
# Get the train split
|
||||
train_data = ds['train']
|
||||
|
||||
print(f"Total dataset size: {len(train_data)} papers")
|
||||
|
||||
# Filter for genetics or genomics categories
|
||||
genetics_genomics_data = []
|
||||
for item in train_data:
|
||||
category = item['category'].lower()
|
||||
if 'genetic' in category or 'genomic' in category:
|
||||
genetics_genomics_data.append(item)
|
||||
|
||||
print(f"Found {len(genetics_genomics_data)} papers with genetics/genomics categories")
|
||||
|
||||
if len(genetics_genomics_data) == 0:
|
||||
# Let's check what categories are available
|
||||
categories = set(item['category'] for item in train_data)
|
||||
print("Available categories:")
|
||||
for cat in sorted(categories):
|
||||
print(f" - {cat}")
|
||||
return
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path("biorxiv_genetics_genomics")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
print(f"Creating txt files in {output_dir}/")
|
||||
|
||||
# Process each paper
|
||||
for i, item in enumerate(genetics_genomics_data):
|
||||
# Create filename from title and ID
|
||||
title_part = sanitize_filename(item['title'], max_length=50)
|
||||
paper_id = item['id'].replace('/', '_')
|
||||
filename = f"{i+1:03d}_{title_part}_{paper_id}.txt"
|
||||
|
||||
# Create file content
|
||||
content = f"Title: {item['title']}\n"
|
||||
content += f"ID: {item['id']}\n"
|
||||
content += f"Category: {item['category']}\n"
|
||||
content += f"\nAbstract:\n{item['abstract']}\n"
|
||||
|
||||
# Write to file
|
||||
file_path = output_dir / filename
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"Successfully created {len(genetics_genomics_data)} txt files in {output_dir}/")
|
||||
|
||||
# Show some statistics
|
||||
categories_found = set(item['category'] for item in genetics_genomics_data)
|
||||
print(f"\nCategories included:")
|
||||
for cat in sorted(categories_found):
|
||||
count = sum(1 for item in genetics_genomics_data if item['category'] == cat)
|
||||
print(f" - {cat}: {count} papers")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
86
nvidia/txt2kg/assets/examples/download_cc_biorxiv_dataset.py
Normal file
86
nvidia/txt2kg/assets/examples/download_cc_biorxiv_dataset.py
Normal file
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download and process the marianna13/biorxiv dataset for txt2kg demo.
|
||||
Filter for Creative Commons licensed papers and create individual txt files.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset
|
||||
|
||||
def sanitize_filename(text, max_length=100):
|
||||
"""Convert text to a safe filename."""
|
||||
# Remove special characters and replace with underscores
|
||||
filename = re.sub(r'[^\w\s-]', '', text)
|
||||
filename = re.sub(r'[-\s]+', '_', filename)
|
||||
filename = filename.strip('_')
|
||||
|
||||
# Truncate if too long
|
||||
if len(filename) > max_length:
|
||||
filename = filename[:max_length]
|
||||
|
||||
return filename
|
||||
|
||||
def main():
|
||||
print("Loading marianna13/biorxiv dataset...")
|
||||
|
||||
# Load the dataset
|
||||
ds = load_dataset("marianna13/biorxiv")
|
||||
|
||||
# Get the train split
|
||||
train_data = ds['train']
|
||||
|
||||
print(f"Total dataset size: {len(train_data)} papers")
|
||||
|
||||
# Filter for Creative Commons licensed papers
|
||||
cc_papers = train_data.filter(lambda x: x['LICENSE'] == 'creative-commons')
|
||||
|
||||
print(f"Found {len(cc_papers)} Creative Commons licensed papers ({len(cc_papers)/len(train_data)*100:.1f}%)")
|
||||
|
||||
# Take a sample for the demo (full dataset would be too large)
|
||||
sample_size = min(1000, len(cc_papers)) # Limit to 1000 papers for demo
|
||||
cc_sample = cc_papers.select(range(sample_size))
|
||||
|
||||
print(f"Using sample of {len(cc_sample)} papers for demo")
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path("biorxiv_creative_commons")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
print(f"Creating txt files in {output_dir}/")
|
||||
|
||||
# Process each paper
|
||||
for i, item in enumerate(cc_sample):
|
||||
# Create filename from title and DOI
|
||||
title_part = sanitize_filename(item['TITLE'], max_length=50)
|
||||
doi_part = item['DOI'].replace('/', '_').replace('.', '_')
|
||||
filename = f"{i+1:03d}_{title_part}_{doi_part}.txt"
|
||||
|
||||
# Create file content with full text
|
||||
content = f"Title: {item['TITLE']}\n"
|
||||
content += f"DOI: {item['DOI']}\n"
|
||||
content += f"Year: {item['YEAR']}\n"
|
||||
content += f"Authors: {'; '.join(item['AUTHORS']) if item['AUTHORS'] else 'N/A'}\n"
|
||||
content += f"License: {item['LICENSE']}\n"
|
||||
content += f"\nFull Text:\n{item['TEXT']}\n"
|
||||
|
||||
# Write to file
|
||||
file_path = output_dir / filename
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"Successfully created {len(cc_sample)} txt files in {output_dir}/")
|
||||
|
||||
# Show some statistics
|
||||
years = [item['YEAR'] for item in cc_sample]
|
||||
year_range = f"{min(years)} - {max(years)}"
|
||||
|
||||
print(f"\nDataset Statistics:")
|
||||
print(f" Year range: {year_range}")
|
||||
print(f" License: Creative Commons (commercial use allowed)")
|
||||
print(f" Content: Full paper text (not just abstracts)")
|
||||
print(f" Average text length: {sum(len(item['TEXT']) for item in cc_sample) // len(cc_sample):,} characters")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
nvidia/txt2kg/assets/frontend/README.md
Normal file
31
nvidia/txt2kg/assets/frontend/README.md
Normal file
@ -0,0 +1,31 @@
|
||||
# Frontend Application
|
||||
|
||||
This directory contains the Next.js frontend application for the txt2kg project.
|
||||
|
||||
## Structure
|
||||
|
||||
- **app/**: Next.js app directory with pages and routes
|
||||
- **components/**: React components
|
||||
- **contexts/**: React context providers
|
||||
- **hooks/**: Custom React hooks
|
||||
- **lib/**: Utility functions and shared logic
|
||||
- **public/**: Static assets
|
||||
- **styles/**: CSS and styling files
|
||||
- **types/**: TypeScript type definitions
|
||||
|
||||
## Development
|
||||
|
||||
To start the development server:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
## Building for Production
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm run build
|
||||
```
|
||||
@ -0,0 +1 @@
|
||||
console.log('This is mock data to avoid SSR issues')
|
||||
90
nvidia/txt2kg/assets/frontend/app/api/backend/route.ts
Normal file
90
nvidia/txt2kg/assets/frontend/app/api/backend/route.ts
Normal file
@ -0,0 +1,90 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import remoteBackend from '@/lib/remote-backend';
|
||||
import type { Triple } from '@/types/graph';
|
||||
import { getGraphDbType } from '../settings/route';
|
||||
|
||||
/**
|
||||
* Remote backend API that provides endpoints for creating and querying a knowledge graph
|
||||
* using the selected graph database, Pinecone, and SentenceTransformer
|
||||
*/
|
||||
|
||||
/**
|
||||
* Create a backend from triples
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const { triples } = await request.json();
|
||||
|
||||
if (!triples || !Array.isArray(triples) || triples.length === 0) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Triples are required and must be a non-empty array' },
|
||||
{ status: 400 }
|
||||
);
|
||||
}
|
||||
|
||||
// Initialize backend with the selected graph database type
|
||||
if (!remoteBackend.isInitialized()) {
|
||||
const graphDbType = getGraphDbType();
|
||||
console.log(`Initializing backend with graph DB type: ${graphDbType}`);
|
||||
await remoteBackend.initialize(graphDbType);
|
||||
}
|
||||
|
||||
// Create backend from triples
|
||||
await remoteBackend.createBackendFromTriples(triples);
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Created backend successfully with ${triples.length} triples`,
|
||||
graphDbType: getGraphDbType()
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error creating backend from triples:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json({ error: errorMessage }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Query the backend with a given query text
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const url = new URL(request.url);
|
||||
const query = url.searchParams.get('query');
|
||||
|
||||
if (!query) {
|
||||
return NextResponse.json({ error: 'Query parameter is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Parse optional parameters with fallbacks
|
||||
const kNeighbors = parseInt(url.searchParams.get('kNeighbors') || '4096', 10);
|
||||
const fanout = parseInt(url.searchParams.get('fanout') || '400', 10);
|
||||
const numHops = parseInt(url.searchParams.get('numHops') || '2', 10);
|
||||
|
||||
// Initialize backend with the selected graph database type
|
||||
if (!remoteBackend.isInitialized()) {
|
||||
const graphDbType = getGraphDbType();
|
||||
console.log(`Initializing backend with graph DB type: ${graphDbType}`);
|
||||
await remoteBackend.initialize(graphDbType);
|
||||
}
|
||||
|
||||
// Query the backend
|
||||
const relevantTriples = await remoteBackend.query(query, kNeighbors, fanout, numHops);
|
||||
|
||||
return NextResponse.json({
|
||||
query,
|
||||
triples: relevantTriples,
|
||||
count: relevantTriples.length,
|
||||
parameters: {
|
||||
kNeighbors,
|
||||
fanout,
|
||||
numHops
|
||||
},
|
||||
graphDbType: getGraphDbType()
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error querying backend:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json({ error: errorMessage }, { status: 500 });
|
||||
}
|
||||
}
|
||||
14
nvidia/txt2kg/assets/frontend/app/api/config/route.ts
Normal file
14
nvidia/txt2kg/assets/frontend/app/api/config/route.ts
Normal file
@ -0,0 +1,14 @@
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
export async function GET() {
|
||||
// Only return the necessary configuration data
|
||||
return NextResponse.json({
|
||||
nvidiaApiKey: process.env.NVIDIA_API_KEY || null,
|
||||
// xaiApiKey removed - integration has been removed
|
||||
ollamaBaseUrl: process.env.OLLAMA_BASE_URL || 'http://localhost:11434/v1',
|
||||
ollamaModel: process.env.OLLAMA_MODEL || 'qwen3:1.7b',
|
||||
vllmBaseUrl: process.env.VLLM_BASE_URL || 'http://localhost:8001/v1',
|
||||
vllmModel: process.env.VLLM_MODEL || 'meta-llama/Llama-3.2-3B-Instruct',
|
||||
// Add other config values as needed
|
||||
});
|
||||
}
|
||||
133
nvidia/txt2kg/assets/frontend/app/api/embeddings/route.ts
Normal file
133
nvidia/txt2kg/assets/frontend/app/api/embeddings/route.ts
Normal file
@ -0,0 +1,133 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { EmbeddingsService } from '@/lib/embeddings';
|
||||
import { PineconeService } from '@/lib/pinecone';
|
||||
|
||||
/**
|
||||
* Generate embeddings for text chunks and store them in Pinecone
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const { documentId, content, documentName } = await request.json();
|
||||
|
||||
if (!content) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Document content is required' },
|
||||
{ status: 400 }
|
||||
);
|
||||
}
|
||||
|
||||
// Initialize embedding service
|
||||
const embeddingsService = EmbeddingsService.getInstance();
|
||||
|
||||
// Log which provider we're using
|
||||
console.log(`Using embeddings provider: ${process.env.EMBEDDINGS_PROVIDER || 'local'}`);
|
||||
|
||||
// Generate chunks from content
|
||||
const chunkSize = 200; // Size of each text chunk
|
||||
const chunks = generateChunks(content, chunkSize);
|
||||
console.log(`Generated ${chunks.length} chunks from document`);
|
||||
|
||||
// Create unique IDs for each chunk based on document name and chunk index
|
||||
const docPrefix = documentName ?
|
||||
documentName.replace(/[^a-zA-Z0-9]/g, '_').substring(0, 20) :
|
||||
documentId ? documentId : 'doc';
|
||||
|
||||
const chunkIds = chunks.map((_, index) => `${docPrefix}_chunk_${index}`);
|
||||
|
||||
// Generate embeddings for chunks
|
||||
console.log('Generating embeddings for chunks...');
|
||||
const embeddings = await embeddingsService.encode(chunks);
|
||||
console.log(`Generated ${embeddings.length} embeddings`);
|
||||
|
||||
// Initialize PineconeService
|
||||
const pineconeService = PineconeService.getInstance();
|
||||
|
||||
// Check if Pinecone server is running
|
||||
const isPineconeRunning = await pineconeService.isPineconeRunning();
|
||||
if (!isPineconeRunning) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Pinecone server is not available. Please make sure it is running.' },
|
||||
{ status: 503 }
|
||||
);
|
||||
}
|
||||
|
||||
if (!pineconeService.isInitialized()) {
|
||||
try {
|
||||
await pineconeService.initialize();
|
||||
} catch (initError) {
|
||||
console.error('Error initializing Pinecone:', initError);
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to initialize Pinecone: ${initError instanceof Error ? initError.message : String(initError)}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Create maps for embeddings and text content
|
||||
const entityEmbeddings = new Map<string, number[]>();
|
||||
const textContent = new Map<string, string>();
|
||||
|
||||
// Populate the maps
|
||||
for (let i = 0; i < chunkIds.length; i++) {
|
||||
entityEmbeddings.set(chunkIds[i], embeddings[i]);
|
||||
textContent.set(chunkIds[i], chunks[i]);
|
||||
}
|
||||
|
||||
// Store embeddings in PineconeService with retry logic
|
||||
try {
|
||||
await pineconeService.storeEmbeddings(entityEmbeddings, textContent);
|
||||
} catch (storeError) {
|
||||
console.error('Error storing embeddings in Pinecone:', storeError);
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to store embeddings in Pinecone: ${storeError instanceof Error ? storeError.message : String(storeError)}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
documentId: documentId || 'unnamed',
|
||||
chunks: chunks.length,
|
||||
embeddings: embeddings.length
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error generating embeddings:', error);
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to generate embeddings: ${error instanceof Error ? error.message : String(error)}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate chunks from text content
|
||||
* @param content Text content
|
||||
* @param chunkSize Size of each chunk
|
||||
* @param overlap Overlap between chunks
|
||||
* @returns Array of text chunks
|
||||
*/
|
||||
function generateChunks(content: string, chunkSize: number, overlap: number = 50): string[] {
|
||||
const chunks: string[] = [];
|
||||
const sentences = content.split(/(?<=[.!?])\s+/);
|
||||
|
||||
let currentChunk = '';
|
||||
for (const sentence of sentences) {
|
||||
// If adding this sentence would make the chunk too long, save the current chunk and start a new one
|
||||
if (currentChunk.length + sentence.length > chunkSize && currentChunk.length > 0) {
|
||||
chunks.push(currentChunk.trim());
|
||||
// Take the last part of the current chunk as overlap for the next chunk
|
||||
const words = currentChunk.split(' ');
|
||||
currentChunk = words.slice(Math.max(0, words.length - overlap)).join(' ');
|
||||
}
|
||||
|
||||
currentChunk += ' ' + sentence;
|
||||
}
|
||||
|
||||
// Add the last chunk if it's not empty
|
||||
if (currentChunk.trim().length > 0) {
|
||||
chunks.push(currentChunk.trim());
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
@ -0,0 +1,92 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { RemoteBackendService } from '@/lib/remote-backend';
|
||||
|
||||
/**
|
||||
* API endpoint for enhanced RAG query with LangChain features
|
||||
* POST /api/enhanced-query
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
// Parse request body
|
||||
const body = await req.json();
|
||||
const { query, kNeighbors, fanout, numHops, topK, queryMode, useTraditional } = body;
|
||||
|
||||
if (!query || typeof query !== 'string') {
|
||||
return NextResponse.json({ error: 'Query is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Initialize the backend service
|
||||
const backend = RemoteBackendService.getInstance();
|
||||
|
||||
// Prepare parameters with defaults
|
||||
const params = {
|
||||
kNeighbors: kNeighbors || 4096,
|
||||
fanout: fanout || 400,
|
||||
numHops: numHops || 2,
|
||||
topK: topK || 5
|
||||
};
|
||||
|
||||
console.log(`Enhanced RAG query: "${query}" with params:`, params);
|
||||
console.log(`Query mode: ${queryMode}, useTraditional: ${useTraditional}`);
|
||||
|
||||
// Determine search method - if traditional is specified, use that
|
||||
const shouldUseTraditional = useTraditional || (queryMode === 'traditional');
|
||||
|
||||
if (shouldUseTraditional) {
|
||||
console.log('Using traditional search for enhanced query');
|
||||
// Call the regular query method with traditional flag
|
||||
const relevantTriples = await backend.query(
|
||||
query,
|
||||
params.kNeighbors,
|
||||
params.fanout,
|
||||
params.numHops,
|
||||
{
|
||||
topk: params.topK,
|
||||
topk_e: params.topK,
|
||||
cost_e: 0.5,
|
||||
num_clusters: 2
|
||||
},
|
||||
true // Use traditional search
|
||||
);
|
||||
|
||||
// Return the results
|
||||
return NextResponse.json({
|
||||
relevantTriples,
|
||||
count: relevantTriples.length,
|
||||
metadata: {
|
||||
searchType: 'traditional'
|
||||
},
|
||||
success: true
|
||||
});
|
||||
}
|
||||
|
||||
// Use the enhanced query with metadata for vector search
|
||||
const { relevantTriples, queryMetadata } = await backend.enhancedQuery(
|
||||
query,
|
||||
params.kNeighbors,
|
||||
params.fanout,
|
||||
params.numHops,
|
||||
{
|
||||
topk: params.topK,
|
||||
topk_e: params.topK,
|
||||
cost_e: 0.5,
|
||||
num_clusters: 2
|
||||
}
|
||||
);
|
||||
|
||||
// Return the results
|
||||
return NextResponse.json({
|
||||
relevantTriples,
|
||||
count: relevantTriples.length,
|
||||
metadata: queryMetadata,
|
||||
success: true
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error in enhanced RAG query:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to execute enhanced query: ${errorMessage}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
207
nvidia/txt2kg/assets/frontend/app/api/extract-triples/route.ts
Normal file
207
nvidia/txt2kg/assets/frontend/app/api/extract-triples/route.ts
Normal file
@ -0,0 +1,207 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { processDocument, TextProcessor } from '@/lib/text-processor';
|
||||
import { llmService } from '@/lib/llm-service';
|
||||
|
||||
// Configure route for dynamic operations and long-running requests
|
||||
export const dynamic = 'force-dynamic';
|
||||
export const maxDuration = 1800; // 30 minutes for large model processing
|
||||
|
||||
/**
|
||||
* API endpoint for extracting triples from text using the LangChain-based pipeline
|
||||
* POST /api/extract-triples
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const startTime = Date.now();
|
||||
console.log(`[${new Date().toISOString()}] extract-triples: Request received`);
|
||||
|
||||
try {
|
||||
// Parse request body
|
||||
const body = await req.json();
|
||||
console.log(`[${new Date().toISOString()}] extract-triples: Body parsed, text length: ${body.text?.length || 0}`);
|
||||
const {
|
||||
text,
|
||||
useLangChain = false,
|
||||
useGraphTransformer = false,
|
||||
systemPrompt,
|
||||
extractionPrompt,
|
||||
graphTransformerPrompt,
|
||||
llmProvider,
|
||||
ollamaModel,
|
||||
ollamaBaseUrl,
|
||||
vllmModel,
|
||||
vllmBaseUrl
|
||||
} = body;
|
||||
|
||||
if (!text || typeof text !== 'string') {
|
||||
return NextResponse.json({ error: 'Text is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// If Ollama is specified, call llmService directly (avoid internal fetch timeout)
|
||||
if (llmProvider === 'ollama') {
|
||||
console.log(`[${new Date().toISOString()}] extract-triples: Processing with Ollama model: ${ollamaModel || 'llama3.1:8b'}`);
|
||||
const llmStartTime = Date.now();
|
||||
|
||||
try {
|
||||
const model = ollamaModel || 'llama3.1:8b';
|
||||
const messages = [
|
||||
{
|
||||
role: 'system' as const,
|
||||
content: 'You are a knowledge graph builder. Extract subject-predicate-object triples from text and return them as a JSON array.'
|
||||
},
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: `Extract triples from this text:\n\n${text}`
|
||||
}
|
||||
];
|
||||
|
||||
console.log(`[${new Date().toISOString()}] extract-triples: Calling llmService.generateOllamaCompletion directly`);
|
||||
const response = await llmService.generateOllamaCompletion(
|
||||
model,
|
||||
messages,
|
||||
{ temperature: 0.1, maxTokens: 8192 }
|
||||
);
|
||||
|
||||
const llmDuration = ((Date.now() - llmStartTime) / 1000).toFixed(2);
|
||||
console.log(`[${new Date().toISOString()}] extract-triples: LLM completion received after ${llmDuration}s, response length: ${response?.length || 0}`);
|
||||
|
||||
// Parse the response to extract triples
|
||||
let triples = [];
|
||||
try {
|
||||
const jsonMatch = response.match(/\[[\s\S]*\]/);
|
||||
if (jsonMatch) {
|
||||
triples = JSON.parse(jsonMatch[0]);
|
||||
} else {
|
||||
// Fallback parser
|
||||
triples = parseTriplesFallback(response);
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.warn('Failed to parse JSON response, using fallback parser:', parseError);
|
||||
triples = parseTriplesFallback(response);
|
||||
}
|
||||
|
||||
const totalDuration = ((Date.now() - llmStartTime) / 1000).toFixed(2);
|
||||
console.log(`[${new Date().toISOString()}] extract-triples: Returning ${triples.length} triples, total duration: ${totalDuration}s`);
|
||||
|
||||
return NextResponse.json({
|
||||
triples: triples.map((triple) => ({
|
||||
...triple,
|
||||
confidence: 0.8,
|
||||
metadata: {
|
||||
entityTypes: [],
|
||||
source: text.substring(0, 100) + '...',
|
||||
context: text.substring(0, 200) + '...',
|
||||
extractionMethod: 'ollama',
|
||||
model: model
|
||||
}
|
||||
})),
|
||||
count: triples.length,
|
||||
success: true,
|
||||
method: 'ollama',
|
||||
model: model
|
||||
});
|
||||
} catch (llmError) {
|
||||
const llmDuration = ((Date.now() - llmStartTime) / 1000).toFixed(2);
|
||||
console.error(`[${new Date().toISOString()}] extract-triples: Ollama processing failed after ${llmDuration}s:`, llmError);
|
||||
throw llmError;
|
||||
}
|
||||
}
|
||||
|
||||
// If vLLM is specified, use the vLLM API endpoint
|
||||
if (llmProvider === 'vllm') {
|
||||
const vllmResponse = await fetch(`${req.nextUrl.origin}/api/vllm`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
text,
|
||||
model: vllmModel || 'meta-llama/Llama-3.2-3B-Instruct',
|
||||
temperature: 0.1,
|
||||
maxTokens: 8192
|
||||
})
|
||||
});
|
||||
|
||||
if (!vllmResponse.ok) {
|
||||
throw new Error(`vLLM API error: ${vllmResponse.statusText}`);
|
||||
}
|
||||
|
||||
const vllmResult = await vllmResponse.json();
|
||||
return NextResponse.json(vllmResult);
|
||||
}
|
||||
|
||||
// Configure TextProcessor for the specified LLM provider
|
||||
const processor = TextProcessor.getInstance();
|
||||
if (llmProvider && ['ollama', 'nvidia', 'vllm'].includes(llmProvider)) {
|
||||
processor.setLLMProvider(llmProvider as 'ollama' | 'nvidia' | 'vllm', {
|
||||
ollamaModel: ollamaModel,
|
||||
ollamaBaseUrl: ollamaBaseUrl,
|
||||
vllmModel: vllmModel,
|
||||
vllmBaseUrl: vllmBaseUrl
|
||||
});
|
||||
}
|
||||
|
||||
// Process the text to extract triples using either default pipeline or LangChain transformer
|
||||
// When both useLangChain and useGraphTransformer are true, use the GraphTransformer
|
||||
// When only useLangChain is true, use the default LangChain pipeline
|
||||
// Pass custom prompts if provided
|
||||
const options = {
|
||||
systemPrompt,
|
||||
extractionPrompt,
|
||||
graphTransformerPrompt
|
||||
};
|
||||
|
||||
const triples = await processDocument(text, useLangChain, useGraphTransformer, options);
|
||||
|
||||
// Return the extracted triples
|
||||
return NextResponse.json({
|
||||
triples,
|
||||
count: triples.length,
|
||||
success: true,
|
||||
method: useGraphTransformer
|
||||
? 'langchain_graphtransformer'
|
||||
: useLangChain
|
||||
? 'langchain_default'
|
||||
: 'standard_pipeline',
|
||||
llmProvider: processor.getLLMProvider(),
|
||||
customPromptUsed: !!(systemPrompt || extractionPrompt || graphTransformerPrompt)
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error in triple extraction:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to extract triples: ${errorMessage}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to parse triples from text when JSON parsing fails
|
||||
function parseTriplesFallback(text: string): Array<{subject: string, predicate: string, object: string}> {
|
||||
const triples = [];
|
||||
const lines = text.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
// Look for patterns like "Subject - Predicate - Object" or similar
|
||||
const tripleMatch = line.match(/^[\s\-\*\d\.]*(.+?)\s*[\-\|]\s*(.+?)\s*[\-\|]\s*(.+)$/);
|
||||
if (tripleMatch) {
|
||||
triples.push({
|
||||
subject: tripleMatch[1].trim(),
|
||||
predicate: tripleMatch[2].trim(),
|
||||
object: tripleMatch[3].trim()
|
||||
});
|
||||
}
|
||||
|
||||
// Also look for JSON-like objects in the text
|
||||
const jsonObjectMatch = line.match(/\{\s*"subject"\s*:\s*"([^"]+)"\s*,\s*"predicate"\s*:\s*"([^"]+)"\s*,\s*"object"\s*:\s*"([^"]+)"\s*\}/);
|
||||
if (jsonObjectMatch) {
|
||||
triples.push({
|
||||
subject: jsonObjectMatch[1],
|
||||
predicate: jsonObjectMatch[2],
|
||||
object: jsonObjectMatch[3]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return triples;
|
||||
}
|
||||
|
||||
117
nvidia/txt2kg/assets/frontend/app/api/fix-query-logs/route.ts
Normal file
117
nvidia/txt2kg/assets/frontend/app/api/fix-query-logs/route.ts
Normal file
@ -0,0 +1,117 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import queryLoggerService, { QueryLogEntry } from '@/lib/query-logger';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { promises as fsPromises } from 'fs';
|
||||
|
||||
interface QueryLogData {
|
||||
query: string;
|
||||
count: number;
|
||||
}
|
||||
|
||||
interface FixResults {
|
||||
fixed: number;
|
||||
data: QueryLogData[];
|
||||
}
|
||||
|
||||
/**
|
||||
* API endpoint to check and fix query logs
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
console.log('Checking and fixing query logs');
|
||||
|
||||
// Initialize logger if not already
|
||||
if (!queryLoggerService.isInitialized()) {
|
||||
await queryLoggerService.initialize();
|
||||
}
|
||||
|
||||
let results: FixResults = { fixed: 0, data: [] };
|
||||
|
||||
try {
|
||||
// Get the log file path
|
||||
const logFilePath = path.join(process.cwd(), 'data', 'query-logs.json');
|
||||
|
||||
// Check if log file exists
|
||||
if (!fs.existsSync(logFilePath)) {
|
||||
console.log('Log file does not exist, creating empty file');
|
||||
await fsPromises.mkdir(path.dirname(logFilePath), { recursive: true });
|
||||
await fsPromises.writeFile(logFilePath, JSON.stringify([]));
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
results,
|
||||
message: 'Created new empty log file'
|
||||
});
|
||||
}
|
||||
|
||||
// Read existing logs
|
||||
const logsRaw = await fsPromises.readFile(logFilePath, 'utf-8');
|
||||
let logs: QueryLogEntry[] = JSON.parse(logsRaw || '[]');
|
||||
|
||||
console.log(`Found ${logs.length} query log entries`);
|
||||
|
||||
// Create a summary of existing logs
|
||||
const querySummary = new Map<string, number>();
|
||||
logs.forEach(log => {
|
||||
const count = querySummary.get(log.query) || 0;
|
||||
querySummary.set(log.query, count + 1);
|
||||
});
|
||||
|
||||
// Convert to array for response
|
||||
results.data = Array.from(querySummary.entries()).map(([query, count]) => ({
|
||||
query,
|
||||
count
|
||||
}));
|
||||
|
||||
// If there are no logs, add a default test log
|
||||
if (logs.length === 0) {
|
||||
console.log('No logs found, adding a default test log');
|
||||
|
||||
const defaultLog: QueryLogEntry = {
|
||||
query: 'Test query for metrics',
|
||||
queryMode: 'traditional',
|
||||
timestamp: new Date().toISOString(),
|
||||
metrics: {
|
||||
executionTimeMs: 0,
|
||||
relevanceScore: 0,
|
||||
precision: 0,
|
||||
recall: 0,
|
||||
resultCount: 0
|
||||
}
|
||||
};
|
||||
|
||||
logs.push(defaultLog);
|
||||
results.fixed++;
|
||||
|
||||
// Update results data
|
||||
results.data.push({
|
||||
query: defaultLog.query,
|
||||
count: 1
|
||||
});
|
||||
|
||||
// Write back to file
|
||||
await fsPromises.writeFile(logFilePath, JSON.stringify(logs, null, 2));
|
||||
console.log('Added default test log');
|
||||
}
|
||||
|
||||
// Return the fixed results
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
results,
|
||||
message: `Fixed ${results.fixed} query logs`
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error during fix operation:', error);
|
||||
return NextResponse.json({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
}, { status: 500 });
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fixing query logs:', error);
|
||||
return NextResponse.json({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
}, { status: 500 });
|
||||
}
|
||||
}
|
||||
139
nvidia/txt2kg/assets/frontend/app/api/graph-data/route.ts
Normal file
139
nvidia/txt2kg/assets/frontend/app/api/graph-data/route.ts
Normal file
@ -0,0 +1,139 @@
|
||||
import { type NextRequest, NextResponse } from "next/server"
|
||||
|
||||
// Utility function to generate UUID with fallback
|
||||
const generateUUID = (): string => {
|
||||
// Check if crypto.randomUUID is available
|
||||
if (typeof crypto !== 'undefined' && crypto.randomUUID) {
|
||||
try {
|
||||
return crypto.randomUUID();
|
||||
} catch (error) {
|
||||
console.warn('crypto.randomUUID failed, using fallback:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback UUID generation
|
||||
return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function(c) {
|
||||
const r = Math.random() * 16 | 0;
|
||||
const v = c == 'x' ? r : (r & 0x3 | 0x8);
|
||||
return v.toString(16);
|
||||
});
|
||||
};
|
||||
|
||||
// Create a more persistent storage mechanism (still in-memory but more reliable)
|
||||
// This will be a global variable that persists between API calls
|
||||
// In a production environment, you would use a database instead
|
||||
const graphDataStore = new Map<string, { triples: any[]; documentName: string }>()
|
||||
|
||||
// Sample graph data for when no ID is provided
|
||||
const sampleGraphData = {
|
||||
nodes: [
|
||||
{ id: "1", name: "Document 1", group: "document" },
|
||||
{ id: "2", name: "Machine Learning", group: "concept" },
|
||||
{ id: "3", name: "Neural Networks", group: "concept" },
|
||||
{ id: "4", name: "Deep Learning", group: "concept" },
|
||||
{ id: "5", name: "Computer Vision", group: "concept" },
|
||||
{ id: "6", name: "Natural Language Processing", group: "concept" },
|
||||
{ id: "7", name: "Reinforcement Learning", group: "concept" },
|
||||
{ id: "8", name: "Supervised Learning", group: "concept" },
|
||||
{ id: "9", name: "Unsupervised Learning", group: "concept" },
|
||||
{ id: "10", name: "Semi-supervised Learning", group: "concept" },
|
||||
{ id: "11", name: "Transfer Learning", group: "concept" },
|
||||
{ id: "12", name: "GPT-4", group: "important" },
|
||||
{ id: "13", name: "BERT", group: "concept" },
|
||||
{ id: "14", name: "Transformers", group: "concept" },
|
||||
{ id: "15", name: "CNN", group: "concept" },
|
||||
{ id: "16", name: "RNN", group: "concept" },
|
||||
{ id: "17", name: "LSTM", group: "concept" },
|
||||
{ id: "18", name: "GAN", group: "concept" },
|
||||
{ id: "19", name: "Diffusion Models", group: "important" },
|
||||
{ id: "20", name: "Document 2", group: "document" },
|
||||
],
|
||||
links: [
|
||||
{ source: "1", target: "2", name: "mentions" },
|
||||
{ source: "1", target: "3", name: "discusses" },
|
||||
{ source: "1", target: "4", name: "explains" },
|
||||
{ source: "2", target: "3", name: "includes" },
|
||||
{ source: "2", target: "4", name: "includes" },
|
||||
{ source: "2", target: "5", name: "related_to" },
|
||||
{ source: "2", target: "6", name: "related_to" },
|
||||
{ source: "2", target: "7", name: "includes" },
|
||||
{ source: "2", target: "8", name: "includes" },
|
||||
{ source: "2", target: "9", name: "includes" },
|
||||
{ source: "2", target: "10", name: "includes" },
|
||||
{ source: "2", target: "11", name: "includes" },
|
||||
{ source: "3", target: "15", name: "includes" },
|
||||
{ source: "3", target: "16", name: "includes" },
|
||||
{ source: "3", target: "17", name: "includes" },
|
||||
{ source: "4", target: "12", name: "uses" },
|
||||
{ source: "4", target: "13", name: "uses" },
|
||||
{ source: "4", target: "14", name: "uses" },
|
||||
{ source: "6", target: "12", name: "uses" },
|
||||
{ source: "6", target: "13", name: "uses" },
|
||||
{ source: "6", target: "14", name: "uses" },
|
||||
{ source: "5", target: "15", name: "uses" },
|
||||
{ source: "5", target: "18", name: "uses" },
|
||||
{ source: "5", target: "19", name: "uses" },
|
||||
{ source: "20", target: "6", name: "mentions" },
|
||||
{ source: "20", target: "12", name: "discusses" },
|
||||
{ source: "20", target: "19", name: "explains" },
|
||||
]
|
||||
};
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const { triples, documentName } = await request.json()
|
||||
|
||||
if (!triples || !Array.isArray(triples)) {
|
||||
return NextResponse.json({ error: "Invalid triples data" }, { status: 400 })
|
||||
}
|
||||
|
||||
// Generate a unique ID for this graph data
|
||||
const graphId = generateUUID()
|
||||
|
||||
// Store the data
|
||||
graphDataStore.set(graphId, { triples, documentName: documentName || "Unnamed Document" })
|
||||
|
||||
console.log(`Stored graph data with ID: ${graphId}, triples count: ${triples.length}`)
|
||||
|
||||
// Return the ID
|
||||
return NextResponse.json({ graphId })
|
||||
} catch (error) {
|
||||
console.error("Error storing graph data:", error)
|
||||
return NextResponse.json({ error: "Failed to store graph data" }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const url = new URL(request.url)
|
||||
const graphId = url.searchParams.get("id")
|
||||
|
||||
// If no ID provided, return sample graph data
|
||||
if (!graphId) {
|
||||
console.log("No graph ID provided, returning sample data")
|
||||
return NextResponse.json(sampleGraphData)
|
||||
}
|
||||
|
||||
console.log(`Retrieving graph data for ID: ${graphId}`)
|
||||
console.log(`Available graph IDs: ${Array.from(graphDataStore.keys()).join(", ")}`)
|
||||
|
||||
const graphData = graphDataStore.get(graphId)
|
||||
|
||||
if (!graphData) {
|
||||
console.log(`Graph data not found for ID: ${graphId}. Informing client to use localStorage.`)
|
||||
// Instead of a redirect, return a special response that tells the client to use localStorage
|
||||
return NextResponse.json({
|
||||
redirect: true,
|
||||
useLocalStorage: true,
|
||||
error: "Graph data not found or has expired"
|
||||
}, { status: 404 })
|
||||
}
|
||||
|
||||
console.log(`Found graph data with ${graphData.triples.length} triples`)
|
||||
return NextResponse.json(graphData)
|
||||
} catch (error) {
|
||||
console.error("Error retrieving graph data:", error)
|
||||
return NextResponse.json({ error: "Failed to retrieve graph data" }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { getGraphDbService } from '@/lib/graph-db-util';
|
||||
import { getGraphDbType } from '../../settings/route';
|
||||
import { ArangoDBService } from '@/lib/arangodb';
|
||||
import { Neo4jService } from '@/lib/neo4j';
|
||||
|
||||
/**
|
||||
* POST handler for clearing all data from the graph database
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
// Get the preferred database type from settings
|
||||
const graphDbType = getGraphDbType();
|
||||
console.log(`Using graph database for clearing: ${graphDbType}`);
|
||||
|
||||
// Get the appropriate service
|
||||
const graphDbService = getGraphDbService(graphDbType);
|
||||
|
||||
// Clear the database based on type
|
||||
if (graphDbType === 'arangodb') {
|
||||
const arangoService = graphDbService as ArangoDBService;
|
||||
await arangoService.clearDatabase();
|
||||
} else if (graphDbType === 'neo4j') {
|
||||
// TODO: Implement Neo4j clear functionality when needed
|
||||
throw new Error('Clear database functionality not implemented for Neo4j');
|
||||
}
|
||||
|
||||
// Return success response
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Successfully cleared all data from ${graphDbType} database`,
|
||||
databaseType: graphDbType
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Error in clear database handler:`, error);
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to clear database: ${error instanceof Error ? error.message : String(error)}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,42 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { getGraphDbService } from '@/lib/graph-db-util';
|
||||
import { getGraphDbType } from '../../settings/route';
|
||||
|
||||
/**
|
||||
* API endpoint for disconnecting from the selected graph database
|
||||
* POST /api/graph-db/disconnect
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
// Get the graph database type from the settings
|
||||
const graphDbType = getGraphDbType();
|
||||
console.log(`Disconnecting from ${graphDbType}...`);
|
||||
|
||||
// Get the appropriate service
|
||||
const graphDbService = getGraphDbService(graphDbType);
|
||||
|
||||
if (graphDbService.isInitialized()) {
|
||||
graphDbService.close();
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Successfully disconnected from ${graphDbType}`,
|
||||
type: graphDbType
|
||||
});
|
||||
} else {
|
||||
return NextResponse.json({
|
||||
success: false,
|
||||
message: `No active ${graphDbType} connection to disconnect`,
|
||||
type: graphDbType
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error disconnecting from graph database:', error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Failed to disconnect from graph database: ${error instanceof Error ? error.message : String(error)}`,
|
||||
type: getGraphDbType()
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
155
nvidia/txt2kg/assets/frontend/app/api/graph-db/route.ts
Normal file
155
nvidia/txt2kg/assets/frontend/app/api/graph-db/route.ts
Normal file
@ -0,0 +1,155 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { getGraphDbService } from '@/lib/graph-db-util';
|
||||
import { getGraphDbType } from '../settings/route';
|
||||
import { GraphDBType } from '@/lib/graph-db-service';
|
||||
|
||||
/**
|
||||
* Initialize graph database connection with parameters from request
|
||||
* @param request Optional request containing connection parameters
|
||||
*/
|
||||
async function ensureConnection(request?: NextRequest): Promise<GraphDBType> {
|
||||
try {
|
||||
// Get the preferred database type from settings or request
|
||||
let graphDbType: GraphDBType;
|
||||
|
||||
if (request?.nextUrl.searchParams.has('type')) {
|
||||
// Explicitly specified in the request
|
||||
graphDbType = request.nextUrl.searchParams.get('type') as GraphDBType;
|
||||
} else {
|
||||
// Get from settings, with a safe fallback
|
||||
graphDbType = getGraphDbType();
|
||||
}
|
||||
|
||||
console.log(`Using graph database: ${graphDbType}`);
|
||||
|
||||
// Get the appropriate service
|
||||
const graphDbService = getGraphDbService(graphDbType);
|
||||
|
||||
if (graphDbType === 'neo4j') {
|
||||
// Neo4j connection params
|
||||
let uri = process.env.NEO4J_URI;
|
||||
let username = process.env.NEO4J_USER || process.env.NEO4J_USERNAME;
|
||||
let password = process.env.NEO4J_PASSWORD;
|
||||
|
||||
// Override with URL parameters if provided
|
||||
if (request) {
|
||||
const params = request.nextUrl.searchParams;
|
||||
if (params.has('url')) uri = params.get('url') as string;
|
||||
if (params.has('username')) username = params.get('username') as string;
|
||||
if (params.has('password')) password = params.get('password') as string;
|
||||
}
|
||||
|
||||
// Connect to Neo4j instance
|
||||
graphDbService.initialize(uri, username, password);
|
||||
} else if (graphDbType === 'arangodb') {
|
||||
// ArangoDB connection params - environment variables take absolute priority
|
||||
let url = process.env.ARANGODB_URL;
|
||||
let dbName = process.env.ARANGODB_DB;
|
||||
let username = process.env.ARANGODB_USER;
|
||||
let password = process.env.ARANGODB_PASSWORD;
|
||||
|
||||
// Only use URL parameters if environment variables are not set
|
||||
if (request) {
|
||||
const params = request.nextUrl.searchParams;
|
||||
if (!url && params.has('url')) url = params.get('url') as string;
|
||||
if (!dbName && params.has('dbName')) dbName = params.get('dbName') as string;
|
||||
if (!username && params.has('username')) username = params.get('username') as string;
|
||||
if (!password && params.has('password')) password = params.get('password') as string;
|
||||
}
|
||||
|
||||
// Connect to ArangoDB instance
|
||||
await (graphDbService as any).initialize(url, dbName, username, password);
|
||||
}
|
||||
|
||||
return graphDbType;
|
||||
} catch (error) {
|
||||
console.error(`Failed to initialize graph database connection:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GET handler for retrieving graph data from the selected graph database
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
// Initialize with connection parameters
|
||||
const graphDbType = await ensureConnection(request);
|
||||
const graphDbService = getGraphDbService(graphDbType);
|
||||
|
||||
// Get graph data from the database
|
||||
const graphData = await graphDbService.getGraphData();
|
||||
|
||||
// Transform to format expected by the frontend
|
||||
const nodes = graphData.nodes.map(node => ({
|
||||
...node,
|
||||
name: node.name || `Node ${node.id}`,
|
||||
label: node.labels?.[0] || 'Entity',
|
||||
val: 1, // Default size
|
||||
color: node.labels?.includes('Entity') ? '#ff6b6b' : '#4ecdc4'
|
||||
}));
|
||||
|
||||
const links = graphData.relationships.map(rel => ({
|
||||
...rel,
|
||||
label: rel.type || 'RELATED_TO'
|
||||
}));
|
||||
|
||||
// Get the connection URL from request params or env
|
||||
const params = request.nextUrl.searchParams;
|
||||
const connectionUrl = params.get('url') ||
|
||||
(graphDbType === 'neo4j' ? process.env.NEO4J_URI : process.env.ARANGODB_URL) ||
|
||||
'Not specified';
|
||||
|
||||
// Convert to the format expected by the application
|
||||
return NextResponse.json({
|
||||
nodes,
|
||||
links,
|
||||
connectionUrl,
|
||||
databaseType: graphDbType
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Error in graph database GET handler:`, error);
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to fetch graph data: ${error instanceof Error ? error.message : String(error)}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* POST handler for importing triples into the selected graph database
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
// Initialize with connection parameters
|
||||
const graphDbType = await ensureConnection(request);
|
||||
const graphDbService = getGraphDbService(graphDbType);
|
||||
|
||||
// Parse request body
|
||||
const body = await request.json();
|
||||
|
||||
// Validate request body
|
||||
if (!body.triples || !Array.isArray(body.triples)) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Invalid request: triples array is required' },
|
||||
{ status: 400 }
|
||||
);
|
||||
}
|
||||
|
||||
// Import triples into the graph database
|
||||
await graphDbService.importTriples(body.triples);
|
||||
|
||||
// Return success response
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Successfully imported ${body.triples.length} triples into ${graphDbType}`,
|
||||
databaseType: graphDbType
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Error in graph database POST handler:`, error);
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to import triples: ${error instanceof Error ? error.message : String(error)}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
180
nvidia/txt2kg/assets/frontend/app/api/graph-db/triples/route.ts
Normal file
180
nvidia/txt2kg/assets/frontend/app/api/graph-db/triples/route.ts
Normal file
@ -0,0 +1,180 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { getGraphDbService } from '@/lib/graph-db-util';
|
||||
import { getGraphDbType } from '../../settings/route';
|
||||
import type { Triple } from '@/types/graph';
|
||||
import { GraphDBType } from '@/lib/graph-db-service';
|
||||
|
||||
/**
|
||||
* API endpoint for fetching all triples from the selected graph database
|
||||
* GET /api/graph-db/triples
|
||||
*/
|
||||
export async function GET(req: NextRequest) {
|
||||
try {
|
||||
// Get the database type from settings or request parameter
|
||||
const graphDbType = req.nextUrl.searchParams.get('type') as GraphDBType || getGraphDbType();
|
||||
console.log(`Using graph database type: ${graphDbType}`);
|
||||
|
||||
// Get the appropriate graph database service
|
||||
const graphDbService = getGraphDbService(graphDbType);
|
||||
|
||||
// Initialize the service based on database type
|
||||
if (graphDbType === 'neo4j') {
|
||||
// Neo4j specific initialization
|
||||
const uri = process.env.NEO4J_URI;
|
||||
const username = process.env.NEO4J_USER || process.env.NEO4J_USERNAME;
|
||||
const password = process.env.NEO4J_PASSWORD;
|
||||
graphDbService.initialize(uri, username, password);
|
||||
} else if (graphDbType === 'arangodb') {
|
||||
// ArangoDB specific initialization
|
||||
const url = process.env.ARANGODB_URL;
|
||||
const dbName = process.env.ARANGODB_DB;
|
||||
const username = process.env.ARANGODB_USER;
|
||||
const password = process.env.ARANGODB_PASSWORD;
|
||||
await (graphDbService as any).initialize(url, dbName, username, password);
|
||||
}
|
||||
|
||||
console.log(`Fetching all triples from ${graphDbType}...`);
|
||||
|
||||
// Get all triples from the graph database
|
||||
// We'll use the graphDbService to get the graph data and then extract the triples
|
||||
const graphData = await graphDbService.getGraphData();
|
||||
|
||||
// Extract triples from the graph data
|
||||
const triples: Triple[] = [];
|
||||
|
||||
// Map of node IDs to names
|
||||
const nodeMap = new Map();
|
||||
for (const node of graphData.nodes) {
|
||||
nodeMap.set(node.id, node.name);
|
||||
}
|
||||
|
||||
// Convert relationships to triples
|
||||
for (const rel of graphData.relationships) {
|
||||
const subject = nodeMap.get(rel.source);
|
||||
const object = nodeMap.get(rel.target);
|
||||
const predicate = rel.type;
|
||||
|
||||
if (subject && predicate && object) {
|
||||
triples.push({
|
||||
subject,
|
||||
predicate,
|
||||
object
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate triples
|
||||
const uniqueTriples = deduplicateTriples(triples);
|
||||
|
||||
console.log(`Successfully fetched ${uniqueTriples.length} unique triples from ${graphDbType}`);
|
||||
|
||||
// Return the triples
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
triples: uniqueTriples,
|
||||
count: uniqueTriples.length,
|
||||
databaseType: graphDbType
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error(`Error fetching triples from graph database:`, error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to fetch triples: ${errorMessage}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to deduplicate triples
|
||||
*/
|
||||
function deduplicateTriples(triples: Triple[]): Triple[] {
|
||||
const seen = new Set<string>();
|
||||
return triples.filter(triple => {
|
||||
// Create a string key for this triple
|
||||
const key = `${triple.subject.toLowerCase()}|${triple.predicate.toLowerCase()}|${triple.object.toLowerCase()}`;
|
||||
|
||||
// Check if we've seen this triple before
|
||||
if (seen.has(key)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Mark this triple as seen
|
||||
seen.add(key);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* API endpoint for storing triples in the selected graph database
|
||||
* POST /api/graph-db/triples
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
// Parse request body
|
||||
const body = await req.json();
|
||||
const { triples, documentName } = body;
|
||||
|
||||
if (!triples || !Array.isArray(triples)) {
|
||||
return NextResponse.json({ error: 'Triples are required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Get the database type from settings or request parameter
|
||||
const graphDbType = req.nextUrl.searchParams.get('type') as GraphDBType || getGraphDbType();
|
||||
console.log(`Using graph database type: ${graphDbType}`);
|
||||
|
||||
console.log(`Storing ${triples.length} triples in ${graphDbType} from document "${documentName || 'unnamed'}"`);
|
||||
|
||||
// Get the appropriate graph database service
|
||||
const graphDbService = getGraphDbService(graphDbType);
|
||||
|
||||
// Initialize the service based on database type
|
||||
if (graphDbType === 'neo4j') {
|
||||
// Neo4j specific initialization
|
||||
const uri = process.env.NEO4J_URI;
|
||||
const username = process.env.NEO4J_USER || process.env.NEO4J_USERNAME;
|
||||
const password = process.env.NEO4J_PASSWORD;
|
||||
graphDbService.initialize(uri, username, password);
|
||||
} else if (graphDbType === 'arangodb') {
|
||||
// ArangoDB specific initialization
|
||||
const url = process.env.ARANGODB_URL;
|
||||
const dbName = process.env.ARANGODB_DB;
|
||||
const username = process.env.ARANGODB_USER;
|
||||
const password = process.env.ARANGODB_PASSWORD;
|
||||
await (graphDbService as any).initialize(url, dbName, username, password);
|
||||
}
|
||||
|
||||
// Filter triples to ensure they are valid
|
||||
const validTriples = triples.filter((triple: any) => {
|
||||
return (
|
||||
triple &&
|
||||
typeof triple.subject === 'string' && triple.subject.trim() !== '' &&
|
||||
typeof triple.predicate === 'string' && triple.predicate.trim() !== '' &&
|
||||
typeof triple.object === 'string' && triple.object.trim() !== ''
|
||||
);
|
||||
}) as Triple[];
|
||||
|
||||
console.log(`Found ${validTriples.length} valid triples to store`);
|
||||
|
||||
// Store triples in the graph database
|
||||
await graphDbService.importTriples(validTriples);
|
||||
|
||||
// Return success response
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Triples stored successfully in ${graphDbType}`,
|
||||
count: validTriples.length,
|
||||
documentName,
|
||||
databaseType: graphDbType
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error storing triples in graph database:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to store triples: ${errorMessage}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
138
nvidia/txt2kg/assets/frontend/app/api/metrics/route.ts
Normal file
138
nvidia/txt2kg/assets/frontend/app/api/metrics/route.ts
Normal file
@ -0,0 +1,138 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import remoteBackendInstance from '@/lib/remote-backend';
|
||||
import { Neo4jService } from '@/lib/neo4j';
|
||||
import neo4jService from '@/lib/neo4j';
|
||||
import { PineconeService } from '@/lib/pinecone';
|
||||
import RAGService from '@/lib/rag';
|
||||
import queryLoggerService, { QueryLogSummary } from '@/lib/query-logger';
|
||||
|
||||
/**
|
||||
* Metrics API that provides performance statistics about the RAG system
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
// Initialize services
|
||||
const neo4j = Neo4jService.getInstance();
|
||||
const pineconeService = PineconeService.getInstance();
|
||||
|
||||
if (!neo4j.isInitialized()) {
|
||||
neo4j.initialize();
|
||||
}
|
||||
|
||||
// Get graph stats from Neo4j
|
||||
const graphData = await neo4j.getGraphData();
|
||||
|
||||
// Get unique entities (nodes)
|
||||
const uniqueEntities = new Set<string>();
|
||||
graphData.nodes.forEach((node: any) => uniqueEntities.add(node.name));
|
||||
|
||||
// Get total triples (relationships)
|
||||
const totalTriples = graphData.relationships.length;
|
||||
|
||||
// Get vector stats from Pinecone if available
|
||||
let vectorStats = {
|
||||
totalVectors: 0,
|
||||
avgQueryTime: 0,
|
||||
avgRelevanceScore: 0
|
||||
};
|
||||
|
||||
try {
|
||||
await pineconeService.initialize();
|
||||
const stats = await pineconeService.getStats();
|
||||
|
||||
vectorStats = {
|
||||
totalVectors: stats.totalVectorCount || 0,
|
||||
avgQueryTime: stats.averageQueryTime || 0,
|
||||
avgRelevanceScore: stats.averageRelevanceScore || 0
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn('Could not fetch Pinecone stats:', error);
|
||||
}
|
||||
|
||||
// Get real query logs instead of mock data
|
||||
let queryLogs: QueryLogSummary[] = [];
|
||||
let precision = 0;
|
||||
let recall = 0;
|
||||
let f1Score = 0;
|
||||
let avgQueryTime = vectorStats.avgQueryTime || 0;
|
||||
let avgRelevance = 0;
|
||||
|
||||
// Get query logs from file-based logger instead of Neo4j
|
||||
try {
|
||||
// Initialize query logger if needed
|
||||
if (!queryLoggerService.isInitialized()) {
|
||||
await queryLoggerService.initialize();
|
||||
}
|
||||
|
||||
// Get the logs
|
||||
console.log('Getting query logs from file');
|
||||
queryLogs = await queryLoggerService.getQueryLogs(25);
|
||||
console.log(`Found ${queryLogs.length} query logs from file-based logger`);
|
||||
|
||||
// Calculate metrics from the query logs
|
||||
if (queryLogs.length > 0) {
|
||||
// Calculate metrics from logs with actual data
|
||||
const logsWithMetrics = queryLogs.filter(log =>
|
||||
log.metrics.avgPrecision > 0 ||
|
||||
log.metrics.avgRecall > 0 ||
|
||||
log.metrics.avgExecutionTimeMs > 0
|
||||
);
|
||||
|
||||
const logsWithRelevance = queryLogs.filter(log => log.metrics.avgRelevanceScore > 0);
|
||||
|
||||
if (logsWithMetrics.length > 0) {
|
||||
precision = logsWithMetrics.reduce((sum, log) => sum + (log.metrics.avgPrecision || 0), 0) / logsWithMetrics.length;
|
||||
recall = logsWithMetrics.reduce((sum, log) => sum + (log.metrics.avgRecall || 0), 0) / logsWithMetrics.length;
|
||||
avgQueryTime = logsWithMetrics.reduce((sum, log) => sum + (log.metrics.avgExecutionTimeMs || 0), 0) / logsWithMetrics.length;
|
||||
f1Score = precision > 0 && recall > 0 ? 2 * (precision * recall) / (precision + recall) : 0;
|
||||
}
|
||||
|
||||
if (logsWithRelevance.length > 0) {
|
||||
avgRelevance = logsWithRelevance.reduce((sum, log) => sum + (log.metrics.avgRelevanceScore || 0), 0) / logsWithRelevance.length;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Error getting query logs from file:', error);
|
||||
// Keep values at 0 instead of using defaults
|
||||
}
|
||||
|
||||
// Get top queries from real logs
|
||||
const topQueries = queryLogs.length > 0
|
||||
? queryLogs
|
||||
.sort((a, b) => b.count - a.count)
|
||||
.slice(0, 5)
|
||||
.map(log => ({
|
||||
query: log.query,
|
||||
count: log.count
|
||||
}))
|
||||
: [];
|
||||
|
||||
// Aggregate metrics
|
||||
const metrics = {
|
||||
totalTriples,
|
||||
totalEntities: uniqueEntities.size,
|
||||
avgQueryTime,
|
||||
avgRelevance: avgRelevance || vectorStats.avgRelevanceScore || 0, // Use query log relevance score, fallback to vector stats
|
||||
precision,
|
||||
recall,
|
||||
f1Score,
|
||||
topQueries,
|
||||
// Add metadata about query logs
|
||||
queryLogStats: {
|
||||
totalQueryLogs: queryLogs.length,
|
||||
totalExecutions: queryLogs.reduce((sum, log) => sum + log.executionCount, 0),
|
||||
lastQueriedAt: queryLogs.length > 0 ? queryLogs[0].lastQueried : null
|
||||
}
|
||||
};
|
||||
|
||||
return NextResponse.json(metrics);
|
||||
} catch (error) {
|
||||
console.error('Error fetching metrics:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json({ error: errorMessage }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Function to calculate precision and recall has been replaced by real data from query logs
|
||||
*/
|
||||
@ -0,0 +1,46 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
|
||||
/**
|
||||
* Legacy Neo4j disconnect endpoint - redirects to the new graph-db/disconnect endpoint
|
||||
* @deprecated Use /api/graph-db/disconnect instead with type=neo4j
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
console.log('Redirecting from deprecated /api/neo4j/disconnect to /api/graph-db/disconnect?type=neo4j');
|
||||
|
||||
// Create the new URL with the neo4j type parameter
|
||||
const url = new URL(request.url);
|
||||
const newUrl = new URL('/api/graph-db/disconnect', url.origin);
|
||||
|
||||
// Copy all query parameters
|
||||
url.searchParams.forEach((value, key) => {
|
||||
newUrl.searchParams.append(key, value);
|
||||
});
|
||||
|
||||
// Add Neo4j type parameter if not present
|
||||
if (!newUrl.searchParams.has('type')) {
|
||||
newUrl.searchParams.append('type', 'neo4j');
|
||||
}
|
||||
|
||||
// Clone the request with the new URL
|
||||
const newRequest = new Request(newUrl, {
|
||||
method: request.method,
|
||||
headers: request.headers,
|
||||
body: request.body,
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
integrity: request.integrity,
|
||||
keepalive: request.keepalive,
|
||||
mode: request.mode,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
signal: request.signal,
|
||||
duplex: 'half',
|
||||
} as RequestInit);
|
||||
|
||||
// Fetch from the new endpoint
|
||||
const response = await fetch(newRequest);
|
||||
|
||||
// Return the response
|
||||
return response;
|
||||
}
|
||||
105
nvidia/txt2kg/assets/frontend/app/api/neo4j/route.ts
Normal file
105
nvidia/txt2kg/assets/frontend/app/api/neo4j/route.ts
Normal file
@ -0,0 +1,105 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { Neo4jService } from '@/lib/neo4j';
|
||||
|
||||
// Initialize Neo4j service
|
||||
const neo4jService = Neo4jService.getInstance();
|
||||
|
||||
// Initialize connection on first request
|
||||
let isInitialized = false;
|
||||
|
||||
/**
|
||||
* Initialize Neo4j connection if not already initialized
|
||||
* @param request Optional request containing connection parameters
|
||||
*/
|
||||
function ensureConnection(request?: NextRequest) {
|
||||
try {
|
||||
let uri = process.env.NEO4J_URI;
|
||||
let username = process.env.NEO4J_USER;
|
||||
let password = process.env.NEO4J_PASSWORD;
|
||||
|
||||
// Override with URL parameters if provided
|
||||
if (request) {
|
||||
const params = request.nextUrl.searchParams;
|
||||
if (params.has('url')) uri = params.get('url') as string;
|
||||
if (params.has('username')) username = params.get('username') as string;
|
||||
if (params.has('password')) password = params.get('password') as string;
|
||||
}
|
||||
|
||||
// Connect to Neo4j instance
|
||||
neo4jService.initialize(uri, username, password);
|
||||
isInitialized = true;
|
||||
} catch (error) {
|
||||
console.error('Failed to initialize Neo4j connection:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Legacy Neo4j endpoint - redirects to the new graph-db endpoint
|
||||
* @deprecated Use /api/graph-db instead with type=neo4j
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
console.log('Redirecting from deprecated /api/neo4j to /api/graph-db?type=neo4j');
|
||||
|
||||
// Create the new URL with the same query parameters
|
||||
const url = new URL(request.url);
|
||||
const newUrl = new URL('/api/graph-db', url.origin);
|
||||
|
||||
// Copy all query parameters
|
||||
url.searchParams.forEach((value, key) => {
|
||||
newUrl.searchParams.append(key, value);
|
||||
});
|
||||
|
||||
// Add Neo4j type parameter if not present
|
||||
if (!newUrl.searchParams.has('type')) {
|
||||
newUrl.searchParams.append('type', 'neo4j');
|
||||
}
|
||||
|
||||
// Return a redirect response
|
||||
return NextResponse.redirect(newUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Legacy Neo4j POST endpoint - redirects to the new graph-db endpoint with a type parameter
|
||||
* @deprecated Use /api/graph-db instead with type=neo4j
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
console.log('Redirecting from deprecated /api/neo4j to /api/graph-db?type=neo4j');
|
||||
|
||||
// Create the new URL with the neo4j type parameter
|
||||
const url = new URL(request.url);
|
||||
const newUrl = new URL('/api/graph-db', url.origin);
|
||||
|
||||
// Copy all query parameters
|
||||
url.searchParams.forEach((value, key) => {
|
||||
newUrl.searchParams.append(key, value);
|
||||
});
|
||||
|
||||
// Add Neo4j type parameter if not present
|
||||
if (!newUrl.searchParams.has('type')) {
|
||||
newUrl.searchParams.append('type', 'neo4j');
|
||||
}
|
||||
|
||||
// Clone the request with the new URL
|
||||
const newRequest = new Request(newUrl, {
|
||||
method: request.method,
|
||||
headers: request.headers,
|
||||
body: request.body,
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
integrity: request.integrity,
|
||||
keepalive: request.keepalive,
|
||||
mode: request.mode,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
signal: request.signal,
|
||||
duplex: 'half',
|
||||
} as RequestInit);
|
||||
|
||||
// Fetch from the new endpoint
|
||||
const response = await fetch(newRequest);
|
||||
|
||||
// Return the response
|
||||
return response;
|
||||
}
|
||||
71
nvidia/txt2kg/assets/frontend/app/api/neo4j/triples/route.ts
Normal file
71
nvidia/txt2kg/assets/frontend/app/api/neo4j/triples/route.ts
Normal file
@ -0,0 +1,71 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
|
||||
/**
|
||||
* Legacy Neo4j triples endpoint - redirects to the new graph-db/triples endpoint
|
||||
* @deprecated Use /api/graph-db/triples instead with type=neo4j
|
||||
*/
|
||||
export async function GET(req: NextRequest) {
|
||||
console.log('Redirecting from deprecated /api/neo4j/triples to /api/graph-db/triples?type=neo4j');
|
||||
|
||||
// Create the new URL with the same query parameters
|
||||
const url = new URL(req.url);
|
||||
const newUrl = new URL('/api/graph-db/triples', url.origin);
|
||||
|
||||
// Copy all query parameters
|
||||
url.searchParams.forEach((value, key) => {
|
||||
newUrl.searchParams.append(key, value);
|
||||
});
|
||||
|
||||
// Add Neo4j type parameter if not present
|
||||
if (!newUrl.searchParams.has('type')) {
|
||||
newUrl.searchParams.append('type', 'neo4j');
|
||||
}
|
||||
|
||||
// Return a redirect response
|
||||
return NextResponse.redirect(newUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Legacy Neo4j triples POST endpoint - redirects to the new graph-db/triples endpoint
|
||||
* @deprecated Use /api/graph-db/triples instead with type=neo4j
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
console.log('Redirecting from deprecated /api/neo4j/triples to /api/graph-db/triples?type=neo4j');
|
||||
|
||||
// Create the new URL with the neo4j type parameter
|
||||
const url = new URL(req.url);
|
||||
const newUrl = new URL('/api/graph-db/triples', url.origin);
|
||||
|
||||
// Copy all query parameters
|
||||
url.searchParams.forEach((value, key) => {
|
||||
newUrl.searchParams.append(key, value);
|
||||
});
|
||||
|
||||
// Add Neo4j type parameter if not present
|
||||
if (!newUrl.searchParams.has('type')) {
|
||||
newUrl.searchParams.append('type', 'neo4j');
|
||||
}
|
||||
|
||||
// Clone the request with the new URL
|
||||
const newRequest = new Request(newUrl, {
|
||||
method: req.method,
|
||||
headers: req.headers,
|
||||
body: req.body,
|
||||
cache: req.cache,
|
||||
credentials: req.credentials,
|
||||
integrity: req.integrity,
|
||||
keepalive: req.keepalive,
|
||||
mode: req.mode,
|
||||
redirect: req.redirect,
|
||||
referrer: req.referrer,
|
||||
referrerPolicy: req.referrerPolicy,
|
||||
signal: req.signal,
|
||||
duplex: 'half',
|
||||
} as RequestInit);
|
||||
|
||||
// Fetch from the new endpoint
|
||||
const response = await fetch(newRequest);
|
||||
|
||||
// Return the response
|
||||
return response;
|
||||
}
|
||||
184
nvidia/txt2kg/assets/frontend/app/api/ollama/batch/route.ts
Normal file
184
nvidia/txt2kg/assets/frontend/app/api/ollama/batch/route.ts
Normal file
@ -0,0 +1,184 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { llmService, LLMMessage } from '@/lib/llm-service';
|
||||
|
||||
/**
|
||||
* API endpoint for batch Ollama operations
|
||||
* POST /api/ollama/batch - Process multiple texts in batch for triple extraction
|
||||
*/
|
||||
|
||||
interface BatchTripleRequest {
|
||||
texts: string[];
|
||||
model?: string;
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
concurrency?: number;
|
||||
}
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
const {
|
||||
texts,
|
||||
model = 'qwen3:1.7b',
|
||||
temperature = 0.1,
|
||||
maxTokens = 8192,
|
||||
concurrency = 5
|
||||
}: BatchTripleRequest = await req.json();
|
||||
|
||||
if (!texts || !Array.isArray(texts) || texts.length === 0) {
|
||||
return NextResponse.json({
|
||||
error: 'Texts array is required and must not be empty'
|
||||
}, { status: 400 });
|
||||
}
|
||||
|
||||
if (texts.length > 100) {
|
||||
return NextResponse.json({
|
||||
error: 'Batch size limited to 100 texts maximum'
|
||||
}, { status: 400 });
|
||||
}
|
||||
|
||||
// Validate all texts are strings
|
||||
const invalidTexts = texts.filter(text => !text || typeof text !== 'string');
|
||||
if (invalidTexts.length > 0) {
|
||||
return NextResponse.json({
|
||||
error: `Invalid texts found at indices: ${texts.map((text, i) =>
|
||||
(!text || typeof text !== 'string') ? i : null
|
||||
).filter(i => i !== null).join(', ')}`
|
||||
}, { status: 400 });
|
||||
}
|
||||
|
||||
console.log(`Starting batch triple extraction for ${texts.length} texts using model ${model}`);
|
||||
|
||||
// Create system prompt for triple extraction
|
||||
const systemPrompt = `You are a knowledge graph builder that extracts structured information from text.
|
||||
Extract subject-predicate-object triples from the following text.
|
||||
|
||||
Guidelines:
|
||||
- Extract only factual triples present in the text
|
||||
- Normalize entity names to their canonical form
|
||||
- Return results in JSON format as an array of objects with "subject", "predicate", "object" fields
|
||||
- Each triple should represent a clear relationship between two entities
|
||||
- Focus on the most important relationships in the text`;
|
||||
|
||||
// Prepare batch messages
|
||||
const messagesBatch: LLMMessage[][] = texts.map(text => [
|
||||
{
|
||||
role: 'system' as const,
|
||||
content: systemPrompt
|
||||
},
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: `Extract triples from this text:\n\n${text}`
|
||||
}
|
||||
]);
|
||||
|
||||
// Process batch with Ollama
|
||||
const batchResult = await llmService.generateOllamaBatchCompletion(
|
||||
model,
|
||||
messagesBatch,
|
||||
{ temperature, maxTokens, concurrency }
|
||||
);
|
||||
|
||||
// Parse responses to extract triples
|
||||
const processedResults = batchResult.results.map((response, index) => {
|
||||
let triples = [];
|
||||
|
||||
if (response) {
|
||||
try {
|
||||
// Try to parse as JSON first
|
||||
const jsonMatch = response.match(/\[[\s\S]*\]/);
|
||||
if (jsonMatch) {
|
||||
triples = JSON.parse(jsonMatch[0]);
|
||||
} else {
|
||||
// Fallback: parse line by line
|
||||
triples = parseTriplesFallback(response);
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.warn(`Failed to parse response for text ${index}:`, parseError);
|
||||
triples = parseTriplesFallback(response);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
textIndex: index,
|
||||
originalText: texts[index].substring(0, 200) + (texts[index].length > 200 ? '...' : ''),
|
||||
triples: triples.map((triple: any, tripleIndex: number) => ({
|
||||
...triple,
|
||||
confidence: 0.8, // Default confidence for Ollama extractions
|
||||
metadata: {
|
||||
entityTypes: [],
|
||||
source: texts[index].substring(0, 100) + '...',
|
||||
context: texts[index].substring(0, 200) + '...',
|
||||
extractionMethod: 'ollama_batch',
|
||||
model: model,
|
||||
textIndex: index,
|
||||
tripleIndex: tripleIndex
|
||||
}
|
||||
})),
|
||||
tripleCount: triples.length,
|
||||
success: !batchResult.errors.some(error => error.index === index)
|
||||
};
|
||||
});
|
||||
|
||||
// Calculate summary statistics
|
||||
const totalTriples = processedResults.reduce((sum, result) => sum + result.tripleCount, 0);
|
||||
const successfulTexts = processedResults.filter(result => result.success).length;
|
||||
|
||||
return NextResponse.json({
|
||||
results: processedResults,
|
||||
summary: {
|
||||
totalTexts: texts.length,
|
||||
successfulTexts: successfulTexts,
|
||||
failedTexts: batchResult.errors.length,
|
||||
totalTriples: totalTriples,
|
||||
averageTriples: successfulTexts > 0 ? (totalTriples / successfulTexts).toFixed(2) : 0
|
||||
},
|
||||
batchInfo: {
|
||||
model: model,
|
||||
concurrency: concurrency,
|
||||
processingTime: Date.now(), // Could be enhanced with actual timing
|
||||
method: 'ollama_batch'
|
||||
},
|
||||
errors: batchResult.errors,
|
||||
success: true
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error in Ollama batch triple extraction:', error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Failed to process batch triple extraction with Ollama',
|
||||
details: error instanceof Error ? error.message : String(error)
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback parser for when JSON parsing fails (reused from single endpoint)
|
||||
function parseTriplesFallback(text: string): Array<{subject: string, predicate: string, object: string}> {
|
||||
const triples = [];
|
||||
const lines = text.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
// Look for patterns like "Subject - Predicate - Object" or similar
|
||||
const tripleMatch = line.match(/^[\s\-\*\d\.]*(.+?)\s*[\-\|]\s*(.+?)\s*[\-\|]\s*(.+)$/);
|
||||
if (tripleMatch) {
|
||||
triples.push({
|
||||
subject: tripleMatch[1].trim(),
|
||||
predicate: tripleMatch[2].trim(),
|
||||
object: tripleMatch[3].trim()
|
||||
});
|
||||
}
|
||||
|
||||
// Also look for JSON-like objects in the text
|
||||
const jsonObjectMatch = line.match(/\{\s*"subject"\s*:\s*"([^"]+)"\s*,\s*"predicate"\s*:\s*"([^"]+)"\s*,\s*"object"\s*:\s*"([^"]+)"\s*\}/);
|
||||
if (jsonObjectMatch) {
|
||||
triples.push({
|
||||
subject: jsonObjectMatch[1],
|
||||
predicate: jsonObjectMatch[2],
|
||||
object: jsonObjectMatch[3]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return triples;
|
||||
}
|
||||
160
nvidia/txt2kg/assets/frontend/app/api/ollama/route.ts
Normal file
160
nvidia/txt2kg/assets/frontend/app/api/ollama/route.ts
Normal file
@ -0,0 +1,160 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { llmService } from '@/lib/llm-service';
|
||||
|
||||
// Configure route for dynamic operations and long-running requests
|
||||
export const dynamic = 'force-dynamic';
|
||||
export const maxDuration = 1800; // 30 minutes for large model processing
|
||||
|
||||
/**
|
||||
* API endpoint for Ollama-specific operations
|
||||
* GET /api/ollama - Test connection and list available models
|
||||
* POST /api/ollama/extract-triples - Extract triples using Ollama model
|
||||
*/
|
||||
|
||||
export async function GET(req: NextRequest) {
|
||||
try {
|
||||
const { searchParams } = new URL(req.url);
|
||||
const action = searchParams.get('action');
|
||||
|
||||
if (action === 'test-connection') {
|
||||
const result = await llmService.testOllamaConnection();
|
||||
return NextResponse.json(result);
|
||||
}
|
||||
|
||||
// Default: test connection and return models
|
||||
const result = await llmService.testOllamaConnection();
|
||||
return NextResponse.json(result);
|
||||
} catch (error) {
|
||||
console.error('Error in Ollama API:', error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Failed to connect to Ollama server',
|
||||
details: error instanceof Error ? error.message : String(error)
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
const startTime = Date.now();
|
||||
console.log(`[${new Date().toISOString()}] /api/ollama: POST request received`);
|
||||
|
||||
try {
|
||||
const { text, model = 'qwen3:1.7b', temperature = 0.1, maxTokens = 8192 } = await req.json();
|
||||
console.log(`[${new Date().toISOString()}] /api/ollama: Parsed body - model: ${model}, text length: ${text?.length || 0}, maxTokens: ${maxTokens}`);
|
||||
|
||||
if (!text || typeof text !== 'string') {
|
||||
return NextResponse.json({ error: 'Text is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Use the LLM service to generate completion with Ollama
|
||||
const messages = [
|
||||
{
|
||||
role: 'system' as const,
|
||||
content: `You are a knowledge graph builder that extracts structured information from text.
|
||||
Extract subject-predicate-object triples from the following text.
|
||||
|
||||
Guidelines:
|
||||
- Extract only factual triples present in the text
|
||||
- Normalize entity names to their canonical form
|
||||
- Return results in JSON format as an array of objects with "subject", "predicate", "object" fields
|
||||
- Each triple should represent a clear relationship between two entities
|
||||
- Focus on the most important relationships in the text`
|
||||
},
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: `Extract triples from this text:\n\n${text}`
|
||||
}
|
||||
];
|
||||
|
||||
console.log(`[${new Date().toISOString()}] /api/ollama: Calling llmService.generateOllamaCompletion with model: ${model}`);
|
||||
const llmStartTime = Date.now();
|
||||
|
||||
const response = await llmService.generateOllamaCompletion(
|
||||
model,
|
||||
messages,
|
||||
{ temperature, maxTokens }
|
||||
);
|
||||
|
||||
const llmDuration = ((Date.now() - llmStartTime) / 1000).toFixed(2);
|
||||
console.log(`[${new Date().toISOString()}] /api/ollama: LLM completion received after ${llmDuration}s, response length: ${response?.length || 0}`);
|
||||
|
||||
// Parse the response to extract triples
|
||||
let triples = [];
|
||||
try {
|
||||
// Try to parse as JSON first
|
||||
const jsonMatch = response.match(/\[[\s\S]*\]/);
|
||||
if (jsonMatch) {
|
||||
triples = JSON.parse(jsonMatch[0]);
|
||||
} else {
|
||||
// Fallback: parse line by line
|
||||
triples = parseTriplesFallback(response);
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.warn('Failed to parse JSON response, using fallback parser:', parseError);
|
||||
triples = parseTriplesFallback(response);
|
||||
}
|
||||
|
||||
const totalDuration = ((Date.now() - startTime) / 1000).toFixed(2);
|
||||
console.log(`[${new Date().toISOString()}] /api/ollama: Returning ${triples.length} triples, total duration: ${totalDuration}s`);
|
||||
|
||||
return NextResponse.json({
|
||||
triples: triples.map((triple, index) => ({
|
||||
...triple,
|
||||
confidence: 0.8, // Default confidence for Ollama extractions
|
||||
metadata: {
|
||||
entityTypes: [],
|
||||
source: text.substring(0, 100) + '...',
|
||||
context: text.substring(0, 200) + '...',
|
||||
extractionMethod: 'ollama',
|
||||
model: model
|
||||
}
|
||||
})),
|
||||
count: triples.length,
|
||||
success: true,
|
||||
method: 'ollama',
|
||||
model: model
|
||||
});
|
||||
} catch (error) {
|
||||
const totalDuration = ((Date.now() - startTime) / 1000).toFixed(2);
|
||||
console.error(`[${new Date().toISOString()}] /api/ollama: Error after ${totalDuration}s:`, error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Failed to extract triples with Ollama',
|
||||
details: error instanceof Error ? error.message : String(error)
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback parser for when JSON parsing fails
|
||||
function parseTriplesFallback(text: string): Array<{subject: string, predicate: string, object: string}> {
|
||||
const triples = [];
|
||||
const lines = text.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
// Look for patterns like "Subject - Predicate - Object" or similar
|
||||
const tripleMatch = line.match(/^[\s\-\*\d\.]*(.+?)\s*[\-\|]\s*(.+?)\s*[\-\|]\s*(.+)$/);
|
||||
if (tripleMatch) {
|
||||
triples.push({
|
||||
subject: tripleMatch[1].trim(),
|
||||
predicate: tripleMatch[2].trim(),
|
||||
object: tripleMatch[3].trim()
|
||||
});
|
||||
}
|
||||
|
||||
// Also look for JSON-like objects in the text
|
||||
const jsonObjectMatch = line.match(/\{\s*"subject"\s*:\s*"([^"]+)"\s*,\s*"predicate"\s*:\s*"([^"]+)"\s*,\s*"object"\s*:\s*"([^"]+)"\s*\}/);
|
||||
if (jsonObjectMatch) {
|
||||
triples.push({
|
||||
subject: jsonObjectMatch[1],
|
||||
predicate: jsonObjectMatch[2],
|
||||
object: jsonObjectMatch[3]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return triples;
|
||||
}
|
||||
82
nvidia/txt2kg/assets/frontend/app/api/ollama/test/route.ts
Normal file
82
nvidia/txt2kg/assets/frontend/app/api/ollama/test/route.ts
Normal file
@ -0,0 +1,82 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
|
||||
/**
|
||||
* Test endpoint for Ollama integration
|
||||
* GET /api/ollama/test - Test Ollama functionality with sample data
|
||||
*/
|
||||
|
||||
export async function GET(req: NextRequest) {
|
||||
try {
|
||||
const sampleText = `
|
||||
Apple Inc. is a multinational technology company headquartered in Cupertino, California.
|
||||
The company was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in 1976.
|
||||
Apple designs and develops consumer electronics, computer software, and online services.
|
||||
Tim Cook is the current CEO of Apple Inc.
|
||||
`;
|
||||
|
||||
console.log('Testing Ollama with sample text...');
|
||||
|
||||
// Test connection first
|
||||
const connectionResponse = await fetch(`${req.nextUrl.origin}/api/ollama?action=test-connection`);
|
||||
const connectionResult = await connectionResponse.json();
|
||||
|
||||
if (!connectionResult.connected) {
|
||||
return NextResponse.json({
|
||||
success: false,
|
||||
error: 'Ollama connection failed',
|
||||
details: connectionResult.error,
|
||||
connectionTest: connectionResult
|
||||
});
|
||||
}
|
||||
|
||||
// Test triple extraction
|
||||
const extractionResponse = await fetch(`${req.nextUrl.origin}/api/ollama`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
text: sampleText.trim(),
|
||||
model: 'qwen3:1.7b',
|
||||
temperature: 0.1,
|
||||
maxTokens: 1024
|
||||
})
|
||||
});
|
||||
|
||||
if (!extractionResponse.ok) {
|
||||
const errorText = await extractionResponse.text();
|
||||
return NextResponse.json({
|
||||
success: false,
|
||||
error: 'Triple extraction failed',
|
||||
details: errorText,
|
||||
connectionTest: connectionResult
|
||||
});
|
||||
}
|
||||
|
||||
const extractionResult = await extractionResponse.json();
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: 'Ollama integration test completed successfully',
|
||||
connectionTest: connectionResult,
|
||||
extractionTest: {
|
||||
inputText: sampleText.trim(),
|
||||
triplesExtracted: extractionResult.triples?.length || 0,
|
||||
sampleTriples: (extractionResult.triples || []).slice(0, 3),
|
||||
method: extractionResult.method,
|
||||
model: extractionResult.model
|
||||
},
|
||||
fullResult: extractionResult
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error in Ollama test:', error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: 'Test failed with exception',
|
||||
details: error instanceof Error ? error.message : String(error)
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,27 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PineconeService } from '@/lib/pinecone';
|
||||
|
||||
/**
|
||||
* Clear all data from the Pinecone vector database
|
||||
* POST /api/pinecone-diag/clear
|
||||
*/
|
||||
export async function POST() {
|
||||
// Get the Pinecone service instance
|
||||
const pineconeService = PineconeService.getInstance();
|
||||
|
||||
// Clear all vectors from the database
|
||||
const deleteSuccess = await pineconeService.deleteAllEntities();
|
||||
|
||||
// Get updated stats after clearing
|
||||
const stats = await pineconeService.getStats();
|
||||
|
||||
// Return response based on operation success
|
||||
return NextResponse.json({
|
||||
success: deleteSuccess,
|
||||
message: deleteSuccess
|
||||
? 'Successfully cleared all data from Pinecone vector database'
|
||||
: 'Failed to clear Pinecone database - service may not be available',
|
||||
totalVectorCount: stats.totalVectorCount || 0,
|
||||
httpHealthy: stats.httpHealthy || false
|
||||
});
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PineconeService } from '@/lib/pinecone';
|
||||
|
||||
/**
|
||||
* Create Pinecone index API endpoint
|
||||
* POST /api/pinecone-diag/create-index
|
||||
*/
|
||||
export async function POST() {
|
||||
try {
|
||||
// Get the Pinecone service instance
|
||||
const pineconeService = PineconeService.getInstance();
|
||||
|
||||
// Force re-initialization to create the index
|
||||
(pineconeService as any).initialized = false;
|
||||
await pineconeService.initialize();
|
||||
|
||||
// Check if initialization was successful by getting stats
|
||||
const stats = await pineconeService.getStats();
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: 'Pinecone index created successfully',
|
||||
httpHealthy: stats.httpHealthy || false
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error creating Pinecone index:', error);
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: `Failed to create Pinecone index: ${error instanceof Error ? error.message : String(error)}`
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,42 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PineconeService } from '@/lib/pinecone';
|
||||
|
||||
/**
|
||||
* Get Pinecone vector database stats
|
||||
*/
|
||||
export async function GET() {
|
||||
try {
|
||||
// Initialize Pinecone service
|
||||
const pineconeService = PineconeService.getInstance();
|
||||
|
||||
// We can now directly call getStats() which handles initialization and error recovery
|
||||
const stats = await pineconeService.getStats();
|
||||
|
||||
return NextResponse.json({
|
||||
...stats,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error getting Pinecone stats:', error);
|
||||
|
||||
// Return a successful response with error information
|
||||
// This prevents the UI from breaking when Pinecone is unavailable
|
||||
let errorMessage = error instanceof Error ? error.message : String(error);
|
||||
|
||||
// More specific error message for 404 errors
|
||||
if (errorMessage.includes('404')) {
|
||||
errorMessage = 'Pinecone server returned 404. The server may not be running or the index does not exist.';
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Failed to get Pinecone stats: ${errorMessage}`,
|
||||
totalVectorCount: 0,
|
||||
source: 'error',
|
||||
httpHealthy: false,
|
||||
timestamp: new Date().toISOString()
|
||||
},
|
||||
{ status: 200 } // Use 200 instead of 500 to avoid UI errors
|
||||
);
|
||||
}
|
||||
}
|
||||
154
nvidia/txt2kg/assets/frontend/app/api/process-document/route.ts
Normal file
154
nvidia/txt2kg/assets/frontend/app/api/process-document/route.ts
Normal file
@ -0,0 +1,154 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { RemoteBackendService } from '@/lib/remote-backend';
|
||||
import { EmbeddingsService } from '@/lib/embeddings';
|
||||
import type { Triple } from '@/types/graph';
|
||||
import { BackendService } from '@/lib/backend-service';
|
||||
import { getGraphDbType } from '../settings/route';
|
||||
|
||||
/**
|
||||
* API endpoint for processing documents with LangChain, generating embeddings,
|
||||
* and storing in the knowledge graph
|
||||
* POST /api/process-document
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
// Parse request body
|
||||
const body = await req.json();
|
||||
const {
|
||||
text,
|
||||
filename,
|
||||
triples,
|
||||
useLangChain,
|
||||
useGraphTransformer,
|
||||
systemPrompt,
|
||||
extractionPrompt,
|
||||
graphTransformerPrompt
|
||||
} = body;
|
||||
|
||||
if (!text || typeof text !== 'string') {
|
||||
return NextResponse.json({ error: 'Text is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
if (!triples || !Array.isArray(triples)) {
|
||||
return NextResponse.json({ error: 'Triples are required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Initialize services
|
||||
const backendService = RemoteBackendService.getInstance();
|
||||
const embeddingsService = EmbeddingsService.getInstance();
|
||||
|
||||
console.log(`🔍 API: Processing document "${filename || 'unnamed'}" (${text.length} chars)`);
|
||||
console.log(`🔍 API: Processing ${triples.length} triples`);
|
||||
console.log(`🔍 API: Using LangChain for triple extraction: ${useLangChain ? 'Yes' : 'No'}`);
|
||||
console.log(`🔍 API: First few triples:`, triples.slice(0, 3));
|
||||
if (useLangChain) {
|
||||
console.log(`Using LLMGraphTransformer: ${useGraphTransformer ? 'Yes' : 'No'}`);
|
||||
}
|
||||
|
||||
// Log if custom prompts are being used
|
||||
if (systemPrompt || extractionPrompt || graphTransformerPrompt) {
|
||||
console.log('Using custom prompts for extraction');
|
||||
if (systemPrompt) console.log('Custom system prompt provided');
|
||||
if (extractionPrompt) console.log('Custom extraction prompt provided');
|
||||
if (graphTransformerPrompt) console.log('Custom graph transformer prompt provided');
|
||||
}
|
||||
|
||||
// Filter triples to ensure they are valid
|
||||
const validTriples = triples.filter((triple: any) => {
|
||||
return (
|
||||
triple &&
|
||||
typeof triple.subject === 'string' && triple.subject.trim() !== '' &&
|
||||
typeof triple.predicate === 'string' && triple.predicate.trim() !== '' &&
|
||||
typeof triple.object === 'string' && triple.object.trim() !== ''
|
||||
);
|
||||
}) as Triple[];
|
||||
|
||||
console.log(`Found ${validTriples.length} valid triples`);
|
||||
|
||||
// If useLangChain flag is set, we'll extract triples using the LangChain route
|
||||
let triplesForProcessing = validTriples;
|
||||
|
||||
if (useLangChain && !filename?.toLowerCase().endsWith('.csv')) {
|
||||
try {
|
||||
console.log('Using LangChain for native triple extraction...');
|
||||
// Use absolute URL with origin from request to fix URL parsing error
|
||||
const baseUrl = new URL(req.url).origin;
|
||||
console.log(`Using base URL: ${baseUrl} for LangChain API call`);
|
||||
|
||||
// Call the extract-triples endpoint with useLangChain flag and custom prompts
|
||||
const requestBody: any = {
|
||||
text,
|
||||
useLangChain: true,
|
||||
useGraphTransformer
|
||||
};
|
||||
|
||||
// Add custom prompts if available
|
||||
if (systemPrompt) requestBody.systemPrompt = systemPrompt;
|
||||
if (extractionPrompt) requestBody.extractionPrompt = extractionPrompt;
|
||||
if (graphTransformerPrompt) requestBody.graphTransformerPrompt = graphTransformerPrompt;
|
||||
|
||||
const langchainResponse = await fetch(`${baseUrl}/api/extract-triples`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(requestBody)
|
||||
});
|
||||
|
||||
if (!langchainResponse.ok) {
|
||||
const errorText = await langchainResponse.text();
|
||||
console.error(`LangChain API error: ${langchainResponse.status} ${langchainResponse.statusText}`, errorText);
|
||||
throw new Error(`LangChain extraction failed: ${langchainResponse.statusText} (${langchainResponse.status})`);
|
||||
}
|
||||
|
||||
const langchainResult = await langchainResponse.json();
|
||||
if (langchainResult.triples && Array.isArray(langchainResult.triples) && langchainResult.triples.length > 0) {
|
||||
console.log(`Successfully extracted ${langchainResult.triples.length} triples using LangChain${useGraphTransformer ? ' with GraphTransformer' : ''}`);
|
||||
triplesForProcessing = langchainResult.triples;
|
||||
} else {
|
||||
console.warn('LangChain extraction returned no triples, falling back to provided triples');
|
||||
}
|
||||
} catch (langchainError) {
|
||||
console.error('Error using LangChain for triple extraction:', langchainError);
|
||||
console.log('Falling back to provided triples');
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a CSV file - if so, skip processing
|
||||
const isCSVFile = filename && filename.toLowerCase().endsWith('.csv');
|
||||
const isJSONFile = filename && filename.toLowerCase().endsWith('.json');
|
||||
|
||||
if (isCSVFile) {
|
||||
console.log('CSV file detected, skipping text processor');
|
||||
// NOTE: Neo4j storage is no longer done automatically
|
||||
// This is now handled manually through the "Store in Graph DB" button in the UI
|
||||
} else if (isJSONFile) {
|
||||
console.log('JSON file detected, processed as unstructured text document - embeddings can be generated manually via the UI');
|
||||
// NOTE: Automatic embeddings generation has been disabled for JSON files.
|
||||
// Embeddings are now generated only when explicitly requested through the "Generate Embeddings" button in the UI.
|
||||
} else {
|
||||
// Regular text processing flow - no automatic embeddings generation
|
||||
console.log('Document processed successfully - embeddings can be generated manually via the UI');
|
||||
// NOTE: Automatic embeddings generation has been disabled.
|
||||
// Embeddings are now generated only when explicitly requested through the "Generate Embeddings" button in the UI.
|
||||
}
|
||||
|
||||
// Return success response
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: 'Document processed successfully',
|
||||
tripleCount: triplesForProcessing.length,
|
||||
triples: triplesForProcessing,
|
||||
documentName: filename || 'unnamed',
|
||||
langchainUsed: useLangChain,
|
||||
graphTransformerUsed: useGraphTransformer,
|
||||
customPromptsUsed: !!(systemPrompt || extractionPrompt || graphTransformerPrompt),
|
||||
graphDbType: getGraphDbType()
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error processing document:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to process document: ${errorMessage}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,42 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
|
||||
const PYGRAPHISTRY_SERVICE_URL = process.env.PYGRAPHISTRY_SERVICE_URL || 'http://localhost:8080'
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { taskId: string } }) {
|
||||
try {
|
||||
const { taskId } = params
|
||||
|
||||
// Forward the request to the PyGraphistry service
|
||||
const response = await fetch(`${PYGRAPHISTRY_SERVICE_URL}/api/generate/${taskId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
console.error('PyGraphistry service error:', errorText)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'PyGraphistry service error',
|
||||
details: errorText
|
||||
},
|
||||
{ status: response.status }
|
||||
)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error forwarding to PyGraphistry service:', error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Failed to connect to PyGraphistry service',
|
||||
details: error instanceof Error ? error.message : 'Unknown error'
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,43 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
|
||||
const PYGRAPHISTRY_SERVICE_URL = process.env.PYGRAPHISTRY_SERVICE_URL || 'http://localhost:8080'
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const body = await request.json()
|
||||
|
||||
// Forward the request to the PyGraphistry service
|
||||
const response = await fetch(`${PYGRAPHISTRY_SERVICE_URL}/api/generate`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
console.error('PyGraphistry service error:', errorText)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'PyGraphistry service error',
|
||||
details: errorText
|
||||
},
|
||||
{ status: response.status }
|
||||
)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error forwarding to PyGraphistry service:', error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Failed to connect to PyGraphistry service',
|
||||
details: error instanceof Error ? error.message : 'Unknown error'
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,42 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
|
||||
const PYGRAPHISTRY_SERVICE_URL = process.env.PYGRAPHISTRY_SERVICE_URL || 'http://localhost:8080'
|
||||
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
// Forward the request to the PyGraphistry service
|
||||
const response = await fetch(`${PYGRAPHISTRY_SERVICE_URL}/api/health`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
console.error('PyGraphistry service health check failed:', errorText)
|
||||
return NextResponse.json(
|
||||
{
|
||||
status: 'error',
|
||||
error: 'PyGraphistry service unhealthy',
|
||||
details: errorText
|
||||
},
|
||||
{ status: response.status }
|
||||
)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error connecting to PyGraphistry service:', error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
status: 'error',
|
||||
error: 'Failed to connect to PyGraphistry service',
|
||||
details: error instanceof Error ? error.message : 'Unknown error'
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,43 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
|
||||
const PYGRAPHISTRY_SERVICE_URL = process.env.PYGRAPHISTRY_SERVICE_URL || 'http://localhost:8080'
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const body = await request.json()
|
||||
|
||||
// Forward the request to the PyGraphistry service
|
||||
const response = await fetch(`${PYGRAPHISTRY_SERVICE_URL}/api/stats`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
console.error('PyGraphistry stats service error:', errorText)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'PyGraphistry stats service error',
|
||||
details: errorText
|
||||
},
|
||||
{ status: response.status }
|
||||
)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error forwarding stats request to PyGraphistry service:', error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Failed to connect to PyGraphistry service for stats',
|
||||
details: error instanceof Error ? error.message : 'Unknown error'
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,43 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
|
||||
const PYGRAPHISTRY_SERVICE_URL = process.env.PYGRAPHISTRY_SERVICE_URL || 'http://localhost:8080'
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const body = await request.json()
|
||||
|
||||
// Forward the request to the PyGraphistry service
|
||||
const response = await fetch(`${PYGRAPHISTRY_SERVICE_URL}/api/visualize`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
console.error('PyGraphistry service error:', errorText)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'PyGraphistry service error',
|
||||
details: errorText
|
||||
},
|
||||
{ status: response.status }
|
||||
)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error forwarding to PyGraphistry service:', error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Failed to connect to PyGraphistry service',
|
||||
details: error instanceof Error ? error.message : 'Unknown error'
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
88
nvidia/txt2kg/assets/frontend/app/api/query-log/add/route.ts
Normal file
88
nvidia/txt2kg/assets/frontend/app/api/query-log/add/route.ts
Normal file
@ -0,0 +1,88 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import neo4jService from '@/lib/neo4j';
|
||||
|
||||
/**
|
||||
* Simple endpoint to directly add a query log with a high count
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
// Get the query text from URL params or use a default
|
||||
const query = request.nextUrl.searchParams.get('query') || 'How does machine learning work?';
|
||||
const count = parseInt(request.nextUrl.searchParams.get('count') || '20');
|
||||
|
||||
// Initialize Neo4j
|
||||
if (!neo4jService.isInitialized()) {
|
||||
neo4jService.initialize();
|
||||
}
|
||||
|
||||
// Execute direct Cypher query to create a query log with a high count
|
||||
const session = neo4jService.getSession();
|
||||
|
||||
try {
|
||||
const cypher = `
|
||||
MERGE (q:QueryLog {query: $query})
|
||||
ON CREATE SET
|
||||
q.firstQueried = datetime(),
|
||||
q.count = $count
|
||||
ON MATCH SET
|
||||
q.lastQueried = datetime(),
|
||||
q.count = $count
|
||||
|
||||
CREATE (e:QueryExecution {
|
||||
timestamp: datetime(),
|
||||
queryMode: 'traditional',
|
||||
executionTimeMs: 0,
|
||||
relevanceScore: 0,
|
||||
precision: 0,
|
||||
recall: 0,
|
||||
resultCount: 0
|
||||
})
|
||||
|
||||
CREATE (q)-[:HAS_EXECUTION]->(e)
|
||||
|
||||
RETURN q.query as query, q.count as count
|
||||
`;
|
||||
|
||||
const result = await session.run(cypher, {
|
||||
query,
|
||||
count
|
||||
});
|
||||
|
||||
const addedQuery = result.records.length > 0 ? {
|
||||
query: result.records[0].get('query'),
|
||||
count: result.records[0].get('count').toNumber()
|
||||
} : null;
|
||||
|
||||
// Also add a few more queries
|
||||
if (count >= 10) {
|
||||
await session.run(cypher, {
|
||||
query: 'What are the applications of artificial intelligence?',
|
||||
count: count - 4
|
||||
});
|
||||
|
||||
await session.run(cypher, {
|
||||
query: 'Explain the principles of deep learning',
|
||||
count: count - 8
|
||||
});
|
||||
}
|
||||
|
||||
// Get the current logs to verify
|
||||
const logs = await neo4jService.getQueryLogs(5);
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Added query log for "${query}" with count ${count}`,
|
||||
addedQuery,
|
||||
logs
|
||||
});
|
||||
} finally {
|
||||
session.close();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error adding query log:', error);
|
||||
return NextResponse.json({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
}, { status: 500 });
|
||||
}
|
||||
}
|
||||
94
nvidia/txt2kg/assets/frontend/app/api/query-log/route.ts
Normal file
94
nvidia/txt2kg/assets/frontend/app/api/query-log/route.ts
Normal file
@ -0,0 +1,94 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import queryLoggerService from '@/lib/query-logger';
|
||||
|
||||
/**
|
||||
* API endpoint to log query metrics
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
console.log('Received query log request:', JSON.stringify(body));
|
||||
|
||||
// Validate required fields
|
||||
if (!body.query) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Missing required field: query' },
|
||||
{ status: 400 }
|
||||
);
|
||||
}
|
||||
|
||||
if (!body.queryMode) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Missing required field: queryMode' },
|
||||
{ status: 400 }
|
||||
);
|
||||
}
|
||||
|
||||
if (!body.metrics || typeof body.metrics !== 'object') {
|
||||
return NextResponse.json(
|
||||
{ error: 'Missing required field: metrics' },
|
||||
{ status: 400 }
|
||||
);
|
||||
}
|
||||
|
||||
// Initialize logger if not already
|
||||
if (!queryLoggerService.isInitialized()) {
|
||||
console.log('Initializing query logger service');
|
||||
await queryLoggerService.initialize();
|
||||
}
|
||||
|
||||
// Log the query with metrics
|
||||
console.log(`Logging query "${body.query}" with mode "${body.queryMode}"`);
|
||||
await queryLoggerService.logQuery(
|
||||
body.query,
|
||||
body.queryMode,
|
||||
{
|
||||
executionTimeMs: body.metrics.executionTimeMs || 0,
|
||||
relevanceScore: body.metrics.relevanceScore,
|
||||
precision: body.metrics.precision,
|
||||
recall: body.metrics.recall,
|
||||
resultCount: body.metrics.resultCount || 0
|
||||
}
|
||||
);
|
||||
|
||||
console.log('Query logged successfully to file');
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
console.error('Error logging query:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: errorMessage },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* API endpoint to get query logs
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
// Initialize logger if not already
|
||||
if (!queryLoggerService.isInitialized()) {
|
||||
console.log('Initializing query logger service for retrieving logs');
|
||||
await queryLoggerService.initialize();
|
||||
}
|
||||
|
||||
// Get limit from query params or default to 25
|
||||
const limit = parseInt(request.nextUrl.searchParams.get('limit') || '25');
|
||||
console.log(`Retrieving up to ${limit} query logs`);
|
||||
|
||||
// Get query logs
|
||||
const logs = await queryLoggerService.getQueryLogs(limit);
|
||||
console.log(`Retrieved ${logs.length} query logs from file`);
|
||||
|
||||
return NextResponse.json({ logs });
|
||||
} catch (error) {
|
||||
console.error('Error getting query logs:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: errorMessage },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,55 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import neo4jService from '@/lib/neo4j';
|
||||
|
||||
/**
|
||||
* API endpoint to create a test query log
|
||||
* This is for debugging purposes only
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
console.log('[Test] Creating test query log');
|
||||
|
||||
// Initialize Neo4j if not already
|
||||
if (!neo4jService.isInitialized()) {
|
||||
console.log('[Test] Initializing Neo4j service');
|
||||
neo4jService.initialize();
|
||||
}
|
||||
|
||||
// Get query text from URL parameters or use a default
|
||||
const query = request.nextUrl.searchParams.get('query') || 'Test query for debugging';
|
||||
const queryMode = (request.nextUrl.searchParams.get('mode') || 'traditional') as 'traditional' | 'vector-search' | 'pure-rag';
|
||||
const executionTime = parseInt(request.nextUrl.searchParams.get('time') || '300');
|
||||
const resultCount = parseInt(request.nextUrl.searchParams.get('count') || '5');
|
||||
|
||||
console.log(`[Test] Adding test query: "${query}" (${queryMode})`);
|
||||
|
||||
// Log the query with some test metrics
|
||||
await neo4jService.logQuery(
|
||||
query,
|
||||
queryMode,
|
||||
{
|
||||
executionTimeMs: executionTime,
|
||||
relevanceScore: 0,
|
||||
precision: 0,
|
||||
recall: 0,
|
||||
resultCount: resultCount
|
||||
}
|
||||
);
|
||||
|
||||
// Get current query logs to verify
|
||||
const logs = await neo4jService.getQueryLogs(10);
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Test query "${query}" added successfully`,
|
||||
logs: logs.slice(0, 3) // Return top 3 logs for verification
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[Test] Error creating test query log:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: errorMessage },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
53
nvidia/txt2kg/assets/frontend/app/api/query/route.ts
Normal file
53
nvidia/txt2kg/assets/frontend/app/api/query/route.ts
Normal file
@ -0,0 +1,53 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import backendService from '@/lib/backend-service';
|
||||
import type { Triple } from '@/types/graph';
|
||||
import { getGraphDbType } from '../settings/route';
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const { query, triples, kNeighbors, fanout, numHops, useTraditional, queryMode } = await request.json();
|
||||
|
||||
if (!query) {
|
||||
return NextResponse.json({ error: 'Query is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Initialize backend if needed with the selected graph DB type
|
||||
if (!backendService.isInitialized) {
|
||||
const graphDbType = getGraphDbType();
|
||||
console.log(`Initializing backend with graph DB type: ${graphDbType}`);
|
||||
await backendService.initialize(graphDbType);
|
||||
}
|
||||
|
||||
// Process triples if provided
|
||||
if (triples && Array.isArray(triples) && triples.length > 0) {
|
||||
await backendService.processTriples(triples);
|
||||
}
|
||||
|
||||
// Determine if we should use traditional search based on queryMode
|
||||
// This allows the frontend to explicitly choose traditional search
|
||||
const shouldUseTraditional = useTraditional || (queryMode === 'traditional');
|
||||
|
||||
console.log(`Query mode: ${queryMode}, Using traditional search: ${shouldUseTraditional}`);
|
||||
|
||||
// Query the backend
|
||||
const relevantTriples = await backendService.query(
|
||||
query,
|
||||
kNeighbors || 4096,
|
||||
fanout || 400,
|
||||
numHops || 2,
|
||||
shouldUseTraditional // Pass the flag to use traditional search
|
||||
);
|
||||
|
||||
// Return results
|
||||
return NextResponse.json({
|
||||
query,
|
||||
relevantTriples,
|
||||
count: relevantTriples.length,
|
||||
message: `Found ${relevantTriples.length} relevant triples for query: "${query}"${shouldUseTraditional ? ' using traditional search' : ''}`
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error querying backend:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json({ error: errorMessage }, { status: 500 });
|
||||
}
|
||||
}
|
||||
45
nvidia/txt2kg/assets/frontend/app/api/rag-query/route.ts
Normal file
45
nvidia/txt2kg/assets/frontend/app/api/rag-query/route.ts
Normal file
@ -0,0 +1,45 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import RAGService from '@/lib/rag';
|
||||
|
||||
/**
|
||||
* API endpoint for RAG-based question answering
|
||||
* Uses Pinecone for document retrieval and LangChain for generation
|
||||
* POST /api/rag-query
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
// Parse request body
|
||||
const body = await req.json();
|
||||
const { query, topK = 5 } = body;
|
||||
|
||||
if (!query || typeof query !== 'string') {
|
||||
return NextResponse.json({ error: 'Query is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Initialize the RAG service
|
||||
const ragService = RAGService;
|
||||
await ragService.initialize();
|
||||
|
||||
console.log(`Processing RAG query: "${query}" with topK=${topK}`);
|
||||
|
||||
// Retrieve documents and generate answer
|
||||
const answer = await ragService.retrievalQA(query, topK);
|
||||
|
||||
// Check if this is a fallback response
|
||||
const isGeneralKnowledgeFallback = answer.startsWith('[Note: No specific information was found');
|
||||
|
||||
// Return the results
|
||||
return NextResponse.json({
|
||||
answer,
|
||||
usedFallback: isGeneralKnowledgeFallback,
|
||||
success: true
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error in RAG query:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to execute RAG query: ${errorMessage}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,49 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
|
||||
// Proxy route specifically for WebRTC streaming frames
|
||||
// This handles binary image data streaming from the remote WebGPU service
|
||||
|
||||
const REMOTE_WEBGPU_SERVICE_URL = process.env.REMOTE_WEBGPU_SERVICE_URL || 'http://txt2kg-remote-webgpu:8083'
|
||||
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { sessionId: string } }
|
||||
) {
|
||||
try {
|
||||
const sessionId = params.sessionId
|
||||
const searchParams = request.nextUrl.searchParams.toString()
|
||||
const url = `${REMOTE_WEBGPU_SERVICE_URL}/api/stream/${sessionId}${searchParams ? `?${searchParams}` : ''}`
|
||||
|
||||
console.log(`Proxying WebRTC stream request to: ${url}`)
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'GET',
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Remote WebGPU service responded with ${response.status}: ${response.statusText}`)
|
||||
}
|
||||
|
||||
// Get the image data as array buffer
|
||||
const imageBuffer = await response.arrayBuffer()
|
||||
const contentType = response.headers.get('content-type') || 'image/png'
|
||||
|
||||
// Return the image with proper headers
|
||||
return new NextResponse(imageBuffer, {
|
||||
status: 200,
|
||||
headers: {
|
||||
'Content-Type': contentType,
|
||||
'Cache-Control': 'no-cache, no-store, must-revalidate',
|
||||
'Pragma': 'no-cache',
|
||||
'Expires': '0',
|
||||
},
|
||||
})
|
||||
|
||||
} catch (error) {
|
||||
console.error('WebRTC stream proxy error:', error)
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to stream from remote WebGPU service', details: String(error) },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,112 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
|
||||
// Proxy route for remote WebGPU clustering service
|
||||
// This allows the frontend to communicate with the clustering service
|
||||
// even when running in a remote browser environment
|
||||
|
||||
const REMOTE_WEBGPU_SERVICE_URL = process.env.REMOTE_WEBGPU_SERVICE_URL || 'http://txt2kg-remote-webgpu:8083'
|
||||
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
try {
|
||||
const path = params.path.join('/')
|
||||
const searchParams = request.nextUrl.searchParams.toString()
|
||||
const url = `${REMOTE_WEBGPU_SERVICE_URL}/${path}${searchParams ? `?${searchParams}` : ''}`
|
||||
|
||||
console.log(`Proxying GET request to: ${url}`)
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Remote WebGPU service responded with ${response.status}: ${response.statusText}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Remote WebGPU proxy error:', error)
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to communicate with remote WebGPU service', details: String(error) },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
try {
|
||||
const path = params.path.join('/')
|
||||
const body = await request.json()
|
||||
const url = `${REMOTE_WEBGPU_SERVICE_URL}/${path}`
|
||||
|
||||
console.log(`Proxying POST request to: ${url}`)
|
||||
console.log(`Request body:`, JSON.stringify(body, null, 2))
|
||||
console.log(`Using service URL: ${REMOTE_WEBGPU_SERVICE_URL}`)
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
throw new Error(`Remote WebGPU service responded with ${response.status}: ${response.statusText} - ${errorText}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Remote WebGPU proxy error:', error)
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to communicate with remote WebGPU service', details: String(error) },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function DELETE(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
try {
|
||||
const path = params.path.join('/')
|
||||
const url = `${REMOTE_WEBGPU_SERVICE_URL}/${path}`
|
||||
|
||||
console.log(`Proxying DELETE request to: ${url}`)
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Remote WebGPU service responded with ${response.status}: ${response.statusText}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (error) {
|
||||
console.error('Remote WebGPU proxy error:', error)
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to communicate with remote WebGPU service', details: String(error) },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,86 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { processSentenceEmbeddings, SentenceEmbedding } from '@/lib/text-processor';
|
||||
import { PineconeService } from '@/lib/pinecone';
|
||||
|
||||
/**
|
||||
* API endpoint for splitting text into sentences and generating embeddings
|
||||
* POST /api/sentence-embeddings
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
// Parse request body
|
||||
const body = await req.json();
|
||||
const { text, documentId } = body;
|
||||
|
||||
if (!text || typeof text !== 'string') {
|
||||
return NextResponse.json({ error: 'Text is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
console.log(`Processing sentence embeddings for document ${documentId || 'unnamed'}`);
|
||||
console.log(`Text length: ${text.length} characters`);
|
||||
|
||||
// Process sentences and generate embeddings
|
||||
let sentenceEmbeddings: SentenceEmbedding[] = [];
|
||||
try {
|
||||
sentenceEmbeddings = await processSentenceEmbeddings(text, documentId);
|
||||
console.log(`Generated embeddings for ${sentenceEmbeddings.length} sentences using local sentence-transformers service`);
|
||||
} catch (embeddingError) {
|
||||
console.error('Error generating embeddings:', embeddingError);
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to generate embeddings: ${embeddingError instanceof Error ? embeddingError.message : String(embeddingError)}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
|
||||
// Optionally store in vector database
|
||||
if (sentenceEmbeddings.length > 0) {
|
||||
try {
|
||||
// Map the embeddings to a format suitable for Pinecone
|
||||
const embeddingsMap = new Map<string, number[]>();
|
||||
const textContentMap = new Map<string, string>();
|
||||
const metadataMap = new Map<string, any>();
|
||||
|
||||
// Create unique keys for each sentence
|
||||
sentenceEmbeddings.forEach((item, index) => {
|
||||
const key = `${documentId || 'doc'}_sentence_${index}`;
|
||||
embeddingsMap.set(key, item.embedding);
|
||||
textContentMap.set(key, item.sentence);
|
||||
metadataMap.set(key, item.metadata);
|
||||
});
|
||||
|
||||
// Store in Pinecone
|
||||
const pineconeService = PineconeService.getInstance();
|
||||
await pineconeService.storeEmbeddingsWithMetadata(
|
||||
embeddingsMap,
|
||||
textContentMap,
|
||||
metadataMap
|
||||
);
|
||||
|
||||
console.log(`Stored ${sentenceEmbeddings.length} sentence embeddings in vector database`);
|
||||
} catch (storageError) {
|
||||
console.error('Error storing sentence embeddings:', storageError);
|
||||
// Continue even if storage fails - we'll still return the embeddings
|
||||
}
|
||||
}
|
||||
|
||||
// Return a summary to avoid large response sizes
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
count: sentenceEmbeddings.length,
|
||||
documentId: documentId || 'unnamed',
|
||||
// Return only the first few embeddings as samples
|
||||
samples: sentenceEmbeddings.slice(0, 3).map(item => ({
|
||||
sentence: item.sentence,
|
||||
metadata: item.metadata,
|
||||
embeddingDimensions: item.embedding.length
|
||||
}))
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error processing sentence embeddings:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to process sentence embeddings: ${errorMessage}` },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
79
nvidia/txt2kg/assets/frontend/app/api/settings/route.ts
Normal file
79
nvidia/txt2kg/assets/frontend/app/api/settings/route.ts
Normal file
@ -0,0 +1,79 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { GraphDBType } from '@/lib/graph-db-service';
|
||||
|
||||
// In-memory storage for settings
|
||||
let serverSettings: Record<string, string> = {};
|
||||
|
||||
/**
|
||||
* API Route to sync client settings with server environment variables
|
||||
* This allows us to use localStorage settings on the client side
|
||||
* and still access them as environment variables on the server side
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const { settings } = await request.json();
|
||||
|
||||
if (!settings || typeof settings !== 'object') {
|
||||
return NextResponse.json({ error: 'Settings object is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Update server settings
|
||||
serverSettings = { ...serverSettings, ...settings };
|
||||
|
||||
// Log some important settings for debugging
|
||||
if (settings.graph_db_type) {
|
||||
console.log(`Setting graph database type to: ${settings.graph_db_type}`);
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: 'Settings updated successfully'
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error updating settings:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json({ error: errorMessage }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GET /api/settings
|
||||
* Retrieve settings from the server side
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const url = new URL(request.url);
|
||||
const key = url.searchParams.get('key');
|
||||
|
||||
if (key) {
|
||||
// Return specific setting
|
||||
return NextResponse.json({
|
||||
[key]: serverSettings[key] || null
|
||||
});
|
||||
}
|
||||
|
||||
// Return all settings (may want to filter sensitive ones in production)
|
||||
return NextResponse.json({
|
||||
settings: serverSettings
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error retrieving settings:', error);
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
return NextResponse.json({ error: errorMessage }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to get a setting value
|
||||
* For use in other API routes
|
||||
*/
|
||||
export function getSetting(key: string): string | null {
|
||||
return serverSettings[key] || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the currently selected graph database type
|
||||
*/
|
||||
export function getGraphDbType(): GraphDBType {
|
||||
return (serverSettings.graph_db_type as GraphDBType) || 'arangodb';
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user