dgx-spark-playbooks/nvidia/station-kernel-dev-ft/assets/cross_entropy_test.py
2026-05-26 18:25:53 +00:00

129 lines
4.6 KiB
Python

"""
Correctness and memory tests for the fused cross-entropy Triton kernel.
Compares loss values and gradients against torch.nn.CrossEntropyLoss, and
reports peak GPU memory usage for both approaches to demonstrate the memory
savings from online softmax.
"""
import torch
from cross_entropy_kernel import TritonCrossEntropyLoss
def test_cross_entropy_correctness(dtype, loss_atol, loss_rtol, grad_atol, grad_rtol):
"""Test forward loss and backward gradients for a given dtype.
BF16 logits use a looser loss tolerance (online softmax vs. PyTorch's path).
Gradients are compared in float32 to avoid BF16 rounding false failures on
128K-wide reductions.
"""
dtype_name = "FP32" if dtype == torch.float32 else "BF16"
vocab_size = 128256 # LLaMA 3.1 vocabulary size
num_tokens = 512 # batch_size * seq_len
# Create random logits and targets.
logits = torch.randn(num_tokens, vocab_size, device="cuda", dtype=dtype, requires_grad=True)
targets = torch.randint(0, vocab_size, (num_tokens,), device="cuda")
# --- Reference: PyTorch CrossEntropyLoss ---
logits_ref = logits.detach().clone().requires_grad_(True)
ref_loss = torch.nn.CrossEntropyLoss()(logits_ref, targets)
ref_loss.backward()
# --- Custom: Triton fused cross-entropy ---
logits_custom = logits.detach().clone().requires_grad_(True)
custom_ce = TritonCrossEntropyLoss()
custom_loss = custom_ce(logits_custom, targets)
custom_loss.backward()
# Compare forward loss values.
loss_diff = abs(ref_loss.item() - custom_loss.item())
loss_ok = torch.allclose(
ref_loss.detach().float(),
custom_loss.detach().float(),
atol=loss_atol,
rtol=loss_rtol,
)
print(f" {dtype_name} Loss — ref: {ref_loss.item():.6f} custom: {custom_loss.item():.6f} diff: {loss_diff:.2e} {'PASSED' if loss_ok else 'FAILED'}")
assert loss_ok, f"Loss correctness check FAILED for {dtype_name}"
# Compare backward gradients in FP32 (BF16 allclose on raw grads is too strict for V-wide CE).
g_ref = logits_ref.grad.float()
g_custom = logits_custom.grad.float()
grad_diff = (g_ref - g_custom).abs().max().item()
grad_ok = torch.allclose(g_ref, g_custom, atol=grad_atol, rtol=grad_rtol)
print(f" {dtype_name} Gradient — max diff (vs ref, fp32 compare): {grad_diff:.2e} {'PASSED' if grad_ok else 'FAILED'}")
assert grad_ok, f"Gradient correctness check FAILED for {dtype_name}"
def test_memory_savings():
"""Compare peak GPU memory between standard and fused cross-entropy."""
vocab_size = 128256
num_tokens = 512
print("\nMemory Comparison")
print("-" * 60)
# --- Standard PyTorch cross-entropy ---
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
logits_std = torch.randn(num_tokens, vocab_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
targets = torch.randint(0, vocab_size, (num_tokens,), device="cuda")
baseline_before = torch.cuda.memory_allocated()
loss_std = torch.nn.CrossEntropyLoss()(logits_std, targets)
loss_std.backward()
std_peak = torch.cuda.max_memory_allocated()
del logits_std, loss_std
torch.cuda.empty_cache()
# --- Custom fused cross-entropy ---
torch.cuda.reset_peak_memory_stats()
logits_fused = torch.randn(num_tokens, vocab_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
fused_before = torch.cuda.memory_allocated()
custom_ce = TritonCrossEntropyLoss()
loss_fused = custom_ce(logits_fused, targets)
loss_fused.backward()
fused_peak = torch.cuda.max_memory_allocated()
std_mb = std_peak / 1024 / 1024
fused_mb = fused_peak / 1024 / 1024
reduction = std_peak / fused_peak if fused_peak > 0 else float("inf")
print(f" Standard PyTorch CE — peak memory: {std_mb:.1f} MB")
print(f" Fused Triton CE — peak memory: {fused_mb:.1f} MB")
print(f" Memory reduction: {reduction:.1f}x")
if __name__ == "__main__":
print("Cross-Entropy Correctness Tests")
print("=" * 60)
print("\nTest 1: Float32")
# log-sum-exp over 128K elements accumulates small drift vs PyTorch's CE path.
test_cross_entropy_correctness(
dtype=torch.float32,
loss_atol=1e-4,
loss_rtol=1e-4,
grad_atol=1e-4,
grad_rtol=1e-4,
)
print("\nTest 2: BFloat16 (relaxed tolerance)")
test_cross_entropy_correctness(
dtype=torch.bfloat16,
loss_atol=5e-2,
loss_rtol=5e-2,
grad_atol=2e-1,
grad_rtol=2e-1,
)
test_memory_savings()
print("\n" + "=" * 60)
print("All cross-entropy tests PASSED")