dgx-spark-playbooks/nvidia/txt2kg/assets/scripts/gnn/preprocess_data.py

412 lines
17 KiB
Python
Raw Normal View History

2025-10-06 17:05:41 +00:00
#!/usr/bin/env python3
import argparse
import gc
import json
import os
import torch
from glob import glob
from itertools import chain
from tqdm import tqdm
from python_arango import ArangoClient
# Import the necessary modules from PyTorch Geometric
from torch_geometric import seed_everything
from torch_geometric.nn import SentenceTransformer
from torch_geometric.utils.rag.backend_utils import (
create_remote_backend_from_triplets,
make_pcst_filter,
preprocess_triplet,
)
from torch_geometric.utils.rag.feature_store import ModernBertFeatureStore
from torch_geometric.utils.rag.graph_store import NeighborSamplingRAGGraphStore
from torch_geometric.loader import RAGQueryLoader
# Define constants for better readability
NV_NIM_MODEL_DEFAULT = "nvidia/llama-3.1-nemotron-70b-instruct"
CHUNK_SIZE_DEFAULT = 512
DEFAULT_ENDPOINT_URL = "https://integrate.api.nvidia.com/v1"
# ArangoDB defaults from docker-compose.yml
ARANGO_URL_DEFAULT = "http://localhost:8529"
ARANGO_DB_DEFAULT = "txt2kg"
ARANGO_USER_DEFAULT = ""
ARANGO_PASSWORD_DEFAULT = ""
# File paths and directories
DATASET_FILE = "tech_qa.pt"
TRIPLES_FILE = "tech_qa_just_triples.pt"
CHECKPOINT_FILE = "checkpoint_kg.pt"
TRAIN_DATA_FILE = "train.json"
CORPUS_DIR = "corpus"
BACKEND_PATH = "backend"
OUTPUT_DIR = "output"
def parse_args():
parser = argparse.ArgumentParser()
# Data processing related arguments
parser.add_argument('--NV_NIM_MODEL', type=str, default=NV_NIM_MODEL_DEFAULT, help="The NIM LLM to use for TXT2KG for LLMJudge")
parser.add_argument('--NV_NIM_KEY', type=str, default="", help="NVIDIA API key")
parser.add_argument(
'--ENDPOINT_URL', type=str, default=DEFAULT_ENDPOINT_URL, help=
"The URL hosting your model, in case you are not using the public NIM."
)
parser.add_argument(
'--chunk_size', type=int, default=512, help="When splitting context documents for txt2kg,\
the maximum number of characters per chunk.")
parser.add_argument('--checkpointing', action="store_true")
# Add ArangoDB-specific arguments
parser.add_argument('--arango_url', type=str, default=ARANGO_URL_DEFAULT, help="ArangoDB URL")
parser.add_argument('--arango_db', type=str, default=ARANGO_DB_DEFAULT, help="ArangoDB database name")
parser.add_argument('--arango_user', type=str, default=ARANGO_USER_DEFAULT, help="ArangoDB username")
parser.add_argument('--arango_password', type=str, default=ARANGO_PASSWORD_DEFAULT, help="ArangoDB password")
parser.add_argument('--use_arango', action="store_true", help="Use ArangoDB instead of TXT2KG")
# Add file path arguments
parser.add_argument('--dataset_file', type=str, default=DATASET_FILE, help="Path to save/load dataset")
parser.add_argument('--triples_file', type=str, default=TRIPLES_FILE, help="Path to save/load triples")
parser.add_argument('--checkpoint_file', type=str, default=CHECKPOINT_FILE, help="Path to save/load checkpoint")
parser.add_argument('--train_data_file', type=str, default=TRAIN_DATA_FILE, help="Path to training data file")
parser.add_argument('--corpus_dir', type=str, default=CORPUS_DIR, help="Directory containing corpus documents")
parser.add_argument('--backend_path', type=str, default=BACKEND_PATH, help="Path for backend storage")
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help="Directory for output files")
return parser.parse_args()
def load_triples_from_arangodb(arango_url, arango_db, arango_user, arango_password):
"""
Load triples from ArangoDB for use with the TXT2KG dataset
Args:
arango_url: ArangoDB connection URL
arango_db: ArangoDB database name
arango_user: ArangoDB username
arango_password: ArangoDB password
Returns:
Array of triples in the format expected by create_remote_backend_from_triplets
"""
try:
# Connect to ArangoDB
client = ArangoClient(hosts=arango_url)
# Get database (no auth in our docker setup)
if arango_user and arango_password:
db = client.db(arango_db, username=arango_user, password=arango_password)
else:
db = client.db(arango_db)
# Query to get all triples from ArangoDB as structured objects
# Handle case sensitivity and trim whitespace
aql_query = """
FOR e IN relationships
LET subject = TRIM(DOCUMENT(e._from).name)
LET object = TRIM(DOCUMENT(e._to).name)
LET predicate = TRIM(e.type)
FILTER subject != "" AND predicate != "" AND object != ""
RETURN {
subject: subject,
predicate: predicate,
object: object
}
"""
# Execute the query
cursor = db.aql.execute(aql_query)
triple_dicts = list(cursor)
# Format triples as strings in the format expected by PyTorch Geometric
# The expected format is a list of strings in the form "subject predicate object"
triples = format_triples_for_pytorch_geometric(triple_dicts)
print(f"Loaded {len(triples)} triples from ArangoDB")
# Print sample triples for debugging
if len(triples) > 0:
print("Sample triples:")
for i in range(min(3, len(triples))):
print(f" {triples[i]}")
return triples
except Exception as error:
print(f"Error loading triples from ArangoDB: {error}")
raise error
def format_triples_for_pytorch_geometric(triple_dicts):
"""
Format triples from ArangoDB into the format expected by PyTorch Geometric
Args:
triple_dicts: List of dictionaries with subject, predicate, object keys
Returns:
List of strings in the format "subject predicate object"
"""
triples = []
# Create a set to avoid duplicates
unique_triples = set()
for triple_dict in triple_dicts:
# Skip any triple with empty values
if not triple_dict['subject'] or not triple_dict['predicate'] or not triple_dict['object']:
continue
# Create a space-separated string in the format that preprocess_triplet expects
triple_str = f"{triple_dict['subject']} {triple_dict['predicate']} {triple_dict['object']}"
# Only add if not already in the set
if triple_str not in unique_triples:
unique_triples.add(triple_str)
triples.append(triple_str)
return triples
def get_data(args):
# need a JSON dict of Questions and answers, see below for how its used
with open(args.train_data_file) as file:
json_obj = json.load(file)
text_contexts = []
# need a folder of text files to use for RAG and to make a KG from
for file_path in glob(f"{args.corpus_dir}/*"):
with open(file_path, "r+") as f:
text_contexts.append(f.read())
return json_obj, text_contexts
def validate_triple_format(triples):
"""
Validate and fix triple format if needed to ensure compatibility with preprocess_triplet
Args:
triples: List of triples to validate
Returns:
Fixed list of triples in the format expected by preprocess_triplet
"""
validated_triples = []
print(f"Validating {len(triples)} triples...")
for i, triple in enumerate(triples):
# If triple is already a proper string with subject, predicate, object
if isinstance(triple, str):
parts = triple.split()
# Ensure there are at least 3 parts (subject, predicate, object)
if len(parts) >= 3:
# For strings with more than 3 parts, use first as subject, second as predicate,
# and join the rest as object
subject = parts[0]
predicate = parts[1]
obj = ' '.join(parts[2:])
validated_triple = f"{subject} {predicate} {obj}"
validated_triples.append(validated_triple)
else:
print(f"Warning: Triple at index {i} has fewer than 3 parts: {triple}")
# If triple is a dictionary with subject, predicate, object keys
elif isinstance(triple, dict) and 'subject' in triple and 'predicate' in triple and 'object' in triple:
validated_triple = f"{triple['subject']} {triple['predicate']} {triple['object']}"
validated_triples.append(validated_triple)
# If triple is a tuple or list of length 3
elif (isinstance(triple, tuple) or isinstance(triple, list)) and len(triple) == 3:
validated_triple = f"{triple[0]} {triple[1]} {triple[2]}"
validated_triples.append(validated_triple)
else:
print(f"Warning: Skipping triple at index {i} with invalid format: {triple}")
print(f"Validation complete. {len(validated_triples)} valid triples out of {len(triples)}")
return validated_triples
def make_dataset(args):
"""Modified make_dataset function that can use ArangoDB as a data source"""
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
dataset_path = os.path.join(args.output_dir, args.dataset_file)
triples_path = os.path.join(args.output_dir, args.triples_file)
checkpoint_path = os.path.join(args.output_dir, args.checkpoint_file)
if os.path.exists(dataset_path):
print(f"Re-using Saved TechQA KG-RAG Dataset from {dataset_path}...")
return torch.load(dataset_path, weights_only=False)
else:
qa_pairs, context_docs = get_data(args)
print("Number of Docs in our VectorDB =", len(context_docs))
data_lists = {"train": [], "validation": [], "test": []}
# Load triples either from saved file or from sources
triples = []
if os.path.exists(triples_path):
triples = torch.load(triples_path, weights_only=False)
else:
if args.use_arango:
# Load triples from ArangoDB instead of generating with TXT2KG
print("Loading triples from ArangoDB...")
triples = load_triples_from_arangodb(
args.arango_url,
args.arango_db,
args.arango_user,
args.arango_password
)
# Validate and fix triples format if needed
triples = validate_triple_format(triples)
# Save triples for future use
torch.save(triples, triples_path)
else:
# Original TXT2KG code path
from torch_geometric.nn import TXT2KG
kg_maker = TXT2KG(
NVIDIA_NIM_MODEL=args.NV_NIM_MODEL,
NVIDIA_API_KEY=args.NV_NIM_KEY,
ENDPOINT_URL=args.ENDPOINT_URL,
chunk_size=args.chunk_size
)
print(
"Note that if the TXT2KG process is too slow for you're liking using the public NIM, "
"consider deploying yourself using local_lm flag of TXT2KG or using "
"https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct?snippet_tab=Docker "
"to deploy to a private endpoint, which you can pass to this script w/ --ENDPOINT_URL flag."
)
total_tqdm_count = len(context_docs)
initial_tqdm_count = 0
if os.path.exists(checkpoint_path):
print(f"Restoring KG from checkpoint at {checkpoint_path}...")
saved_relevant_triples = torch.load(checkpoint_path, weights_only=False)
kg_maker.relevant_triples = saved_relevant_triples
kg_maker.doc_id_counter = len(saved_relevant_triples)
initial_tqdm_count = kg_maker.doc_id_counter
context_docs = context_docs[(kg_maker.doc_id_counter - 1):]
if args.checkpointing:
interval = 10
count = 0
for context_doc in tqdm(context_docs, total=total_tqdm_count,
initial=initial_tqdm_count, desc="Extracting KG triples"):
kg_maker.add_doc_2_KG(txt=context_doc)
if args.checkpointing:
count += 1
if count == interval:
print(f" checkpointing KG to {checkpoint_path}...")
count = 0
kg_maker.save_kg(checkpoint_path)
relevant_triples = kg_maker.relevant_triples
triples.extend(
list(
chain.from_iterable(
triple_set for triple_set in relevant_triples.values()
)
)
)
triples = list(dict.fromkeys(triples))
torch.save(triples, triples_path)
if args.checkpointing and os.path.exists(checkpoint_path):
os.remove(checkpoint_path)
print("Number of triples in our GraphDB =", len(triples))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sent_trans_batch_size = 256
model = SentenceTransformer(
model_name='Alibaba-NLP/gte-modernbert-base').to(device)
backend_path = os.path.join(args.output_dir, args.backend_path)
fs, gs = create_remote_backend_from_triplets(
triplets=triples,
node_embedding_model=model,
node_method_to_call="encode",
path=backend_path,
pre_transform=preprocess_triplet,
node_method_kwargs={
"batch_size": min(len(triples), sent_trans_batch_size)
},
graph_db=NeighborSamplingRAGGraphStore,
feature_db=ModernBertFeatureStore
).load()
# encode the raw context docs
embedded_docs = model.encode(
context_docs,
output_device=device,
batch_size=int(sent_trans_batch_size / 4),
verbose=True
)
# k for KNN
knn_neighsample_bs = 1024
# number of neighbors for each seed node selected by KNN
fanout = 100
# number of hops for neighborsampling
num_hops = 2
local_filter_kwargs = {
"topk": 5, # nodes
"topk_e": 5, # edges
"cost_e": .5, # edge cost
"num_clusters": 10, # num clusters
}
print("Now to retrieve context for each query from our Vector and Graph DBs...")
# GraphDB retrieval done with KNN+NeighborSampling+PCST
# PCST = Prize Collecting Steiner Tree
# VectorDB retrieval just vanilla RAG
query_loader = RAGQueryLoader(
data=(fs, gs),
seed_nodes_kwargs={"k_nodes": knn_neighsample_bs},
sampler_kwargs={"num_neighbors": [fanout] * num_hops},
local_filter=make_pcst_filter(triples, model),
local_filter_kwargs=local_filter_kwargs,
raw_docs=context_docs,
embedded_docs=embedded_docs
)
total_data_list = []
extracted_triple_sizes = []
for data_point in tqdm(qa_pairs, desc="Building un-split dataset"):
if data_point["is_impossible"]:
continue
QA_pair = (data_point["question"], data_point["answer"])
q = QA_pair[0]
subgraph = query_loader.query(q)
subgraph.label = QA_pair[1]
total_data_list.append(subgraph)
extracted_triple_sizes.append(len(subgraph.triples))
import random
random.shuffle(total_data_list)
print("Min # of Retrieved Triples =", min(extracted_triple_sizes))
print("Max # of Retrieved Triples =", max(extracted_triple_sizes))
print("Average # of Retrieved Triples =", sum(extracted_triple_sizes) / len(extracted_triple_sizes))
# 60:20:20 split
data_lists["train"] = total_data_list[:int(.6 * len(total_data_list))]
data_lists["validation"] = total_data_list[
int(.6 * len(total_data_list)):int(.8 * len(total_data_list))]
data_lists["test"] = total_data_list[int(.8 * len(total_data_list)):]
torch.save(data_lists, dataset_path)
del model
gc.collect()
torch.cuda.empty_cache()
return data_lists
if __name__ == '__main__':
# for reproducibility
seed_everything(50)
args = parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Process and save dataset
data_lists = make_dataset(args)
print(f"Dataset processed and saved to {os.path.join(args.output_dir, args.dataset_file)}")
print("Training data size:", len(data_lists["train"]))
print("Validation data size:", len(data_lists["validation"]))
print("Testing data size:", len(data_lists["test"]))