mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-22 18:13:52 +00:00
187 lines
6.8 KiB
Python
187 lines
6.8 KiB
Python
#
|
|
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import torch
|
|
import argparse
|
|
from datasets import load_dataset
|
|
from trl import SFTConfig, SFTTrainer
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from peft import get_peft_model, LoraConfig, TaskType
|
|
|
|
|
|
# Define prompt templates
|
|
ALPACA_PROMPT_TEMPLATE = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
|
### Instruction: {}
|
|
|
|
### Input: {}
|
|
|
|
### Response: {}"""
|
|
|
|
def get_alpaca_dataset(eos_token, dataset_size=500):
|
|
# Preprocess the dataset
|
|
def preprocess(x):
|
|
texts = [
|
|
ALPACA_PROMPT_TEMPLATE.format(instruction, input, output) + eos_token
|
|
for instruction, input, output in zip(x["instruction"], x["input"], x["output"])
|
|
]
|
|
return {"text": texts}
|
|
|
|
dataset = load_dataset("tatsu-lab/alpaca", split="train").select(range(dataset_size)).shuffle(seed=42)
|
|
return dataset.map(preprocess, remove_columns=dataset.column_names, batched=True)
|
|
|
|
|
|
def main(args):
|
|
# Load the model and tokenizer
|
|
print(f"Loading model: {args.model_name}")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_name,
|
|
dtype=args.dtype,
|
|
device_map="auto"
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# Configure LoRA config
|
|
peft_config = LoraConfig(
|
|
r=args.lora_rank,
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
|
lora_alpha=16,
|
|
lora_dropout=0,
|
|
task_type=TaskType.CAUSAL_LM)
|
|
print(f"Trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
|
|
|
# Load and preprocess the dataset
|
|
print(f"Loading dataset with {args.dataset_size} samples...")
|
|
dataset = get_alpaca_dataset(tokenizer.eos_token, args.dataset_size)
|
|
|
|
# Configure the SFT config
|
|
config = {
|
|
"per_device_train_batch_size": args.batch_size,
|
|
"num_train_epochs": 0.01,
|
|
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
|
"learning_rate": args.learning_rate,
|
|
"optim": "adamw_torch",
|
|
"save_strategy": 'no',
|
|
"remove_unused_columns": False,
|
|
"seed": 42,
|
|
"dataset_text_field": "text",
|
|
"packing": False,
|
|
"max_length": args.seq_length,
|
|
"torch_compile": False,
|
|
"report_to": "none",
|
|
"logging_dir": args.log_dir,
|
|
"logging_steps": args.logging_steps
|
|
}
|
|
|
|
# Compile model if requested
|
|
if args.use_torch_compile:
|
|
print("Compiling model with torch.compile()...")
|
|
model = torch.compile(model)
|
|
|
|
# Warmup for torch compile
|
|
print("Running warmup for torch.compile()...")
|
|
SFTTrainer(
|
|
model=model,
|
|
processing_class=tokenizer,
|
|
train_dataset=dataset,
|
|
args=SFTConfig(**config),
|
|
).train()
|
|
|
|
# Train the model
|
|
print(f"\nStarting LoRA fine-tuning for {args.num_epochs} epoch(s)...")
|
|
config["num_train_epochs"] = args.num_epochs
|
|
config["report_to"] = "tensorboard"
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
processing_class=tokenizer,
|
|
train_dataset=dataset,
|
|
args=SFTConfig(**config),
|
|
peft_config=peft_config,
|
|
)
|
|
|
|
trainer_stats = trainer.train()
|
|
|
|
# Print training statistics
|
|
print(f"\n{'='*60}")
|
|
print("TRAINING COMPLETED")
|
|
print(f"{'='*60}")
|
|
print(f"Training runtime: {trainer_stats.metrics['train_runtime']:.2f} seconds")
|
|
print(f"Samples per second: {trainer_stats.metrics['train_samples_per_second']:.2f}")
|
|
print(f"Steps per second: {trainer_stats.metrics['train_steps_per_second']:.2f}")
|
|
print(f"Train loss: {trainer_stats.metrics['train_loss']:.4f}")
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description="Llama 3.1 8B Fine-tuning with LoRA")
|
|
|
|
# Model configuration
|
|
parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B-Instruct",
|
|
help="Model name or path")
|
|
parser.add_argument("--dtype", type=str, default="bfloat16",
|
|
choices=["float32", "float16", "bfloat16"],
|
|
help="Model dtype")
|
|
|
|
# Training configuration
|
|
parser.add_argument("--batch_size", type=int, default=4,
|
|
help="Per device training batch size")
|
|
parser.add_argument("--seq_length", type=int, default=2048,
|
|
help="Maximum sequence length")
|
|
parser.add_argument("--num_epochs", type=int, default=1,
|
|
help="Number of training epochs")
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
|
help="Gradient accumulation steps")
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4,
|
|
help="Learning rate")
|
|
|
|
# LoRA configuration
|
|
parser.add_argument("--lora_rank", type=int, default=8,
|
|
help="LoRA rank")
|
|
|
|
# Dataset configuration
|
|
parser.add_argument("--dataset_size", type=int, default=500,
|
|
help="Number of samples to use from dataset")
|
|
|
|
# Logging configuration
|
|
parser.add_argument("--logging_steps", type=int, default=1,
|
|
help="Log every N steps")
|
|
parser.add_argument("--log_dir", type=str, default="logs",
|
|
help="Directory for logs")
|
|
# Compilation
|
|
parser.add_argument("--use_torch_compile", action="store_true",
|
|
help="Use torch.compile() for faster training")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_arguments()
|
|
print(f"\n{'='*60}")
|
|
print("LLAMA 3.1 8B LoRA FINE-TUNING CONFIGURATION")
|
|
print(f"{'='*60}")
|
|
print(f"Model: {args.model_name}")
|
|
print(f"Batch size: {args.batch_size}")
|
|
print(f"Sequence length: {args.seq_length}")
|
|
print(f"Number of epochs: {args.num_epochs}")
|
|
print(f"Learning rate: {args.learning_rate}")
|
|
print(f"LoRA rank: {args.lora_rank}")
|
|
print(f"Dataset size: {args.dataset_size}")
|
|
print(f"Torch compile: {args.use_torch_compile}")
|
|
print(f"{'='*60}\n")
|
|
|
|
main(args)
|