mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-23 10:33:51 +00:00
130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
#!/usr/bin/env python3
|
|
#
|
|
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
import torch
|
|
from flask import Flask, request, jsonify
|
|
import torch_geometric
|
|
from torch_geometric.nn import GAT, LLM, GRetriever
|
|
|
|
app = Flask(__name__)
|
|
|
|
# Constants
|
|
MODEL_PATH = os.environ.get('MODEL_PATH', '/app/models/tech-qa-model.pt')
|
|
LLM_GENERATOR_NAME = os.environ.get('LLM_GENERATOR_NAME', 'meta-llama/Meta-Llama-3.1-8B-Instruct')
|
|
GNN_HID_CHANNELS = int(os.environ.get('GNN_HID_CHANNELS', '1024'))
|
|
GNN_LAYERS = int(os.environ.get('GNN_LAYERS', '4'))
|
|
|
|
# Prompt template for questions
|
|
prompt_template = """Answer this question based on retrieved contexts. Just give the answer without explanation.
|
|
[QUESTION] {question} [END_QUESTION]
|
|
[RETRIEVED_CONTEXTS] {context} [END_RETRIEVED_CONTEXTS]
|
|
Answer: """
|
|
|
|
# Load the model
|
|
def load_model():
|
|
print(f"Loading model from {MODEL_PATH}")
|
|
|
|
# Create the GNN component
|
|
gnn = GAT(in_channels=768, hidden_channels=GNN_HID_CHANNELS,
|
|
out_channels=1024, num_layers=GNN_LAYERS, heads=4)
|
|
|
|
# Create the LLM component
|
|
llm = LLM(model_name=LLM_GENERATOR_NAME)
|
|
|
|
# Create the GRetriever model
|
|
model = GRetriever(llm=llm, gnn=gnn)
|
|
|
|
# Load trained weights
|
|
if os.path.exists(MODEL_PATH):
|
|
state_dict = torch.load(MODEL_PATH, weights_only=True)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
print("Model loaded successfully")
|
|
else:
|
|
print(f"WARNING: Model file not found at {MODEL_PATH}. Using untrained model.")
|
|
|
|
return model
|
|
|
|
# Initialize model
|
|
model = None
|
|
|
|
@app.before_first_request
|
|
def initialize():
|
|
global model
|
|
model = load_model()
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health_check():
|
|
return jsonify({"status": "healthy"})
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
def predict():
|
|
if not request.is_json:
|
|
return jsonify({"error": "Request must be JSON"}), 400
|
|
|
|
data = request.get_json()
|
|
|
|
if 'question' not in data:
|
|
return jsonify({"error": "Question is required"}), 400
|
|
|
|
if 'context' not in data:
|
|
return jsonify({"error": "Context is required"}), 400
|
|
|
|
question = data['question']
|
|
context = data['context']
|
|
|
|
# Format the question with context using the prompt template
|
|
formatted_question = prompt_template.format(question=question, context=context)
|
|
|
|
# Prepare input for the model
|
|
# Note: In a real implementation, you'd need to convert text to graph structure
|
|
# Here we're assuming a simplified interface for demonstration
|
|
try:
|
|
# Create a PyTorch Geometric Data object
|
|
# This is simplified and would need to be adapted to your actual graph structure
|
|
graph_data = create_graph_from_text(context)
|
|
|
|
# Generate prediction
|
|
with torch.no_grad():
|
|
prediction = model.generate(
|
|
input_question=[formatted_question],
|
|
input_graph=graph_data
|
|
)[0] # Get first prediction since we're processing one sample
|
|
|
|
return jsonify({
|
|
"question": question,
|
|
"answer": prediction
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
def create_graph_from_text(text):
|
|
"""
|
|
Convert text to a graph structure for the GNN.
|
|
This is a placeholder - you'll need to implement the actual conversion
|
|
based on your specific graph construction approach.
|
|
"""
|
|
# This would need to be implemented based on how your graphs are constructed
|
|
# For now, return a dummy graph
|
|
raise NotImplementedError("Graph creation from text needs to be implemented")
|
|
|
|
if __name__ == '__main__':
|
|
port = int(os.environ.get('PORT', 5000))
|
|
app.run(host='0.0.0.0', port=port) |