mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-23 02:23:53 +00:00
412 lines
17 KiB
Python
412 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import gc
|
|
import json
|
|
import os
|
|
import torch
|
|
from glob import glob
|
|
from itertools import chain
|
|
from tqdm import tqdm
|
|
from python_arango import ArangoClient
|
|
|
|
# Import the necessary modules from PyTorch Geometric
|
|
from torch_geometric import seed_everything
|
|
from torch_geometric.nn import SentenceTransformer
|
|
from torch_geometric.utils.rag.backend_utils import (
|
|
create_remote_backend_from_triplets,
|
|
make_pcst_filter,
|
|
preprocess_triplet,
|
|
)
|
|
from torch_geometric.utils.rag.feature_store import ModernBertFeatureStore
|
|
from torch_geometric.utils.rag.graph_store import NeighborSamplingRAGGraphStore
|
|
from torch_geometric.loader import RAGQueryLoader
|
|
|
|
# Define constants for better readability
|
|
NV_NIM_MODEL_DEFAULT = "nvidia/llama-3.1-nemotron-70b-instruct"
|
|
CHUNK_SIZE_DEFAULT = 512
|
|
DEFAULT_ENDPOINT_URL = "https://integrate.api.nvidia.com/v1"
|
|
|
|
# ArangoDB defaults from docker-compose.yml
|
|
ARANGO_URL_DEFAULT = "http://localhost:8529"
|
|
ARANGO_DB_DEFAULT = "txt2kg"
|
|
ARANGO_USER_DEFAULT = ""
|
|
ARANGO_PASSWORD_DEFAULT = ""
|
|
|
|
# File paths and directories
|
|
DATASET_FILE = "tech_qa.pt"
|
|
TRIPLES_FILE = "tech_qa_just_triples.pt"
|
|
CHECKPOINT_FILE = "checkpoint_kg.pt"
|
|
TRAIN_DATA_FILE = "train.json"
|
|
CORPUS_DIR = "corpus"
|
|
BACKEND_PATH = "backend"
|
|
OUTPUT_DIR = "output"
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
# Data processing related arguments
|
|
parser.add_argument('--NV_NIM_MODEL', type=str, default=NV_NIM_MODEL_DEFAULT, help="The NIM LLM to use for TXT2KG for LLMJudge")
|
|
parser.add_argument('--NV_NIM_KEY', type=str, default="", help="NVIDIA API key")
|
|
parser.add_argument(
|
|
'--ENDPOINT_URL', type=str, default=DEFAULT_ENDPOINT_URL, help=
|
|
"The URL hosting your model, in case you are not using the public NIM."
|
|
)
|
|
parser.add_argument(
|
|
'--chunk_size', type=int, default=512, help="When splitting context documents for txt2kg,\
|
|
the maximum number of characters per chunk.")
|
|
parser.add_argument('--checkpointing', action="store_true")
|
|
|
|
# Add ArangoDB-specific arguments
|
|
parser.add_argument('--arango_url', type=str, default=ARANGO_URL_DEFAULT, help="ArangoDB URL")
|
|
parser.add_argument('--arango_db', type=str, default=ARANGO_DB_DEFAULT, help="ArangoDB database name")
|
|
parser.add_argument('--arango_user', type=str, default=ARANGO_USER_DEFAULT, help="ArangoDB username")
|
|
parser.add_argument('--arango_password', type=str, default=ARANGO_PASSWORD_DEFAULT, help="ArangoDB password")
|
|
parser.add_argument('--use_arango', action="store_true", help="Use ArangoDB instead of TXT2KG")
|
|
|
|
# Add file path arguments
|
|
parser.add_argument('--dataset_file', type=str, default=DATASET_FILE, help="Path to save/load dataset")
|
|
parser.add_argument('--triples_file', type=str, default=TRIPLES_FILE, help="Path to save/load triples")
|
|
parser.add_argument('--checkpoint_file', type=str, default=CHECKPOINT_FILE, help="Path to save/load checkpoint")
|
|
parser.add_argument('--train_data_file', type=str, default=TRAIN_DATA_FILE, help="Path to training data file")
|
|
parser.add_argument('--corpus_dir', type=str, default=CORPUS_DIR, help="Directory containing corpus documents")
|
|
parser.add_argument('--backend_path', type=str, default=BACKEND_PATH, help="Path for backend storage")
|
|
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help="Directory for output files")
|
|
|
|
return parser.parse_args()
|
|
|
|
def load_triples_from_arangodb(arango_url, arango_db, arango_user, arango_password):
|
|
"""
|
|
Load triples from ArangoDB for use with the TXT2KG dataset
|
|
|
|
Args:
|
|
arango_url: ArangoDB connection URL
|
|
arango_db: ArangoDB database name
|
|
arango_user: ArangoDB username
|
|
arango_password: ArangoDB password
|
|
|
|
Returns:
|
|
Array of triples in the format expected by create_remote_backend_from_triplets
|
|
"""
|
|
try:
|
|
# Connect to ArangoDB
|
|
client = ArangoClient(hosts=arango_url)
|
|
|
|
# Get database (no auth in our docker setup)
|
|
if arango_user and arango_password:
|
|
db = client.db(arango_db, username=arango_user, password=arango_password)
|
|
else:
|
|
db = client.db(arango_db)
|
|
|
|
# Query to get all triples from ArangoDB as structured objects
|
|
# Handle case sensitivity and trim whitespace
|
|
aql_query = """
|
|
FOR e IN relationships
|
|
LET subject = TRIM(DOCUMENT(e._from).name)
|
|
LET object = TRIM(DOCUMENT(e._to).name)
|
|
LET predicate = TRIM(e.type)
|
|
FILTER subject != "" AND predicate != "" AND object != ""
|
|
RETURN {
|
|
subject: subject,
|
|
predicate: predicate,
|
|
object: object
|
|
}
|
|
"""
|
|
|
|
# Execute the query
|
|
cursor = db.aql.execute(aql_query)
|
|
triple_dicts = list(cursor)
|
|
|
|
# Format triples as strings in the format expected by PyTorch Geometric
|
|
# The expected format is a list of strings in the form "subject predicate object"
|
|
triples = format_triples_for_pytorch_geometric(triple_dicts)
|
|
|
|
print(f"Loaded {len(triples)} triples from ArangoDB")
|
|
# Print sample triples for debugging
|
|
if len(triples) > 0:
|
|
print("Sample triples:")
|
|
for i in range(min(3, len(triples))):
|
|
print(f" {triples[i]}")
|
|
|
|
return triples
|
|
except Exception as error:
|
|
print(f"Error loading triples from ArangoDB: {error}")
|
|
raise error
|
|
|
|
def format_triples_for_pytorch_geometric(triple_dicts):
|
|
"""
|
|
Format triples from ArangoDB into the format expected by PyTorch Geometric
|
|
|
|
Args:
|
|
triple_dicts: List of dictionaries with subject, predicate, object keys
|
|
|
|
Returns:
|
|
List of strings in the format "subject predicate object"
|
|
"""
|
|
triples = []
|
|
# Create a set to avoid duplicates
|
|
unique_triples = set()
|
|
|
|
for triple_dict in triple_dicts:
|
|
# Skip any triple with empty values
|
|
if not triple_dict['subject'] or not triple_dict['predicate'] or not triple_dict['object']:
|
|
continue
|
|
|
|
# Create a space-separated string in the format that preprocess_triplet expects
|
|
triple_str = f"{triple_dict['subject']} {triple_dict['predicate']} {triple_dict['object']}"
|
|
|
|
# Only add if not already in the set
|
|
if triple_str not in unique_triples:
|
|
unique_triples.add(triple_str)
|
|
triples.append(triple_str)
|
|
|
|
return triples
|
|
|
|
def get_data(args):
|
|
# need a JSON dict of Questions and answers, see below for how its used
|
|
with open(args.train_data_file) as file:
|
|
json_obj = json.load(file)
|
|
|
|
text_contexts = []
|
|
# need a folder of text files to use for RAG and to make a KG from
|
|
for file_path in glob(f"{args.corpus_dir}/*"):
|
|
with open(file_path, "r+") as f:
|
|
text_contexts.append(f.read())
|
|
|
|
return json_obj, text_contexts
|
|
|
|
def validate_triple_format(triples):
|
|
"""
|
|
Validate and fix triple format if needed to ensure compatibility with preprocess_triplet
|
|
|
|
Args:
|
|
triples: List of triples to validate
|
|
|
|
Returns:
|
|
Fixed list of triples in the format expected by preprocess_triplet
|
|
"""
|
|
validated_triples = []
|
|
|
|
print(f"Validating {len(triples)} triples...")
|
|
for i, triple in enumerate(triples):
|
|
# If triple is already a proper string with subject, predicate, object
|
|
if isinstance(triple, str):
|
|
parts = triple.split()
|
|
# Ensure there are at least 3 parts (subject, predicate, object)
|
|
if len(parts) >= 3:
|
|
# For strings with more than 3 parts, use first as subject, second as predicate,
|
|
# and join the rest as object
|
|
subject = parts[0]
|
|
predicate = parts[1]
|
|
obj = ' '.join(parts[2:])
|
|
validated_triple = f"{subject} {predicate} {obj}"
|
|
validated_triples.append(validated_triple)
|
|
else:
|
|
print(f"Warning: Triple at index {i} has fewer than 3 parts: {triple}")
|
|
# If triple is a dictionary with subject, predicate, object keys
|
|
elif isinstance(triple, dict) and 'subject' in triple and 'predicate' in triple and 'object' in triple:
|
|
validated_triple = f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
|
validated_triples.append(validated_triple)
|
|
# If triple is a tuple or list of length 3
|
|
elif (isinstance(triple, tuple) or isinstance(triple, list)) and len(triple) == 3:
|
|
validated_triple = f"{triple[0]} {triple[1]} {triple[2]}"
|
|
validated_triples.append(validated_triple)
|
|
else:
|
|
print(f"Warning: Skipping triple at index {i} with invalid format: {triple}")
|
|
|
|
print(f"Validation complete. {len(validated_triples)} valid triples out of {len(triples)}")
|
|
return validated_triples
|
|
|
|
def make_dataset(args):
|
|
"""Modified make_dataset function that can use ArangoDB as a data source"""
|
|
# Create output directory if it doesn't exist
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
dataset_path = os.path.join(args.output_dir, args.dataset_file)
|
|
triples_path = os.path.join(args.output_dir, args.triples_file)
|
|
checkpoint_path = os.path.join(args.output_dir, args.checkpoint_file)
|
|
|
|
if os.path.exists(dataset_path):
|
|
print(f"Re-using Saved TechQA KG-RAG Dataset from {dataset_path}...")
|
|
return torch.load(dataset_path, weights_only=False)
|
|
else:
|
|
qa_pairs, context_docs = get_data(args)
|
|
print("Number of Docs in our VectorDB =", len(context_docs))
|
|
data_lists = {"train": [], "validation": [], "test": []}
|
|
|
|
# Load triples either from saved file or from sources
|
|
triples = []
|
|
if os.path.exists(triples_path):
|
|
triples = torch.load(triples_path, weights_only=False)
|
|
else:
|
|
if args.use_arango:
|
|
# Load triples from ArangoDB instead of generating with TXT2KG
|
|
print("Loading triples from ArangoDB...")
|
|
triples = load_triples_from_arangodb(
|
|
args.arango_url,
|
|
args.arango_db,
|
|
args.arango_user,
|
|
args.arango_password
|
|
)
|
|
# Validate and fix triples format if needed
|
|
triples = validate_triple_format(triples)
|
|
# Save triples for future use
|
|
torch.save(triples, triples_path)
|
|
else:
|
|
# Original TXT2KG code path
|
|
from torch_geometric.nn import TXT2KG
|
|
kg_maker = TXT2KG(
|
|
NVIDIA_NIM_MODEL=args.NV_NIM_MODEL,
|
|
NVIDIA_API_KEY=args.NV_NIM_KEY,
|
|
ENDPOINT_URL=args.ENDPOINT_URL,
|
|
chunk_size=args.chunk_size
|
|
)
|
|
print(
|
|
"Note that if the TXT2KG process is too slow for you're liking using the public NIM, "
|
|
"consider deploying yourself using local_lm flag of TXT2KG or using "
|
|
"https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct?snippet_tab=Docker "
|
|
"to deploy to a private endpoint, which you can pass to this script w/ --ENDPOINT_URL flag."
|
|
)
|
|
|
|
total_tqdm_count = len(context_docs)
|
|
initial_tqdm_count = 0
|
|
if os.path.exists(checkpoint_path):
|
|
print(f"Restoring KG from checkpoint at {checkpoint_path}...")
|
|
saved_relevant_triples = torch.load(checkpoint_path, weights_only=False)
|
|
kg_maker.relevant_triples = saved_relevant_triples
|
|
kg_maker.doc_id_counter = len(saved_relevant_triples)
|
|
initial_tqdm_count = kg_maker.doc_id_counter
|
|
context_docs = context_docs[(kg_maker.doc_id_counter - 1):]
|
|
|
|
if args.checkpointing:
|
|
interval = 10
|
|
count = 0
|
|
|
|
for context_doc in tqdm(context_docs, total=total_tqdm_count,
|
|
initial=initial_tqdm_count, desc="Extracting KG triples"):
|
|
kg_maker.add_doc_2_KG(txt=context_doc)
|
|
if args.checkpointing:
|
|
count += 1
|
|
if count == interval:
|
|
print(f" checkpointing KG to {checkpoint_path}...")
|
|
count = 0
|
|
kg_maker.save_kg(checkpoint_path)
|
|
|
|
relevant_triples = kg_maker.relevant_triples
|
|
triples.extend(
|
|
list(
|
|
chain.from_iterable(
|
|
triple_set for triple_set in relevant_triples.values()
|
|
)
|
|
)
|
|
)
|
|
triples = list(dict.fromkeys(triples))
|
|
torch.save(triples, triples_path)
|
|
|
|
if args.checkpointing and os.path.exists(checkpoint_path):
|
|
os.remove(checkpoint_path)
|
|
|
|
print("Number of triples in our GraphDB =", len(triples))
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
sent_trans_batch_size = 256
|
|
model = SentenceTransformer(
|
|
model_name='Alibaba-NLP/gte-modernbert-base').to(device)
|
|
|
|
backend_path = os.path.join(args.output_dir, args.backend_path)
|
|
fs, gs = create_remote_backend_from_triplets(
|
|
triplets=triples,
|
|
node_embedding_model=model,
|
|
node_method_to_call="encode",
|
|
path=backend_path,
|
|
pre_transform=preprocess_triplet,
|
|
node_method_kwargs={
|
|
"batch_size": min(len(triples), sent_trans_batch_size)
|
|
},
|
|
graph_db=NeighborSamplingRAGGraphStore,
|
|
feature_db=ModernBertFeatureStore
|
|
).load()
|
|
|
|
# encode the raw context docs
|
|
embedded_docs = model.encode(
|
|
context_docs,
|
|
output_device=device,
|
|
batch_size=int(sent_trans_batch_size / 4),
|
|
verbose=True
|
|
)
|
|
|
|
# k for KNN
|
|
knn_neighsample_bs = 1024
|
|
# number of neighbors for each seed node selected by KNN
|
|
fanout = 100
|
|
# number of hops for neighborsampling
|
|
num_hops = 2
|
|
|
|
local_filter_kwargs = {
|
|
"topk": 5, # nodes
|
|
"topk_e": 5, # edges
|
|
"cost_e": .5, # edge cost
|
|
"num_clusters": 10, # num clusters
|
|
}
|
|
|
|
print("Now to retrieve context for each query from our Vector and Graph DBs...")
|
|
# GraphDB retrieval done with KNN+NeighborSampling+PCST
|
|
# PCST = Prize Collecting Steiner Tree
|
|
# VectorDB retrieval just vanilla RAG
|
|
query_loader = RAGQueryLoader(
|
|
data=(fs, gs),
|
|
seed_nodes_kwargs={"k_nodes": knn_neighsample_bs},
|
|
sampler_kwargs={"num_neighbors": [fanout] * num_hops},
|
|
local_filter=make_pcst_filter(triples, model),
|
|
local_filter_kwargs=local_filter_kwargs,
|
|
raw_docs=context_docs,
|
|
embedded_docs=embedded_docs
|
|
)
|
|
|
|
total_data_list = []
|
|
extracted_triple_sizes = []
|
|
|
|
for data_point in tqdm(qa_pairs, desc="Building un-split dataset"):
|
|
if data_point["is_impossible"]:
|
|
continue
|
|
|
|
QA_pair = (data_point["question"], data_point["answer"])
|
|
q = QA_pair[0]
|
|
subgraph = query_loader.query(q)
|
|
subgraph.label = QA_pair[1]
|
|
total_data_list.append(subgraph)
|
|
extracted_triple_sizes.append(len(subgraph.triples))
|
|
|
|
import random
|
|
random.shuffle(total_data_list)
|
|
|
|
print("Min # of Retrieved Triples =", min(extracted_triple_sizes))
|
|
print("Max # of Retrieved Triples =", max(extracted_triple_sizes))
|
|
print("Average # of Retrieved Triples =", sum(extracted_triple_sizes) / len(extracted_triple_sizes))
|
|
|
|
# 60:20:20 split
|
|
data_lists["train"] = total_data_list[:int(.6 * len(total_data_list))]
|
|
data_lists["validation"] = total_data_list[
|
|
int(.6 * len(total_data_list)):int(.8 * len(total_data_list))]
|
|
data_lists["test"] = total_data_list[int(.8 * len(total_data_list)):]
|
|
|
|
torch.save(data_lists, dataset_path)
|
|
|
|
del model
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
return data_lists
|
|
|
|
if __name__ == '__main__':
|
|
# for reproducibility
|
|
seed_everything(50)
|
|
args = parse_args()
|
|
|
|
# Create output directory
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Process and save dataset
|
|
data_lists = make_dataset(args)
|
|
print(f"Dataset processed and saved to {os.path.join(args.output_dir, args.dataset_file)}")
|
|
print("Training data size:", len(data_lists["train"]))
|
|
print("Validation data size:", len(data_lists["validation"]))
|
|
print("Testing data size:", len(data_lists["test"])) |