dgx-spark-playbooks/nvidia/station-kernel-dev-ft/assets/rmsnorm_test.py
2026-05-26 18:25:53 +00:00

113 lines
4.6 KiB
Python

"""
Correctness tests for the fused RMSNorm Triton kernel.
Compares forward and backward outputs against PyTorch's reference LlamaRMSNorm
at shapes matching LLaMA 3.1 8B (hidden_size=4096). Tests both FP32 and BF16
to verify numerical correctness under realistic training precision.
"""
import torch
from rmsnorm_kernel import TritonRMSNorm
class ReferenceLlamaRMSNorm(torch.nn.Module):
"""PyTorch reference implementation of RMSNorm (from HuggingFace transformers)."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def test_rmsnorm(dtype, atol, rtol, dw_atol=None, dw_rtol=None):
"""Test forward and backward for a given dtype.
dw_atol/dw_rtol: separate tolerance for weight gradient, which accumulates
across all rows and is more sensitive to FP ordering differences.
"""
dtype_name = "FP32" if dtype == torch.float32 else "BF16"
hidden_size = 4096 # LLaMA 3.1 8B hidden dimension
batch_size = 4
seq_len = 512
if dw_atol is None:
dw_atol = atol
if dw_rtol is None:
dw_rtol = rtol
# Create input tensor with gradients enabled.
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype, requires_grad=True)
# Reference implementation
ref_norm = ReferenceLlamaRMSNorm(hidden_size).to(device="cuda", dtype=dtype)
# Custom Triton implementation (copy weights from reference)
custom_norm = TritonRMSNorm(hidden_size).to(device="cuda", dtype=dtype)
custom_norm.weight.data.copy_(ref_norm.weight.data)
# --- Forward pass ---
x_ref = x.detach().clone().requires_grad_(True)
x_custom = x.detach().clone().requires_grad_(True)
y_ref = ref_norm(x_ref)
y_custom = custom_norm(x_custom)
fwd_max_diff = (y_ref - y_custom).abs().max().item()
fwd_pass = torch.allclose(y_ref, y_custom, atol=atol, rtol=rtol)
print(f" {dtype_name} Forward — max diff: {fwd_max_diff:.2e} {'PASSED' if fwd_pass else 'FAILED'}")
assert fwd_pass, f"Forward correctness check FAILED for {dtype_name}"
# --- Backward pass ---
# Use the same upstream gradient for both.
grad_output = torch.randn_like(y_ref)
y_ref.backward(grad_output)
y_custom.backward(grad_output)
grad_x_diff = (x_ref.grad - x_custom.grad).abs().max().item()
grad_x_pass = torch.allclose(x_ref.grad, x_custom.grad, atol=atol, rtol=rtol)
# Weight gradient uses separate (potentially relaxed) tolerance because it
# sums per-row contributions across batch*seq_len rows. Different summation
# order between our FP32-accumulated kernel and PyTorch's autograd produces
# larger absolute differences, especially in BF16 where the final cast
# has 0.25-0.5 step size for typical gradient magnitudes.
grad_w_diff = (ref_norm.weight.grad - custom_norm.weight.grad).abs().max().item()
grad_w_pass = torch.allclose(ref_norm.weight.grad, custom_norm.weight.grad, atol=dw_atol, rtol=dw_rtol)
print(f" {dtype_name} Backward (dx) — max diff: {grad_x_diff:.2e} {'PASSED' if grad_x_pass else 'FAILED'}")
print(f" {dtype_name} Backward (dw) — max diff: {grad_w_diff:.2e} {'PASSED' if grad_w_pass else 'FAILED'}")
assert grad_x_pass, f"Backward grad_x check FAILED for {dtype_name}"
assert grad_w_pass, f"Backward grad_w check FAILED for {dtype_name}"
if __name__ == "__main__":
print("RMSNorm Correctness Tests")
print("=" * 60)
print("\nTest 1: Float32")
# Slightly relaxed tolerance: the weight gradient is accumulated across all
# rows (batch_size * seq_len = 2048), and different summation order between
# our per-row accumulation and PyTorch's autograd causes FP32 rounding diffs.
test_rmsnorm(dtype=torch.float32, atol=1e-4, rtol=1e-4)
print("\nTest 2: BFloat16 (relaxed tolerance)")
# BF16 has only 7 bits of mantissa, so larger differences are expected.
# The weight gradient tolerance is more relaxed because BF16 accumulation
# across 2048 rows amplifies rounding differences. The key check is that
# the relative error is small (< 2%), not the absolute difference.
test_rmsnorm(dtype=torch.bfloat16, atol=1e-2, rtol=1e-2,
dw_atol=2.0, dw_rtol=2e-2)
print("\n" + "=" * 60)
print("All RMSNorm correctness tests PASSED")