mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-18 12:32:23 +00:00
129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
"""
|
|
Correctness and memory tests for the fused cross-entropy Triton kernel.
|
|
|
|
Compares loss values and gradients against torch.nn.CrossEntropyLoss, and
|
|
reports peak GPU memory usage for both approaches to demonstrate the memory
|
|
savings from online softmax.
|
|
"""
|
|
|
|
import torch
|
|
from cross_entropy_kernel import TritonCrossEntropyLoss
|
|
|
|
|
|
def test_cross_entropy_correctness(dtype, loss_atol, loss_rtol, grad_atol, grad_rtol):
|
|
"""Test forward loss and backward gradients for a given dtype.
|
|
|
|
BF16 logits use a looser loss tolerance (online softmax vs. PyTorch's path).
|
|
Gradients are compared in float32 to avoid BF16 rounding false failures on
|
|
128K-wide reductions.
|
|
"""
|
|
dtype_name = "FP32" if dtype == torch.float32 else "BF16"
|
|
vocab_size = 128256 # LLaMA 3.1 vocabulary size
|
|
num_tokens = 512 # batch_size * seq_len
|
|
|
|
# Create random logits and targets.
|
|
logits = torch.randn(num_tokens, vocab_size, device="cuda", dtype=dtype, requires_grad=True)
|
|
targets = torch.randint(0, vocab_size, (num_tokens,), device="cuda")
|
|
|
|
# --- Reference: PyTorch CrossEntropyLoss ---
|
|
logits_ref = logits.detach().clone().requires_grad_(True)
|
|
ref_loss = torch.nn.CrossEntropyLoss()(logits_ref, targets)
|
|
ref_loss.backward()
|
|
|
|
# --- Custom: Triton fused cross-entropy ---
|
|
logits_custom = logits.detach().clone().requires_grad_(True)
|
|
custom_ce = TritonCrossEntropyLoss()
|
|
custom_loss = custom_ce(logits_custom, targets)
|
|
custom_loss.backward()
|
|
|
|
# Compare forward loss values.
|
|
loss_diff = abs(ref_loss.item() - custom_loss.item())
|
|
loss_ok = torch.allclose(
|
|
ref_loss.detach().float(),
|
|
custom_loss.detach().float(),
|
|
atol=loss_atol,
|
|
rtol=loss_rtol,
|
|
)
|
|
print(f" {dtype_name} Loss — ref: {ref_loss.item():.6f} custom: {custom_loss.item():.6f} diff: {loss_diff:.2e} {'PASSED' if loss_ok else 'FAILED'}")
|
|
assert loss_ok, f"Loss correctness check FAILED for {dtype_name}"
|
|
|
|
# Compare backward gradients in FP32 (BF16 allclose on raw grads is too strict for V-wide CE).
|
|
g_ref = logits_ref.grad.float()
|
|
g_custom = logits_custom.grad.float()
|
|
grad_diff = (g_ref - g_custom).abs().max().item()
|
|
grad_ok = torch.allclose(g_ref, g_custom, atol=grad_atol, rtol=grad_rtol)
|
|
print(f" {dtype_name} Gradient — max diff (vs ref, fp32 compare): {grad_diff:.2e} {'PASSED' if grad_ok else 'FAILED'}")
|
|
assert grad_ok, f"Gradient correctness check FAILED for {dtype_name}"
|
|
|
|
|
|
def test_memory_savings():
|
|
"""Compare peak GPU memory between standard and fused cross-entropy."""
|
|
vocab_size = 128256
|
|
num_tokens = 512
|
|
|
|
print("\nMemory Comparison")
|
|
print("-" * 60)
|
|
|
|
# --- Standard PyTorch cross-entropy ---
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.empty_cache()
|
|
|
|
logits_std = torch.randn(num_tokens, vocab_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
targets = torch.randint(0, vocab_size, (num_tokens,), device="cuda")
|
|
|
|
baseline_before = torch.cuda.memory_allocated()
|
|
loss_std = torch.nn.CrossEntropyLoss()(logits_std, targets)
|
|
loss_std.backward()
|
|
std_peak = torch.cuda.max_memory_allocated()
|
|
|
|
del logits_std, loss_std
|
|
torch.cuda.empty_cache()
|
|
|
|
# --- Custom fused cross-entropy ---
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
logits_fused = torch.randn(num_tokens, vocab_size, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
|
|
fused_before = torch.cuda.memory_allocated()
|
|
custom_ce = TritonCrossEntropyLoss()
|
|
loss_fused = custom_ce(logits_fused, targets)
|
|
loss_fused.backward()
|
|
fused_peak = torch.cuda.max_memory_allocated()
|
|
|
|
std_mb = std_peak / 1024 / 1024
|
|
fused_mb = fused_peak / 1024 / 1024
|
|
reduction = std_peak / fused_peak if fused_peak > 0 else float("inf")
|
|
|
|
print(f" Standard PyTorch CE — peak memory: {std_mb:.1f} MB")
|
|
print(f" Fused Triton CE — peak memory: {fused_mb:.1f} MB")
|
|
print(f" Memory reduction: {reduction:.1f}x")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Cross-Entropy Correctness Tests")
|
|
print("=" * 60)
|
|
|
|
print("\nTest 1: Float32")
|
|
# log-sum-exp over 128K elements accumulates small drift vs PyTorch's CE path.
|
|
test_cross_entropy_correctness(
|
|
dtype=torch.float32,
|
|
loss_atol=1e-4,
|
|
loss_rtol=1e-4,
|
|
grad_atol=1e-4,
|
|
grad_rtol=1e-4,
|
|
)
|
|
|
|
print("\nTest 2: BFloat16 (relaxed tolerance)")
|
|
test_cross_entropy_correctness(
|
|
dtype=torch.bfloat16,
|
|
loss_atol=5e-2,
|
|
loss_rtol=5e-2,
|
|
grad_atol=2e-1,
|
|
grad_rtol=2e-1,
|
|
)
|
|
|
|
test_memory_savings()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("All cross-entropy tests PASSED")
|