dgx-spark-playbooks/nvidia/station-kernel-dev-ft/assets/rmsnorm_kernel.py
2026-05-26 18:25:53 +00:00

263 lines
11 KiB
Python

"""
Fused RMSNorm Triton Kernel — Forward and Backward
====================================================
This module implements a fused RMSNorm (Root Mean Square Layer Normalization)
as a Triton kernel. RMSNorm is used in every transformer layer of LLaMA 3.1 8B
(and most modern LLMs). The formula is:
RMSNorm(x) = (x / sqrt(mean(x^2) + eps)) * weight
PyTorch's default implementation breaks this into multiple separate GPU operations
(square, mean, rsqrt, multiply, multiply-by-weight), each of which reads from and
writes to GPU memory. Our fused kernel does everything in a single pass: read x
once, compute the result, write once. This eliminates redundant memory traffic and
improves bandwidth utilization from ~11% to ~80-90% of the GPU's peak.
Key Triton concepts introduced:
- @triton.jit: JIT-compiles a Python function into GPU machine code
- tl.program_id: Each "program" handles one row of the input (like a CUDA block)
- tl.load / tl.store: Read from / write to GPU memory
- tl.sum: Parallel reduction across elements within a program
- tl.constexpr: Compile-time constants (like BLOCK_SIZE) that Triton optimizes for
"""
import torch
import triton
import triton.language as tl
# =============================================================================
# Forward kernel
# =============================================================================
# Each program processes one row of the input tensor. For LLaMA 3.1 8B with
# hidden_size=4096, each row is 4096 elements. The kernel reads the row once,
# computes the RMS normalization, multiplies by the learned weight, and writes
# the result. It also saves the inverse RMS value (rnorm) for the backward pass.
# =============================================================================
@triton.jit
def _rmsnorm_fwd_kernel(
# Pointers to tensors in GPU memory
X_ptr, # Input tensor: shape [num_rows, hidden_size]
W_ptr, # Weight tensor: shape [hidden_size]
Y_ptr, # Output tensor: shape [num_rows, hidden_size]
Rnorm_ptr, # Saved inverse RMS: shape [num_rows] (for backward)
# Dimensions
stride_x, # Stride between rows of X (number of elements to skip)
hidden_size, # Number of elements per row (e.g., 4096 for LLaMA 8B)
eps, # Small constant for numerical stability (typically 1e-6)
# Compile-time constant: how many elements each program processes at once.
# Triton will generate specialized GPU code for this specific block size.
BLOCK_SIZE: tl.constexpr,
):
# Which row this program is responsible for.
# tl.program_id(0) returns a unique index for each program, similar to
# blockIdx.x in CUDA but at a higher abstraction level.
row_idx = tl.program_id(0)
# Compute the memory offset for the start of this row.
row_start = row_idx * stride_x
# Create a vector of offsets [0, 1, 2, ..., BLOCK_SIZE-1] for loading
# elements within the row. If BLOCK_SIZE > hidden_size, some offsets
# will be out of bounds — the mask below handles that.
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_size
# Load the entire row from GPU memory into registers.
# Elements beyond hidden_size are set to 0.0 (won't affect the sum).
x = tl.load(X_ptr + row_start + offsets, mask=mask, other=0.0)
# Load the weight vector (same for every row).
w = tl.load(W_ptr + offsets, mask=mask, other=0.0)
# Compute the RMS normalization in one pass:
# 1. Square each element and sum across the row
# 2. Divide by hidden_size to get the mean of squares
# 3. Add epsilon for numerical stability
# 4. Take the inverse square root
x_fp32 = x.to(tl.float32)
variance = tl.sum(x_fp32 * x_fp32, axis=0) / hidden_size
rnorm = 1.0 / tl.sqrt(variance + eps)
# Apply normalization and multiply by the learned weight.
y = (x_fp32 * rnorm).to(x.dtype) * w
# Write the normalized output.
tl.store(Y_ptr + row_start + offsets, y, mask=mask)
# Save the inverse RMS value for the backward pass. Each row produces
# one scalar that the backward kernel needs to compute gradients.
tl.store(Rnorm_ptr + row_idx, rnorm)
# =============================================================================
# Backward kernel
# =============================================================================
# The backward pass computes two gradients:
# 1. grad_x: gradient of the loss w.r.t. the input x
# 2. grad_w: gradient of the loss w.r.t. the weight (accumulated across rows)
#
# For grad_x, the derivation is:
# y = x * rnorm * w
# dy/dx = rnorm * w - x * rnorm^3 * (1/N) * sum(x * dy * w)
# = rnorm * (dy * w - x * (1/N) * sum(x * dy * w) * rnorm^2)
#
# For grad_w, it is simply: dL/dw = sum_over_rows(dy * x * rnorm)
# =============================================================================
@triton.jit
def _rmsnorm_bwd_kernel(
# Pointers to tensors
DY_ptr, # Upstream gradient: shape [num_rows, hidden_size]
X_ptr, # Original input (saved from forward): shape [num_rows, hidden_size]
W_ptr, # Weight: shape [hidden_size]
Rnorm_ptr, # Saved inverse RMS from forward: shape [num_rows]
DX_ptr, # Output: gradient w.r.t. input: shape [num_rows, hidden_size]
DW_ptr, # Output: partial gradient w.r.t. weight: shape [num_rows, hidden_size]
# Dimensions
stride_x,
hidden_size,
BLOCK_SIZE: tl.constexpr,
DTYPE_IS_FP32: tl.constexpr, # True if input dtype is float32
DTYPE_IS_BF16: tl.constexpr, # True if input dtype is bfloat16 (else float16)
):
row_idx = tl.program_id(0)
row_start = row_idx * stride_x
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_size
# Load everything we need for this row.
dy = tl.load(DY_ptr + row_start + offsets, mask=mask, other=0.0).to(tl.float32)
x = tl.load(X_ptr + row_start + offsets, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
rnorm = tl.load(Rnorm_ptr + row_idx)
# Compute the normalized input (same as forward, reconstructed from saved rnorm).
x_hat = x * rnorm
# Gradient w.r.t. weight for this row: dy * x_hat_quantized
# Important: we must cast x_hat to the input dtype before computing dw,
# matching what the forward pass does (cast normalized output to BF16
# before multiplying by weight). This ensures our gradient matches PyTorch's.
if DTYPE_IS_FP32:
x_hat_q = x_hat
elif DTYPE_IS_BF16:
x_hat_q = x_hat.to(tl.bfloat16).to(tl.float32)
else:
x_hat_q = x_hat.to(tl.float16).to(tl.float32)
dw = dy * x_hat_q
# Gradient w.r.t. input:
# dx = rnorm * (dy * w - x_hat * mean(dy * w * x_hat))
dy_w = dy * w
dot = tl.sum(dy_w * x_hat, axis=0) / hidden_size
dx = rnorm * (dy_w - x_hat * dot)
# Cast dx back to the input dtype, but keep dw in float32 for accurate
# accumulation when summing across rows.
if DTYPE_IS_FP32:
out_dtype = tl.float32
elif DTYPE_IS_BF16:
out_dtype = tl.bfloat16
else:
out_dtype = tl.float16
tl.store(DX_ptr + row_start + offsets, dx.to(out_dtype), mask=mask)
tl.store(DW_ptr + row_start + offsets, dw, mask=mask) # dw already float32
# =============================================================================
# Autograd wrapper
# =============================================================================
# torch.autograd.Function connects the Triton kernels to PyTorch's automatic
# differentiation system. The forward() method runs the forward kernel and
# saves tensors needed for backward. The backward() method runs the backward
# kernel using those saved tensors.
# =============================================================================
class TritonRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, eps):
# x shape: [*, hidden_size] — flatten all leading dims into num_rows
orig_shape = x.shape
x_2d = x.view(-1, orig_shape[-1])
num_rows, hidden_size = x_2d.shape
# Allocate output and saved tensors
y = torch.empty_like(x_2d)
rnorm = torch.empty(num_rows, dtype=torch.float32, device=x.device)
# Choose BLOCK_SIZE: must be a power of 2 >= hidden_size.
# Triton generates specialized code for each BLOCK_SIZE value.
BLOCK_SIZE = triton.next_power_of_2(hidden_size)
# Launch the kernel: one program per row.
_rmsnorm_fwd_kernel[(num_rows,)](
x_2d, weight, y, rnorm,
x_2d.stride(0),
hidden_size,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
# Save tensors needed by the backward pass.
ctx.save_for_backward(x_2d, weight, rnorm)
ctx.hidden_size = hidden_size
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.orig_shape = orig_shape
return y.view(orig_shape)
@staticmethod
def backward(ctx, dy):
x_2d, weight, rnorm = ctx.saved_tensors
dy_2d = dy.view(-1, ctx.hidden_size)
num_rows = x_2d.shape[0]
dx = torch.empty_like(x_2d)
# Per-row weight gradients stored in FP32 for accurate accumulation.
# Summing thousands of BF16 values would lose significant precision.
dw_rows = torch.empty(x_2d.shape, dtype=torch.float32, device=x_2d.device)
_rmsnorm_bwd_kernel[(num_rows,)](
dy_2d, x_2d, weight, rnorm, dx, dw_rows,
x_2d.stride(0),
ctx.hidden_size,
BLOCK_SIZE=ctx.BLOCK_SIZE,
DTYPE_IS_FP32=(x_2d.dtype == torch.float32),
DTYPE_IS_BF16=(x_2d.dtype == torch.bfloat16),
)
# Sum per-row weight gradients across all rows (already in FP32).
dw = dw_rows.sum(dim=0).to(x_2d.dtype)
return dx.view(ctx.orig_shape), dw, None # None for eps (not differentiable)
# =============================================================================
# Drop-in nn.Module replacement
# =============================================================================
# This module has the same interface as transformers.models.llama.LlamaRMSNorm,
# so it can be swapped in without changing any other code.
# =============================================================================
class TritonRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
return TritonRMSNormFunction.apply(x, self.weight, self.eps)
@classmethod
def from_llama_rmsnorm(cls, llama_norm):
"""Create a TritonRMSNorm from an existing LlamaRMSNorm, copying its weights."""
norm = cls(
hidden_size=llama_norm.weight.shape[0],
eps=llama_norm.variance_epsilon,
)
norm.weight = llama_norm.weight
return norm