""" Fused RMSNorm Triton Kernel — Forward and Backward ==================================================== This module implements a fused RMSNorm (Root Mean Square Layer Normalization) as a Triton kernel. RMSNorm is used in every transformer layer of LLaMA 3.1 8B (and most modern LLMs). The formula is: RMSNorm(x) = (x / sqrt(mean(x^2) + eps)) * weight PyTorch's default implementation breaks this into multiple separate GPU operations (square, mean, rsqrt, multiply, multiply-by-weight), each of which reads from and writes to GPU memory. Our fused kernel does everything in a single pass: read x once, compute the result, write once. This eliminates redundant memory traffic and improves bandwidth utilization from ~11% to ~80-90% of the GPU's peak. Key Triton concepts introduced: - @triton.jit: JIT-compiles a Python function into GPU machine code - tl.program_id: Each "program" handles one row of the input (like a CUDA block) - tl.load / tl.store: Read from / write to GPU memory - tl.sum: Parallel reduction across elements within a program - tl.constexpr: Compile-time constants (like BLOCK_SIZE) that Triton optimizes for """ import torch import triton import triton.language as tl # ============================================================================= # Forward kernel # ============================================================================= # Each program processes one row of the input tensor. For LLaMA 3.1 8B with # hidden_size=4096, each row is 4096 elements. The kernel reads the row once, # computes the RMS normalization, multiplies by the learned weight, and writes # the result. It also saves the inverse RMS value (rnorm) for the backward pass. # ============================================================================= @triton.jit def _rmsnorm_fwd_kernel( # Pointers to tensors in GPU memory X_ptr, # Input tensor: shape [num_rows, hidden_size] W_ptr, # Weight tensor: shape [hidden_size] Y_ptr, # Output tensor: shape [num_rows, hidden_size] Rnorm_ptr, # Saved inverse RMS: shape [num_rows] (for backward) # Dimensions stride_x, # Stride between rows of X (number of elements to skip) hidden_size, # Number of elements per row (e.g., 4096 for LLaMA 8B) eps, # Small constant for numerical stability (typically 1e-6) # Compile-time constant: how many elements each program processes at once. # Triton will generate specialized GPU code for this specific block size. BLOCK_SIZE: tl.constexpr, ): # Which row this program is responsible for. # tl.program_id(0) returns a unique index for each program, similar to # blockIdx.x in CUDA but at a higher abstraction level. row_idx = tl.program_id(0) # Compute the memory offset for the start of this row. row_start = row_idx * stride_x # Create a vector of offsets [0, 1, 2, ..., BLOCK_SIZE-1] for loading # elements within the row. If BLOCK_SIZE > hidden_size, some offsets # will be out of bounds — the mask below handles that. offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < hidden_size # Load the entire row from GPU memory into registers. # Elements beyond hidden_size are set to 0.0 (won't affect the sum). x = tl.load(X_ptr + row_start + offsets, mask=mask, other=0.0) # Load the weight vector (same for every row). w = tl.load(W_ptr + offsets, mask=mask, other=0.0) # Compute the RMS normalization in one pass: # 1. Square each element and sum across the row # 2. Divide by hidden_size to get the mean of squares # 3. Add epsilon for numerical stability # 4. Take the inverse square root x_fp32 = x.to(tl.float32) variance = tl.sum(x_fp32 * x_fp32, axis=0) / hidden_size rnorm = 1.0 / tl.sqrt(variance + eps) # Apply normalization and multiply by the learned weight. y = (x_fp32 * rnorm).to(x.dtype) * w # Write the normalized output. tl.store(Y_ptr + row_start + offsets, y, mask=mask) # Save the inverse RMS value for the backward pass. Each row produces # one scalar that the backward kernel needs to compute gradients. tl.store(Rnorm_ptr + row_idx, rnorm) # ============================================================================= # Backward kernel # ============================================================================= # The backward pass computes two gradients: # 1. grad_x: gradient of the loss w.r.t. the input x # 2. grad_w: gradient of the loss w.r.t. the weight (accumulated across rows) # # For grad_x, the derivation is: # y = x * rnorm * w # dy/dx = rnorm * w - x * rnorm^3 * (1/N) * sum(x * dy * w) # = rnorm * (dy * w - x * (1/N) * sum(x * dy * w) * rnorm^2) # # For grad_w, it is simply: dL/dw = sum_over_rows(dy * x * rnorm) # ============================================================================= @triton.jit def _rmsnorm_bwd_kernel( # Pointers to tensors DY_ptr, # Upstream gradient: shape [num_rows, hidden_size] X_ptr, # Original input (saved from forward): shape [num_rows, hidden_size] W_ptr, # Weight: shape [hidden_size] Rnorm_ptr, # Saved inverse RMS from forward: shape [num_rows] DX_ptr, # Output: gradient w.r.t. input: shape [num_rows, hidden_size] DW_ptr, # Output: partial gradient w.r.t. weight: shape [num_rows, hidden_size] # Dimensions stride_x, hidden_size, BLOCK_SIZE: tl.constexpr, DTYPE_IS_FP32: tl.constexpr, # True if input dtype is float32 DTYPE_IS_BF16: tl.constexpr, # True if input dtype is bfloat16 (else float16) ): row_idx = tl.program_id(0) row_start = row_idx * stride_x offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < hidden_size # Load everything we need for this row. dy = tl.load(DY_ptr + row_start + offsets, mask=mask, other=0.0).to(tl.float32) x = tl.load(X_ptr + row_start + offsets, mask=mask, other=0.0).to(tl.float32) w = tl.load(W_ptr + offsets, mask=mask, other=0.0).to(tl.float32) rnorm = tl.load(Rnorm_ptr + row_idx) # Compute the normalized input (same as forward, reconstructed from saved rnorm). x_hat = x * rnorm # Gradient w.r.t. weight for this row: dy * x_hat_quantized # Important: we must cast x_hat to the input dtype before computing dw, # matching what the forward pass does (cast normalized output to BF16 # before multiplying by weight). This ensures our gradient matches PyTorch's. if DTYPE_IS_FP32: x_hat_q = x_hat elif DTYPE_IS_BF16: x_hat_q = x_hat.to(tl.bfloat16).to(tl.float32) else: x_hat_q = x_hat.to(tl.float16).to(tl.float32) dw = dy * x_hat_q # Gradient w.r.t. input: # dx = rnorm * (dy * w - x_hat * mean(dy * w * x_hat)) dy_w = dy * w dot = tl.sum(dy_w * x_hat, axis=0) / hidden_size dx = rnorm * (dy_w - x_hat * dot) # Cast dx back to the input dtype, but keep dw in float32 for accurate # accumulation when summing across rows. if DTYPE_IS_FP32: out_dtype = tl.float32 elif DTYPE_IS_BF16: out_dtype = tl.bfloat16 else: out_dtype = tl.float16 tl.store(DX_ptr + row_start + offsets, dx.to(out_dtype), mask=mask) tl.store(DW_ptr + row_start + offsets, dw, mask=mask) # dw already float32 # ============================================================================= # Autograd wrapper # ============================================================================= # torch.autograd.Function connects the Triton kernels to PyTorch's automatic # differentiation system. The forward() method runs the forward kernel and # saves tensors needed for backward. The backward() method runs the backward # kernel using those saved tensors. # ============================================================================= class TritonRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, eps): # x shape: [*, hidden_size] — flatten all leading dims into num_rows orig_shape = x.shape x_2d = x.view(-1, orig_shape[-1]) num_rows, hidden_size = x_2d.shape # Allocate output and saved tensors y = torch.empty_like(x_2d) rnorm = torch.empty(num_rows, dtype=torch.float32, device=x.device) # Choose BLOCK_SIZE: must be a power of 2 >= hidden_size. # Triton generates specialized code for each BLOCK_SIZE value. BLOCK_SIZE = triton.next_power_of_2(hidden_size) # Launch the kernel: one program per row. _rmsnorm_fwd_kernel[(num_rows,)]( x_2d, weight, y, rnorm, x_2d.stride(0), hidden_size, eps, BLOCK_SIZE=BLOCK_SIZE, ) # Save tensors needed by the backward pass. ctx.save_for_backward(x_2d, weight, rnorm) ctx.hidden_size = hidden_size ctx.BLOCK_SIZE = BLOCK_SIZE ctx.orig_shape = orig_shape return y.view(orig_shape) @staticmethod def backward(ctx, dy): x_2d, weight, rnorm = ctx.saved_tensors dy_2d = dy.view(-1, ctx.hidden_size) num_rows = x_2d.shape[0] dx = torch.empty_like(x_2d) # Per-row weight gradients stored in FP32 for accurate accumulation. # Summing thousands of BF16 values would lose significant precision. dw_rows = torch.empty(x_2d.shape, dtype=torch.float32, device=x_2d.device) _rmsnorm_bwd_kernel[(num_rows,)]( dy_2d, x_2d, weight, rnorm, dx, dw_rows, x_2d.stride(0), ctx.hidden_size, BLOCK_SIZE=ctx.BLOCK_SIZE, DTYPE_IS_FP32=(x_2d.dtype == torch.float32), DTYPE_IS_BF16=(x_2d.dtype == torch.bfloat16), ) # Sum per-row weight gradients across all rows (already in FP32). dw = dw_rows.sum(dim=0).to(x_2d.dtype) return dx.view(ctx.orig_shape), dw, None # None for eps (not differentiable) # ============================================================================= # Drop-in nn.Module replacement # ============================================================================= # This module has the same interface as transformers.models.llama.LlamaRMSNorm, # so it can be swapped in without changing any other code. # ============================================================================= class TritonRMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = torch.nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): return TritonRMSNormFunction.apply(x, self.weight, self.eps) @classmethod def from_llama_rmsnorm(cls, llama_norm): """Create a TritonRMSNorm from an existing LlamaRMSNorm, copying its weights.""" norm = cls( hidden_size=llama_norm.weight.shape[0], eps=llama_norm.variance_epsilon, ) norm.weight = llama_norm.weight return norm