dgx-spark-playbooks/nvidia/station-kernel-dev-ft/assets/profile_baseline.py
2026-05-26 18:25:53 +00:00

173 lines
6.4 KiB
Python

"""
Profile a LLaMA 3.1 8B fine-tuning step to identify GPU bottlenecks.
Runs a single forward + backward + optimizer step under torch.profiler, then
exports a Chrome trace and prints a summary table of the most time-consuming
GPU operations. Supports optional flags to enable custom Triton kernels for
re-profiling after optimization.
Usage:
python profile_baseline.py # Baseline profile
python profile_baseline.py --use-custom-rmsnorm # With custom RMSNorm
python profile_baseline.py --use-custom-rmsnorm --use-custom-ce # With both custom kernels
python profile_baseline.py --batch-size 2 --seq-len 256 # Custom dimensions
"""
import argparse
import os
import shutil
import torch
from torch.profiler import profile, ProfilerActivity, schedule
def _filter_profiler_table(table: str) -> str:
"""Drop noisy driver / submission rows that confuse beginners (e.g. Command Buffer Full)."""
lines = table.splitlines()
out = []
for line in lines:
low = line.lower()
if "command buffer full" in low:
continue
out.append(line)
return "\n".join(out)
def replace_rmsnorm(model):
"""Replace all LlamaRMSNorm modules with the custom Triton implementation."""
from rmsnorm_kernel import TritonRMSNorm
# Collect replacements first, then apply — modifying the module tree during
# named_modules() iteration can cause skipped modules.
replacements = []
for name, module in model.named_modules():
if type(module).__name__ == "LlamaRMSNorm":
parent_name = ".".join(name.split(".")[:-1])
attr_name = name.split(".")[-1]
parent = model.get_submodule(parent_name) if parent_name else model
replacements.append((parent, attr_name, module))
for parent, attr_name, old_module in replacements:
setattr(parent, attr_name, TritonRMSNorm.from_llama_rmsnorm(old_module))
return len(replacements)
def main():
parser = argparse.ArgumentParser(description="Profile LLaMA 3.1 8B fine-tuning step")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size (default: 1)")
parser.add_argument("--seq-len", type=int, default=512, help="Sequence length (default: 512)")
parser.add_argument("--use-custom-rmsnorm", action="store_true", help="Use custom Triton RMSNorm")
parser.add_argument("--use-custom-ce", action="store_true", help="Use custom Triton cross-entropy")
args = parser.parse_args()
print("=" * 70)
print(" LLaMA 3.1 8B Fine-Tuning Profiler")
print("=" * 70)
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" Batch size: {args.batch_size}, Sequence length: {args.seq_len}")
print(f" Custom RMSNorm: {'ON' if args.use_custom_rmsnorm else 'OFF'}")
print(f" Custom Cross-Entropy: {'ON' if args.use_custom_ce else 'OFF'}")
print()
# Load model
print("Loading meta-llama/Llama-3.1-8B...")
from transformers import AutoModelForCausalLM, AutoConfig
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
device_map="cuda",
)
model.train()
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Apply custom kernels if requested
if args.use_custom_rmsnorm:
n = replace_rmsnorm(model)
print(f" Replaced {n} RMSNorm module(s) with custom Triton kernel")
if args.use_custom_ce:
from cross_entropy_kernel import TritonCrossEntropyLoss
custom_ce = TritonCrossEntropyLoss()
print(" Using custom Triton cross-entropy loss")
else:
custom_ce = None
# Create synthetic training data
input_ids = torch.randint(0, model.config.vocab_size, (args.batch_size, args.seq_len), device="cuda")
labels = input_ids.clone()
# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# Warm-up step (not profiled) — this triggers Triton JIT compilation
# and CUDA lazy initialization so they don't appear in the profile.
print("\nRunning warm-up step...")
if custom_ce:
outputs = model(input_ids=input_ids)
loss = custom_ce(outputs.logits[:, :-1, :].contiguous(), labels[:, 1:].contiguous())
else:
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
# Profiled step
print("Running profiled training step...")
os.makedirs("traces", exist_ok=True)
trace_name = "trace"
if args.use_custom_rmsnorm:
trace_name += "_custom_rmsnorm"
if args.use_custom_ce:
trace_name += "_custom_ce"
trace_dir = os.path.join("traces", trace_name)
chrome_trace_path = os.path.join("traces", f"{trace_name}_chrome.json")
# tensorboard_trace_handler and Chrome export fail on repeat runs if paths already exist.
if os.path.isdir(trace_dir):
shutil.rmtree(trace_dir)
elif os.path.isfile(trace_dir):
os.remove(trace_dir)
if os.path.isfile(chrome_trace_path):
os.remove(chrome_trace_path)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
schedule=schedule(wait=0, warmup=0, active=1, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(f"traces/{trace_name}"),
) as prof:
if custom_ce:
outputs = model(input_ids=input_ids)
loss = custom_ce(outputs.logits[:, :-1, :].contiguous(), labels[:, 1:].contiguous())
else:
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
prof.step()
# Also export a Chrome trace JSON for manual inspection.
prof.export_chrome_trace(chrome_trace_path)
# Print summary table
print(f"\nChrome trace saved to: {chrome_trace_path}")
print(f"TensorBoard trace saved to: traces/{trace_name}/")
print("\n" + "=" * 70)
print(" Top 20 CUDA Operations by Total GPU Time")
print("=" * 70)
raw_table = prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20)
print(_filter_profiler_table(raw_table))
# Print peak memory
peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
print(f"\nPeak GPU memory allocated: {peak_mem:.1f} GB")
if __name__ == "__main__":
main()