2025-11-20 18:30:59 +00:00
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import argparse
import os
from datasets import load_dataset
from trl import SFTConfig , SFTTrainer
from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig
from peft import get_peft_model , LoraConfig , TaskType , prepare_model_for_kbit_training
# Define prompt templates
ALPACA_PROMPT_TEMPLATE = """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction: {}
### Input: {}
### Response: {}"""
def get_alpaca_dataset ( eos_token , dataset_size = 500 ) :
# Preprocess the dataset
def preprocess ( x ) :
texts = [
ALPACA_PROMPT_TEMPLATE . format ( instruction , input , output ) + eos_token
for instruction , input , output in zip ( x [ " instruction " ] , x [ " input " ] , x [ " output " ] )
]
return { " text " : texts }
dataset = load_dataset ( " tatsu-lab/alpaca " , split = " train " ) . select ( range ( dataset_size ) ) . shuffle ( seed = 42 )
return dataset . map ( preprocess , remove_columns = dataset . column_names , batched = True )
def main ( args ) :
# Load the model and tokenizer
print ( f " Loading model: { args . model_name } " )
print ( f " Training mode: LoRA " )
# When using FSDP, don't use device_map to avoid loading full model on one device
# FSDP will handle sharding and device placement
# With fsdp_cpu_ram_efficient_loading=true in config, model loads on meta device first,
# then FSDP shards and materializes it across all devices/nodes
# This prevents OOM when a single node can't hold the full model
model = AutoModelForCausalLM . from_pretrained (
args . model_name ,
dtype = args . dtype ,
low_cpu_mem_usage = True , # Use lazy loading to reduce memory usage
trust_remote_code = True
)
tokenizer = AutoTokenizer . from_pretrained ( args . model_name , trust_remote_code = True )
tokenizer . pad_token = tokenizer . eos_token
# Prepare model for LoRA training
print ( f " Preparing model for LoRA with rank { args . lora_rank } ... " )
peft_config = LoraConfig (
r = args . lora_rank ,
target_modules = [ " q_proj " , " k_proj " , " v_proj " , " o_proj " ,
" gate_proj " , " up_proj " , " down_proj " ] ,
lora_alpha = 16 ,
lora_dropout = 0 ,
task_type = TaskType . CAUSAL_LM
)
2026-01-15 18:12:01 +00:00
model = get_peft_model ( model , peft_config )
2025-11-20 18:30:59 +00:00
trainable_params = sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad )
total_params = sum ( p . numel ( ) for p in model . parameters ( ) )
print ( f " Trainable parameters: { trainable_params : , } ( { 100 * trainable_params / total_params : .2f } %) " )
# Load and preprocess the dataset
print ( f " Loading dataset with { args . dataset_size } samples... " )
dataset = get_alpaca_dataset ( tokenizer . eos_token , args . dataset_size )
# Configure the SFT config
config = {
" per_device_train_batch_size " : args . batch_size ,
" num_train_epochs " : 0.01 , # Warmup epoch
" gradient_accumulation_steps " : args . gradient_accumulation_steps ,
" learning_rate " : args . learning_rate ,
" optim " : " adamw_torch " ,
" save_strategy " : ' no ' ,
" remove_unused_columns " : False ,
" seed " : 42 ,
" dataset_text_field " : " text " ,
" packing " : False ,
" max_length " : args . seq_length ,
" torch_compile " : False ,
" report_to " : " none " ,
" logging_dir " : args . log_dir ,
" logging_steps " : args . logging_steps ,
" gradient_checkpointing " : args . gradient_checkpointing
}
# Compile model if requested
if args . use_torch_compile :
print ( " Compiling model with torch.compile()... " )
model = torch . compile ( model )
# Warmup for torch compile
print ( " Running warmup for torch.compile()... " )
SFTTrainer (
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
args = SFTConfig ( * * config ) ,
) . train ( )
# Train the model
print ( f " \n Starting LoRA fine-tuning for { args . num_epochs } epoch(s)... " )
config [ " num_train_epochs " ] = args . num_epochs
config [ " report_to " ] = " tensorboard "
trainer = SFTTrainer (
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
args = SFTConfig ( * * config ) ,
)
trainer_stats = trainer . train ( )
# Print training statistics
print ( f " \n { ' = ' * 60 } " )
print ( " TRAINING COMPLETED " )
print ( f " { ' = ' * 60 } " )
print ( f " Training runtime: { trainer_stats . metrics [ ' train_runtime ' ] : .2f } seconds " )
print ( f " Samples per second: { trainer_stats . metrics [ ' train_samples_per_second ' ] : .2f } " )
print ( f " Steps per second: { trainer_stats . metrics [ ' train_steps_per_second ' ] : .2f } " )
print ( f " Train loss: { trainer_stats . metrics [ ' train_loss ' ] : .4f } " )
print ( f " { ' = ' * 60 } \n " )
# Save model if requested
if args . output_dir :
print ( f " Saving model to { args . output_dir } ... " )
trainer . save_model ( args . output_dir )
tokenizer . save_pretrained ( args . output_dir )
print ( " Model saved successfully! " )
def parse_arguments ( ) :
parser = argparse . ArgumentParser ( description = " Llama 3.1 70B Fine-tuning with LoRA " )
# Model configuration
parser . add_argument ( " --model_name " , type = str , default = " meta-llama/Llama-3.1-70B-Instruct " ,
help = " Model name or path " )
parser . add_argument ( " --dtype " , type = str , default = " bfloat16 " ,
help = " Model dtype (e.g., float32, float16, bfloat16) " )
# Training configuration
parser . add_argument ( " --batch_size " , type = int , default = 4 ,
choices = [ 1 , 2 , 4 , 8 , 16 , 32 ] ,
help = " Per device training batch size " )
parser . add_argument ( " --seq_length " , type = int , default = 2048 ,
choices = [ 256 , 512 , 1024 , 2048 , 4096 , 8192 ] ,
help = " Maximum sequence length " )
parser . add_argument ( " --num_epochs " , type = int , default = 1 ,
help = " Number of training epochs " )
parser . add_argument ( " --gradient_accumulation_steps " , type = int , default = 1 ,
help = " Gradient accumulation steps " )
parser . add_argument ( " --learning_rate " , type = float , default = 1e-4 ,
help = " Learning rate " )
parser . add_argument ( " --gradient_checkpointing " , action = argparse . BooleanOptionalAction , default = True ,
help = " Enable gradient checkpointing to save memory (default: enabled) " )
# LoRA configuration
parser . add_argument ( " --lora_rank " , type = int , default = 8 ,
help = " LoRA rank " )
# Dataset configuration
parser . add_argument ( " --dataset_size " , type = int , default = 500 ,
help = " Number of samples to use from dataset " )
# Logging configuration
parser . add_argument ( " --logging_steps " , type = int , default = 1 ,
help = " Log every N steps " )
parser . add_argument ( " --log_dir " , type = str , default = " logs " ,
help = " Directory for logs " )
# Compilation and saving
parser . add_argument ( " --use_torch_compile " , action = " store_true " ,
help = " Use torch.compile() for faster training " )
parser . add_argument ( " --output_dir " , type = str , default = None ,
help = " Directory to save the fine-tuned model " )
return parser . parse_args ( )
if __name__ == " __main__ " :
args = parse_arguments ( )
print ( f " \n { ' = ' * 60 } " )
print ( " LLAMA 3.1 70B LoRA FINE-TUNING " )
print ( f " { ' = ' * 60 } " )
print ( f " Model: { args . model_name } " )
print ( f " Training mode: LoRA " )
print ( f " Batch size: { args . batch_size } " )
print ( f " Gradient accumulation: { args . gradient_accumulation_steps } " )
print ( f " Effective batch size: { args . batch_size * args . gradient_accumulation_steps } " )
print ( f " Sequence length: { args . seq_length } " )
print ( f " Number of epochs: { args . num_epochs } " )
print ( f " Learning rate: { args . learning_rate } " )
print ( f " LoRA rank: { args . lora_rank } " )
print ( f " Dataset size: { args . dataset_size } " )
print ( f " Gradient checkpointing: { args . gradient_checkpointing } " )
print ( f " Torch compile: { args . use_torch_compile } " )
print ( f " { ' = ' * 60 } \n " )
2026-01-15 18:12:01 +00:00
main ( args )