mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-22 14:19:30 +00:00
94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
import argparse
|
|
|
|
import torch
|
|
|
|
from megatron.bridge.recipes.llama import llama3_8b_pretrain_config
|
|
from megatron.bridge.training.gpt_step import forward_step
|
|
from megatron.bridge.training.pretrain import pretrain
|
|
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, bf16_mixed, bf16_with_nvfp4_mixed
|
|
|
|
|
|
def nvfp4_mixed_precision() -> MixedPrecisionConfig:
|
|
"""NVFP4 mixed precision config with last 4 layers in BF16."""
|
|
cfg = bf16_with_nvfp4_mixed()
|
|
cfg.first_last_layers_bf16 = True
|
|
cfg.num_layers_at_start_in_bf16 = 0
|
|
cfg.num_layers_at_end_in_bf16 = 4
|
|
return cfg
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Llama 3.1 8B NVFP4 pretraining")
|
|
parser.add_argument(
|
|
"--disable-fp4",
|
|
action="store_true",
|
|
help="Disable NVFP4; use plain BF16 mixed precision as a baseline",
|
|
)
|
|
parser.add_argument(
|
|
"--train-iters",
|
|
type=int,
|
|
default=50,
|
|
help="Number of training iterations",
|
|
)
|
|
parser.add_argument(
|
|
"--warmup-iters",
|
|
type=int,
|
|
default=2,
|
|
help="Number of warmup iterations",
|
|
)
|
|
parser.add_argument(
|
|
"--global-batch-size",
|
|
type=int,
|
|
default=64,
|
|
help="Global batch size",
|
|
)
|
|
parser.add_argument(
|
|
"--micro-batch-size",
|
|
type=int,
|
|
default=4,
|
|
help="Micro batch size (drives peak VRAM; increase to use more memory)",
|
|
)
|
|
parser.add_argument(
|
|
"--seq-length",
|
|
type=int,
|
|
default=4096,
|
|
help="Sequence length (recipe default is 8192; halved here to fit single GPU)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
config = llama3_8b_pretrain_config()
|
|
|
|
# Single-GPU override: recipe defaults to context_parallel_size=2
|
|
config.model.context_parallel_size = 1
|
|
|
|
config.model.seq_length = args.seq_length
|
|
config.dataset.sequence_length = args.seq_length
|
|
|
|
config.train.train_iters = args.train_iters
|
|
config.scheduler.lr_warmup_iters = args.warmup_iters
|
|
config.train.global_batch_size = args.global_batch_size
|
|
config.train.micro_batch_size = args.micro_batch_size
|
|
|
|
config.logger.log_interval = 1
|
|
config.dataset.num_workers = 2
|
|
config.train.eval_iters = 0
|
|
|
|
if args.disable_fp4:
|
|
config.mixed_precision = bf16_mixed()
|
|
else:
|
|
config.mixed_precision = nvfp4_mixed_precision()
|
|
|
|
pretrain(config=config, forward_step_func=forward_step)
|
|
|
|
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
|
|
print(
|
|
f"FINAL mem-reserved-gigabytes: {torch.cuda.memory_reserved() / 1e9:.3f} | "
|
|
f"mem-max-reserved-gigabytes: {torch.cuda.max_memory_reserved() / 1e9:.3f} | "
|
|
f"mem-max-allocated-gigabytes: {torch.cuda.max_memory_allocated() / 1e9:.3f}",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|