dgx-spark-playbooks/nvidia/station-kernel-dev-ft/assets/benchmark_kernels.py

193 lines
6.6 KiB
Python
Raw Normal View History

2026-05-26 18:25:53 +00:00
"""
Benchmark script for custom Triton kernels.
Measures latency, throughput, bandwidth utilization, and peak memory for the
fused RMSNorm and fused cross-entropy kernels, comparing against PyTorch
reference implementations.
Usage:
python benchmark_kernels.py --kernel rmsnorm
python benchmark_kernels.py --kernel cross_entropy
python benchmark_kernels.py --kernel all
"""
import argparse
import torch
from tabulate import tabulate
def benchmark_fn(fn, warmup=10, iters=100):
"""Time a GPU function using CUDA events for accurate measurement."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters # ms per iteration
def benchmark_rmsnorm():
"""Benchmark fused RMSNorm vs PyTorch reference."""
from rmsnorm_kernel import TritonRMSNorm
print("=" * 70)
print(" RMSNorm Benchmark — Custom Triton vs. PyTorch Reference")
print("=" * 70)
gpu_name = torch.cuda.get_device_name(0)
print(f" GPU: {gpu_name}\n")
hidden_size = 4096 # LLaMA 3.1 8B
eps = 1e-6
# Reference: PyTorch implementation (matches HuggingFace LlamaRMSNorm)
class RefRMSNorm(torch.nn.Module):
def __init__(self, hs, eps):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hs, device="cuda", dtype=torch.bfloat16))
self.eps = eps
def forward(self, x):
x32 = x.to(torch.float32)
v = x32.pow(2).mean(-1, keepdim=True)
return self.weight * (x32 * torch.rsqrt(v + self.eps)).to(x.dtype)
ref_norm = RefRMSNorm(hidden_size, eps)
custom_norm = TritonRMSNorm(hidden_size, eps).to(device="cuda", dtype=torch.bfloat16)
custom_norm.weight.data.copy_(ref_norm.weight.data)
rows = []
for num_tokens in [256, 1024, 4096, 16384]:
x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
grad = torch.randn_like(x)
# Benchmark forward + backward together (training workload)
def run_ref():
x_r = x.detach().requires_grad_(True)
y = ref_norm(x_r)
y.backward(grad)
def run_custom():
x_c = x.detach().requires_grad_(True)
y = custom_norm(x_c)
y.backward(grad)
ref_ms = benchmark_fn(run_ref)
custom_ms = benchmark_fn(run_custom)
# Bandwidth calculation:
# Forward: read x + read weight + write y + write rnorm = 2*N*H*2 + H*2 + N*4 bytes (BF16)
# Backward: read dy + read x + read w + read rnorm + write dx + write dw_rows
# Approximate total: ~6 * N * H * 2 bytes for fwd+bwd
total_bytes = 6 * num_tokens * hidden_size * 2
ref_gbps = total_bytes / (ref_ms * 1e-3) / 1e9
custom_gbps = total_bytes / (custom_ms * 1e-3) / 1e9
speedup = ref_ms / custom_ms
rows.append([
f"{num_tokens:,}",
f"{custom_ms * 1000:.1f}",
f"{ref_ms * 1000:.1f}",
f"{custom_gbps:.0f}",
f"{ref_gbps:.0f}",
f"{speedup:.2f}x",
])
headers = ["Tokens", "Custom (us)", "PyTorch (us)", "Custom (GB/s)", "PyTorch (GB/s)", "Speedup"]
print(tabulate(rows, headers=headers, tablefmt="simple"))
print()
def benchmark_cross_entropy():
"""Benchmark fused cross-entropy vs PyTorch reference."""
from cross_entropy_kernel import TritonCrossEntropyLoss
print("=" * 70)
print(" Cross-Entropy Benchmark — Custom Triton (online softmax) vs. PyTorch")
print("=" * 70)
gpu_name = torch.cuda.get_device_name(0)
print(f" GPU: {gpu_name}")
print(f" Vocabulary size: 128,256 (LLaMA 3.1)\n")
vocab_size = 128256
ref_ce = torch.nn.CrossEntropyLoss()
custom_ce = TritonCrossEntropyLoss()
rows = []
for num_tokens in [128, 256, 512, 1024]:
# --- Latency comparison ---
logits = torch.randn(num_tokens, vocab_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
targets = torch.randint(0, vocab_size, (num_tokens,), device="cuda")
grad = torch.tensor(1.0, device="cuda")
def run_ref():
l = logits.detach().requires_grad_(True)
loss = ref_ce(l.float(), targets) # PyTorch CE needs float32 internally
loss.backward(grad)
def run_custom():
l = logits.detach().requires_grad_(True)
loss = custom_ce(l, targets)
loss.backward(grad)
ref_ms = benchmark_fn(run_ref, warmup=5, iters=20)
custom_ms = benchmark_fn(run_custom, warmup=5, iters=20)
# --- Memory comparison ---
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
l_std = torch.randn(num_tokens, vocab_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
base_mem = torch.cuda.memory_allocated()
loss_std = ref_ce(l_std.float(), targets)
loss_std.backward()
std_peak = torch.cuda.max_memory_allocated() - base_mem
del l_std, loss_std
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
l_fused = torch.randn(num_tokens, vocab_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
base_mem = torch.cuda.memory_allocated()
loss_fused = custom_ce(l_fused, targets)
loss_fused.backward()
fused_peak = torch.cuda.max_memory_allocated() - base_mem
del l_fused, loss_fused
torch.cuda.empty_cache()
speedup = ref_ms / custom_ms
mem_reduction = std_peak / fused_peak if fused_peak > 0 else float("inf")
rows.append([
f"{num_tokens:,}",
f"{custom_ms * 1000:.0f}",
f"{ref_ms * 1000:.0f}",
f"{speedup:.2f}x",
f"{fused_peak / 1024 / 1024:.0f}",
f"{std_peak / 1024 / 1024:.0f}",
f"{mem_reduction:.1f}x",
])
headers = ["Tokens", "Custom (us)", "PyTorch (us)", "Speedup", "Custom Mem (MB)", "PyTorch Mem (MB)", "Mem Reduction"]
print(tabulate(rows, headers=headers, tablefmt="simple"))
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark custom Triton kernels")
parser.add_argument("--kernel", choices=["rmsnorm", "cross_entropy", "all"], default="all")
args = parser.parse_args()
if args.kernel in ("rmsnorm", "all"):
benchmark_rmsnorm()
if args.kernel in ("cross_entropy", "all"):
benchmark_cross_entropy()