mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-22 18:13:52 +00:00
Merge 8740b3e30b into d0dbd18840
This commit is contained in:
commit
7a6477a4bf
@ -51,8 +51,8 @@ ALl files required for fine-tuning are included in the folder in [the GitHub rep
|
||||
|
||||
* **Time estimate:** 30-45 mins for setup and runing fine-tuning. Fine-tuning run time varies depending on model size
|
||||
* **Risks:** Model downloads can be large (several GB), ARM64 package compatibility issues may require troubleshooting.
|
||||
* **Last Updated:** 01/02/2025
|
||||
* Add two-Spark distributed finetuning example
|
||||
* **Last Updated:** 01/12/2025
|
||||
* Bug fix to ensure torch.compile does not break with LoRA.
|
||||
|
||||
## Instructions
|
||||
|
||||
@ -111,15 +111,75 @@ cd dgx-spark-playbooks/nvidia/pytorch-fine-tune/assets
|
||||
|
||||
## Step7: Run the fine-tuning recipes
|
||||
|
||||
To run LoRA on Llama3-8B use the following command:
|
||||
### Available Fine-Tuning Scripts
|
||||
|
||||
The following fine-tuning scripts are provided, each optimized for different model sizes and training approaches:
|
||||
|
||||
| Script | Model | Fine-Tuning Type | Description |
|
||||
|--------|-------|------------------|-------------|
|
||||
| `Llama3_3B_full_finetuning.py` | Llama 3.2 3B | Full SFT | Full supervised fine-tuning (all parameters trainable) |
|
||||
| `Llama3_8B_LoRA_finetuning.py` | Llama 3.1 8B | LoRA | Low-Rank Adaptation (parameter-efficient) |
|
||||
| `Llama3_70B_LoRA_finetuning.py` | Llama 3.1 70B | LoRA | Low-Rank Adaptation with FSDP support |
|
||||
| `Llama3_70B_qLoRA_finetuning.py` | Llama 3.1 70B | QLoRA | Quantized LoRA (4-bit quantization for memory efficiency) |
|
||||
|
||||
### Basic Usage
|
||||
|
||||
Run any script with default settings:
|
||||
|
||||
```bash
|
||||
# Full fine-tuning on Llama 3.2 3B
|
||||
python Llama3_3B_full_finetuning.py
|
||||
|
||||
# LoRA fine-tuning on Llama 3.1 8B
|
||||
python Llama3_8B_LoRA_finetuning.py
|
||||
|
||||
# LoRA fine-tuning on Llama 3.1 70B
|
||||
python Llama3_70B_LoRA_finetuning.py
|
||||
```
|
||||
|
||||
To run full fine-tuning on llama3-3B use the following command:
|
||||
### Common Command-Line Arguments
|
||||
|
||||
All scripts support the following command-line arguments for customization:
|
||||
|
||||
#### Model Configuration
|
||||
- `--model_name`: Model name or path (default: varies by script)
|
||||
- `--dtype`: Model precision - `float32`, `float16`, or `bfloat16` (default: `bfloat16`)
|
||||
|
||||
#### Training Configuration
|
||||
- `--batch_size`: Per-device training batch size (default: varies by script)
|
||||
- `--seq_length`: Maximum sequence length (default: `2048`)
|
||||
- `--num_epochs`: Number of training epochs (default: `1`)
|
||||
- `--gradient_accumulation_steps`: Gradient accumulation steps (default: `1`)
|
||||
- `--learning_rate`: Learning rate (default: varies by script)
|
||||
- `--gradient_checkpointing`: Enable gradient checkpointing to save memory (flag)
|
||||
|
||||
#### LoRA Configuration (LoRA and QLoRA scripts only)
|
||||
- `--lora_rank`: LoRA rank - higher values = more trainable parameters (default: `8`)
|
||||
|
||||
#### Dataset Configuration
|
||||
- `--dataset_size`: Number of samples to use from the Alpaca dataset (default: `500`)
|
||||
|
||||
#### Logging Configuration
|
||||
- `--logging_steps`: Log metrics every N steps (default: `1`)
|
||||
- `--log_dir`: Directory for TensorBoard logs (default: `logs`)
|
||||
|
||||
#### Model Saving
|
||||
- `--output_dir`: Directory to save the fine-tuned model (default: `None` - model not saved)
|
||||
|
||||
#### Performance Optimization
|
||||
- `--use_torch_compile`: Enable `torch.compile()` for faster training (flag)
|
||||
|
||||
> [!WARNING]
|
||||
> **Important:** The `--use_torch_compile` flag is **not compatible with QLoRA** (`Llama3_70B_qLoRA_finetuning.py`).
|
||||
> Only use this flag with full fine-tuning and standard LoRA scripts.
|
||||
|
||||
### Usage Examples
|
||||
```bash
|
||||
python Llama3_3B_full_finetuning.py
|
||||
```
|
||||
python Llama3_8B_LoRA_finetuning.py \
|
||||
--dataset_size 100 \
|
||||
--num_epochs 1 \
|
||||
--batch_size 2
|
||||
```
|
||||
|
||||
## Run on two Sparks
|
||||
|
||||
|
||||
@ -77,6 +77,7 @@ def main(args):
|
||||
lora_dropout=0,
|
||||
task_type=TaskType.CAUSAL_LM
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
@ -130,7 +131,6 @@ def main(args):
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(**config),
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
@ -220,4 +220,4 @@ if __name__ == "__main__":
|
||||
print(f"Torch compile: {args.use_torch_compile}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
@ -62,6 +62,8 @@ def main(args):
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
task_type=TaskType.CAUSAL_LM)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
print(f"Trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
||||
|
||||
# Load and preprocess the dataset
|
||||
@ -110,7 +112,6 @@ def main(args):
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(**config),
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
@ -183,4 +184,4 @@ if __name__ == "__main__":
|
||||
print(f"Torch compile: {args.use_torch_compile}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
Loading…
Reference in New Issue
Block a user