3.6 KiB
GNN Training Pipeline (Experimental)
Status: This is an experimental feature for training Graph Neural Network models for enhanced RAG retrieval.
This pipeline provides a two-stage process for training GNN-based knowledge graph retrieval models:
- Data Preprocessing (
preprocess_data.py): Extracts knowledge graph triples from ArangoDB and prepares training datasets. - Model Training & Testing (
train_test_gnn.py): Trains and evaluates a GNN-based retriever model using PyTorch Geometric.
Prerequisites
- Python 3.8+
- ArangoDB running (can be set up using the provided docker-compose.yml)
- PyTorch and PyTorch Geometric installed
- All dependencies listed in requirements.txt
Installation
- Install the required dependencies:
pip install -r scripts/requirements.txt
- Ensure ArangoDB is running. You can use the main start script:
# From project root
./start.sh
Usage
Stage 1: Data Preprocessing
Run the preprocessing script to prepare the dataset:
python scripts/preprocess_data.py --use_arango --output_dir ./output
Loading data from ArangoDB
Use the --use_arango flag to load triples from ArangoDB instead of generating them with TXT2KG:
python scripts/preprocess_data.py --use_arango
The script will connect to ArangoDB using the default settings from docker-compose.yml:
- URL: http://localhost:8529
- Database: txt2kg
- No auth (username and password are empty)
Custom ArangoDB Connection
You can specify custom ArangoDB connection parameters:
python scripts/preprocess_data.py --use_arango --arango_url "http://localhost:8529" --arango_db "your_db" --arango_user "username" --arango_password "password"
Using Direct Triple Extraction
If you don't pass the --use_arango flag, the script will extract triples directly using the configured LLM provider.
Stage 2: Model Training & Testing
After preprocessing the data, train and test the model:
python scripts/train_test_gnn.py --output_dir ./output
Training Options
You can customize training with options:
python scripts/train_test_gnn.py --output_dir ./output --gnn_hidden_channels 2048 --num_gnn_layers 6 --epochs 5 --batch_size 2
Evaluation Only
To evaluate a previously trained model without retraining:
python scripts/train_test_gnn.py --output_dir ./output --eval_only
Expected Data Format in ArangoDB
The script expects ArangoDB to have:
- A document collection named
entitiescontaining nodes with anameproperty - An edge collection named
relationshipswhere:- Edges have a
typeproperty (the predicate/relationship type) - Edges connect from and to entities in the
entitiescollection
- Edges have a
How It Works
Data Preprocessing (preprocess_data.py)
- Connects to ArangoDB and queries all triples in the format "subject predicate object" (or generates them with TXT2KG)
- Creates a knowledge graph from these triples
- Prepares the dataset with training, validation, and test splits
Model Training & Testing (train_test_gnn.py)
- Loads the preprocessed dataset
- Initializes a GNN model (GAT architecture) and an LLM for generation
- Trains the model on the training set, validating on the validation set
- Evaluates the trained model on the test set using the LLMJudge for scoring
Limitations
- The script assumes that your ArangoDB instance contains data in the format described above
- You need to have both question-answer pairs and corpus documents available
- Make sure your ArangoDB contains knowledge graph triples relevant to your corpus
- Large LLM models require significant GPU memory for training