dgx-spark-playbooks/nvidia/txt2kg/assets/scripts/gnn/preprocess_data.py
2026-01-14 16:05:35 +00:00

349 lines
14 KiB
Python

#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import gc
import json
import os
import torch
from glob import glob
from itertools import chain
from tqdm import tqdm
from python_arango import ArangoClient
# Import the necessary modules from PyTorch Geometric
from torch_geometric import seed_everything
from torch_geometric.nn import SentenceTransformer
from torch_geometric.utils.rag.backend_utils import (
create_remote_backend_from_triplets,
make_pcst_filter,
preprocess_triplet,
)
from torch_geometric.utils.rag.feature_store import ModernBertFeatureStore
from torch_geometric.utils.rag.graph_store import NeighborSamplingRAGGraphStore
from torch_geometric.loader import RAGQueryLoader
# Define constants for better readability
NV_NIM_MODEL_DEFAULT = "nvidia/llama-3.1-nemotron-70b-instruct"
CHUNK_SIZE_DEFAULT = 512
DEFAULT_ENDPOINT_URL = "https://integrate.api.nvidia.com/v1"
# ArangoDB defaults from docker-compose.yml
ARANGO_URL_DEFAULT = "http://localhost:8529"
ARANGO_DB_DEFAULT = "txt2kg"
ARANGO_USER_DEFAULT = ""
ARANGO_PASSWORD_DEFAULT = ""
# File paths and directories
DATASET_FILE = "tech_qa.pt"
TRIPLES_FILE = "tech_qa_just_triples.pt"
CHECKPOINT_FILE = "checkpoint_kg.pt"
TRAIN_DATA_FILE = "train.json"
CORPUS_DIR = "corpus"
BACKEND_PATH = "backend"
OUTPUT_DIR = "output"
def parse_args():
parser = argparse.ArgumentParser()
# Data processing related arguments
parser.add_argument('--NV_NIM_MODEL', type=str, default=NV_NIM_MODEL_DEFAULT, help="The NIM LLM to use for TXT2KG for LLMJudge")
parser.add_argument('--NV_NIM_KEY', type=str, default="", help="NVIDIA API key")
parser.add_argument(
'--ENDPOINT_URL', type=str, default=DEFAULT_ENDPOINT_URL, help=
"The URL hosting your model, in case you are not using the public NIM."
)
parser.add_argument(
'--chunk_size', type=int, default=512, help="When splitting context documents for txt2kg,\
the maximum number of characters per chunk.")
parser.add_argument('--checkpointing', action="store_true")
# Add ArangoDB-specific arguments
parser.add_argument('--arango_url', type=str, default=ARANGO_URL_DEFAULT, help="ArangoDB URL")
parser.add_argument('--arango_db', type=str, default=ARANGO_DB_DEFAULT, help="ArangoDB database name")
parser.add_argument('--arango_user', type=str, default=ARANGO_USER_DEFAULT, help="ArangoDB username")
parser.add_argument('--arango_password', type=str, default=ARANGO_PASSWORD_DEFAULT, help="ArangoDB password")
parser.add_argument('--use_arango', action="store_true", help="Use ArangoDB instead of TXT2KG")
# Add file path arguments
parser.add_argument('--dataset_file', type=str, default=DATASET_FILE, help="Path to save/load dataset")
parser.add_argument('--triples_file', type=str, default=TRIPLES_FILE, help="Path to save/load triples")
parser.add_argument('--checkpoint_file', type=str, default=CHECKPOINT_FILE, help="Path to save/load checkpoint")
parser.add_argument('--train_data_file', type=str, default=TRAIN_DATA_FILE, help="Path to training data file")
parser.add_argument('--corpus_dir', type=str, default=CORPUS_DIR, help="Directory containing corpus documents")
parser.add_argument('--backend_path', type=str, default=BACKEND_PATH, help="Path for backend storage")
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help="Directory for output files")
return parser.parse_args()
def load_triples_from_arangodb(arango_url: str, arango_db: str, arango_user: str, arango_password: str) -> list[str]:
"""
Load triples from ArangoDB for use with the TXT2KG dataset
Args:
arango_url: ArangoDB connection URL
arango_db: ArangoDB database name
arango_user: ArangoDB username
arango_password: ArangoDB password
Returns:
List of triples in the format "subject predicate object"
"""
try:
# Connect to ArangoDB
client = ArangoClient(hosts=arango_url)
# Get database (no auth in our docker setup)
if arango_user and arango_password:
db = client.db(arango_db, username=arango_user, password=arango_password)
else:
db = client.db(arango_db)
# Query to get all triples from ArangoDB
# Handle case sensitivity, trim whitespace, and deduplication
aql_query = """
FOR e IN relationships
LET subject = TRIM(DOCUMENT(e._from).name)
LET object = TRIM(DOCUMENT(e._to).name)
LET predicate = TRIM(e.type)
FILTER subject != "" AND predicate != "" AND object != ""
COLLECT s = subject, p = predicate, o = object
RETURN CONCAT_SEPARATOR(" ", s, p, o)
"""
# Execute the query with streaming for large datasets
cursor = db.aql.execute(aql_query, stream=True, batch_size=1000)
triples = list(cursor)
print(f"Loaded {len(triples)} triples from ArangoDB")
# Print sample triples for debugging
if len(triples) > 0:
print("Sample triples:")
for i in range(min(3, len(triples))):
print(f" {triples[i]}")
return triples
except Exception as error:
print(f"Error loading triples from ArangoDB: {error}")
raise error
def get_data(args):
# need a JSON dict of Questions and answers, see below for how its used
with open(args.train_data_file) as file:
json_obj = json.load(file)
text_contexts = []
# need a folder of text files to use for RAG and to make a KG from
for file_path in glob(f"{args.corpus_dir}/*"):
with open(file_path, "r+") as f:
text_contexts.append(f.read())
return json_obj, text_contexts
def make_dataset(args):
"""Modified make_dataset function that can use ArangoDB as a data source"""
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
dataset_path = os.path.join(args.output_dir, args.dataset_file)
triples_path = os.path.join(args.output_dir, args.triples_file)
checkpoint_path = os.path.join(args.output_dir, args.checkpoint_file)
if os.path.exists(dataset_path):
print(f"Re-using Saved TechQA KG-RAG Dataset from {dataset_path}...")
return torch.load(dataset_path, weights_only=False)
else:
qa_pairs, context_docs = get_data(args)
print("Number of Docs in our VectorDB =", len(context_docs))
data_lists = {"train": [], "validation": [], "test": []}
# Load triples either from saved file or from sources
triples = []
if os.path.exists(triples_path):
triples = torch.load(triples_path, weights_only=False)
else:
if args.use_arango:
# Load triples from ArangoDB instead of generating with TXT2KG
print("Loading triples from ArangoDB...")
triples = load_triples_from_arangodb(
args.arango_url,
args.arango_db,
args.arango_user,
args.arango_password
)
# Save triples for future use
torch.save(triples, triples_path)
else:
# Original TXT2KG code path
from torch_geometric.nn import TXT2KG
kg_maker = TXT2KG(
NVIDIA_NIM_MODEL=args.NV_NIM_MODEL,
NVIDIA_API_KEY=args.NV_NIM_KEY,
ENDPOINT_URL=args.ENDPOINT_URL,
chunk_size=args.chunk_size
)
print(
"Note that if the TXT2KG process is too slow for you're liking using the public NIM, "
"consider deploying yourself using local_lm flag of TXT2KG or using "
"https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct?snippet_tab=Docker "
"to deploy to a private endpoint, which you can pass to this script w/ --ENDPOINT_URL flag."
)
total_tqdm_count = len(context_docs)
initial_tqdm_count = 0
if os.path.exists(checkpoint_path):
print(f"Restoring KG from checkpoint at {checkpoint_path}...")
saved_relevant_triples = torch.load(checkpoint_path, weights_only=False)
kg_maker.relevant_triples = saved_relevant_triples
kg_maker.doc_id_counter = len(saved_relevant_triples)
initial_tqdm_count = kg_maker.doc_id_counter
context_docs = context_docs[(kg_maker.doc_id_counter - 1):]
if args.checkpointing:
interval = 10
count = 0
for context_doc in tqdm(context_docs, total=total_tqdm_count,
initial=initial_tqdm_count, desc="Extracting KG triples"):
kg_maker.add_doc_2_KG(txt=context_doc)
if args.checkpointing:
count += 1
if count == interval:
print(f" checkpointing KG to {checkpoint_path}...")
count = 0
kg_maker.save_kg(checkpoint_path)
relevant_triples = kg_maker.relevant_triples
triples.extend(
list(
chain.from_iterable(
triple_set for triple_set in relevant_triples.values()
)
)
)
triples = list(dict.fromkeys(triples))
torch.save(triples, triples_path)
if args.checkpointing and os.path.exists(checkpoint_path):
os.remove(checkpoint_path)
print("Number of triples in our GraphDB =", len(triples))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sent_trans_batch_size = 256
model = SentenceTransformer(
model_name='Alibaba-NLP/gte-modernbert-base').to(device)
backend_path = os.path.join(args.output_dir, args.backend_path)
fs, gs = create_remote_backend_from_triplets(
triplets=triples,
node_embedding_model=model,
node_method_to_call="encode",
path=backend_path,
pre_transform=preprocess_triplet,
node_method_kwargs={
"batch_size": min(len(triples), sent_trans_batch_size)
},
graph_db=NeighborSamplingRAGGraphStore,
feature_db=ModernBertFeatureStore
).load()
# encode the raw context docs
embedded_docs = model.encode(
context_docs,
output_device=device,
batch_size=int(sent_trans_batch_size / 4),
verbose=True
)
# k for KNN
knn_neighsample_bs = 1024
# number of neighbors for each seed node selected by KNN
fanout = 100
# number of hops for neighborsampling
num_hops = 2
local_filter_kwargs = {
"topk": 5, # nodes
"topk_e": 5, # edges
"cost_e": .5, # edge cost
"num_clusters": 10, # num clusters
}
print("Now to retrieve context for each query from our Vector and Graph DBs...")
# GraphDB retrieval done with KNN+NeighborSampling+PCST
# PCST = Prize Collecting Steiner Tree
# VectorDB retrieval just vanilla RAG
query_loader = RAGQueryLoader(
data=(fs, gs),
seed_nodes_kwargs={"k_nodes": knn_neighsample_bs},
sampler_kwargs={"num_neighbors": [fanout] * num_hops},
local_filter=make_pcst_filter(triples, model),
local_filter_kwargs=local_filter_kwargs,
raw_docs=context_docs,
embedded_docs=embedded_docs
)
total_data_list = []
extracted_triple_sizes = []
for data_point in tqdm(qa_pairs, desc="Building un-split dataset"):
if data_point["is_impossible"]:
continue
QA_pair = (data_point["question"], data_point["answer"])
q = QA_pair[0]
subgraph = query_loader.query(q)
subgraph.label = QA_pair[1]
total_data_list.append(subgraph)
extracted_triple_sizes.append(len(subgraph.triples))
import random
random.shuffle(total_data_list)
print("Min # of Retrieved Triples =", min(extracted_triple_sizes))
print("Max # of Retrieved Triples =", max(extracted_triple_sizes))
print("Average # of Retrieved Triples =", sum(extracted_triple_sizes) / len(extracted_triple_sizes))
# 60:20:20 split
data_lists["train"] = total_data_list[:int(.6 * len(total_data_list))]
data_lists["validation"] = total_data_list[
int(.6 * len(total_data_list)):int(.8 * len(total_data_list))]
data_lists["test"] = total_data_list[int(.8 * len(total_data_list)):]
torch.save(data_lists, dataset_path)
del model
gc.collect()
torch.cuda.empty_cache()
return data_lists
if __name__ == '__main__':
# for reproducibility
seed_everything(50)
args = parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Process and save dataset
data_lists = make_dataset(args)
print(f"Dataset processed and saved to {os.path.join(args.output_dir, args.dataset_file)}")
print("Training data size:", len(data_lists["train"]))
print("Validation data size:", len(data_lists["validation"]))
print("Testing data size:", len(data_lists["test"]))