dgx-spark-playbooks/nvidia/station-nvfp4-pretraining/assets/pretrain_llama.py
2026-05-26 18:25:53 +00:00

94 lines
2.8 KiB
Python

import argparse
import torch
from megatron.bridge.recipes.llama import llama3_8b_pretrain_config
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, bf16_mixed, bf16_with_nvfp4_mixed
def nvfp4_mixed_precision() -> MixedPrecisionConfig:
"""NVFP4 mixed precision config with last 4 layers in BF16."""
cfg = bf16_with_nvfp4_mixed()
cfg.first_last_layers_bf16 = True
cfg.num_layers_at_start_in_bf16 = 0
cfg.num_layers_at_end_in_bf16 = 4
return cfg
def main() -> None:
parser = argparse.ArgumentParser(description="Llama 3.1 8B NVFP4 pretraining")
parser.add_argument(
"--disable-fp4",
action="store_true",
help="Disable NVFP4; use plain BF16 mixed precision as a baseline",
)
parser.add_argument(
"--train-iters",
type=int,
default=50,
help="Number of training iterations",
)
parser.add_argument(
"--warmup-iters",
type=int,
default=2,
help="Number of warmup iterations",
)
parser.add_argument(
"--global-batch-size",
type=int,
default=64,
help="Global batch size",
)
parser.add_argument(
"--micro-batch-size",
type=int,
default=4,
help="Micro batch size (drives peak VRAM; increase to use more memory)",
)
parser.add_argument(
"--seq-length",
type=int,
default=4096,
help="Sequence length (recipe default is 8192; halved here to fit single GPU)",
)
args = parser.parse_args()
config = llama3_8b_pretrain_config()
# Single-GPU override: recipe defaults to context_parallel_size=2
config.model.context_parallel_size = 1
config.model.seq_length = args.seq_length
config.dataset.sequence_length = args.seq_length
config.train.train_iters = args.train_iters
config.scheduler.lr_warmup_iters = args.warmup_iters
config.train.global_batch_size = args.global_batch_size
config.train.micro_batch_size = args.micro_batch_size
config.logger.log_interval = 1
config.dataset.num_workers = 2
config.train.eval_iters = 0
if args.disable_fp4:
config.mixed_precision = bf16_mixed()
else:
config.mixed_precision = nvfp4_mixed_precision()
pretrain(config=config, forward_step_func=forward_step)
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
print(
f"FINAL mem-reserved-gigabytes: {torch.cuda.memory_reserved() / 1e9:.3f} | "
f"mem-max-reserved-gigabytes: {torch.cuda.max_memory_reserved() / 1e9:.3f} | "
f"mem-max-allocated-gigabytes: {torch.cuda.max_memory_allocated() / 1e9:.3f}",
flush=True,
)
if __name__ == "__main__":
main()