This commit is contained in:
lramesh-2409 2026-01-14 16:09:10 +00:00 committed by GitHub
commit 7a6477a4bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 10 deletions

View File

@ -51,8 +51,8 @@ ALl files required for fine-tuning are included in the folder in [the GitHub rep
* **Time estimate:** 30-45 mins for setup and runing fine-tuning. Fine-tuning run time varies depending on model size
* **Risks:** Model downloads can be large (several GB), ARM64 package compatibility issues may require troubleshooting.
* **Last Updated:** 01/02/2025
* Add two-Spark distributed finetuning example
* **Last Updated:** 01/12/2025
* Bug fix to ensure torch.compile does not break with LoRA.
## Instructions
@ -111,15 +111,75 @@ cd dgx-spark-playbooks/nvidia/pytorch-fine-tune/assets
## Step7: Run the fine-tuning recipes
To run LoRA on Llama3-8B use the following command:
### Available Fine-Tuning Scripts
The following fine-tuning scripts are provided, each optimized for different model sizes and training approaches:
| Script | Model | Fine-Tuning Type | Description |
|--------|-------|------------------|-------------|
| `Llama3_3B_full_finetuning.py` | Llama 3.2 3B | Full SFT | Full supervised fine-tuning (all parameters trainable) |
| `Llama3_8B_LoRA_finetuning.py` | Llama 3.1 8B | LoRA | Low-Rank Adaptation (parameter-efficient) |
| `Llama3_70B_LoRA_finetuning.py` | Llama 3.1 70B | LoRA | Low-Rank Adaptation with FSDP support |
| `Llama3_70B_qLoRA_finetuning.py` | Llama 3.1 70B | QLoRA | Quantized LoRA (4-bit quantization for memory efficiency) |
### Basic Usage
Run any script with default settings:
```bash
# Full fine-tuning on Llama 3.2 3B
python Llama3_3B_full_finetuning.py
# LoRA fine-tuning on Llama 3.1 8B
python Llama3_8B_LoRA_finetuning.py
# LoRA fine-tuning on Llama 3.1 70B
python Llama3_70B_LoRA_finetuning.py
```
To run full fine-tuning on llama3-3B use the following command:
### Common Command-Line Arguments
All scripts support the following command-line arguments for customization:
#### Model Configuration
- `--model_name`: Model name or path (default: varies by script)
- `--dtype`: Model precision - `float32`, `float16`, or `bfloat16` (default: `bfloat16`)
#### Training Configuration
- `--batch_size`: Per-device training batch size (default: varies by script)
- `--seq_length`: Maximum sequence length (default: `2048`)
- `--num_epochs`: Number of training epochs (default: `1`)
- `--gradient_accumulation_steps`: Gradient accumulation steps (default: `1`)
- `--learning_rate`: Learning rate (default: varies by script)
- `--gradient_checkpointing`: Enable gradient checkpointing to save memory (flag)
#### LoRA Configuration (LoRA and QLoRA scripts only)
- `--lora_rank`: LoRA rank - higher values = more trainable parameters (default: `8`)
#### Dataset Configuration
- `--dataset_size`: Number of samples to use from the Alpaca dataset (default: `500`)
#### Logging Configuration
- `--logging_steps`: Log metrics every N steps (default: `1`)
- `--log_dir`: Directory for TensorBoard logs (default: `logs`)
#### Model Saving
- `--output_dir`: Directory to save the fine-tuned model (default: `None` - model not saved)
#### Performance Optimization
- `--use_torch_compile`: Enable `torch.compile()` for faster training (flag)
> [!WARNING]
> **Important:** The `--use_torch_compile` flag is **not compatible with QLoRA** (`Llama3_70B_qLoRA_finetuning.py`).
> Only use this flag with full fine-tuning and standard LoRA scripts.
### Usage Examples
```bash
python Llama3_3B_full_finetuning.py
```
python Llama3_8B_LoRA_finetuning.py \
--dataset_size 100 \
--num_epochs 1 \
--batch_size 2
```
## Run on two Sparks

View File

@ -77,6 +77,7 @@ def main(args):
lora_dropout=0,
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, peft_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
@ -130,7 +131,6 @@ def main(args):
processing_class=tokenizer,
train_dataset=dataset,
args=SFTConfig(**config),
peft_config=peft_config,
)
trainer_stats = trainer.train()
@ -220,4 +220,4 @@ if __name__ == "__main__":
print(f"Torch compile: {args.use_torch_compile}")
print(f"{'='*60}\n")
main(args)
main(args)

View File

@ -62,6 +62,8 @@ def main(args):
lora_alpha=16,
lora_dropout=0,
task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, peft_config)
print(f"Trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# Load and preprocess the dataset
@ -110,7 +112,6 @@ def main(args):
processing_class=tokenizer,
train_dataset=dataset,
args=SFTConfig(**config),
peft_config=peft_config,
)
trainer_stats = trainer.train()
@ -183,4 +184,4 @@ if __name__ == "__main__":
print(f"Torch compile: {args.use_torch_compile}")
print(f"{'='*60}\n")
main(args)
main(args)