mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-18 04:22:21 +00:00
263 lines
11 KiB
Python
263 lines
11 KiB
Python
|
|
"""
|
||
|
|
Fused RMSNorm Triton Kernel — Forward and Backward
|
||
|
|
====================================================
|
||
|
|
|
||
|
|
This module implements a fused RMSNorm (Root Mean Square Layer Normalization)
|
||
|
|
as a Triton kernel. RMSNorm is used in every transformer layer of LLaMA 3.1 8B
|
||
|
|
(and most modern LLMs). The formula is:
|
||
|
|
|
||
|
|
RMSNorm(x) = (x / sqrt(mean(x^2) + eps)) * weight
|
||
|
|
|
||
|
|
PyTorch's default implementation breaks this into multiple separate GPU operations
|
||
|
|
(square, mean, rsqrt, multiply, multiply-by-weight), each of which reads from and
|
||
|
|
writes to GPU memory. Our fused kernel does everything in a single pass: read x
|
||
|
|
once, compute the result, write once. This eliminates redundant memory traffic and
|
||
|
|
improves bandwidth utilization from ~11% to ~80-90% of the GPU's peak.
|
||
|
|
|
||
|
|
Key Triton concepts introduced:
|
||
|
|
- @triton.jit: JIT-compiles a Python function into GPU machine code
|
||
|
|
- tl.program_id: Each "program" handles one row of the input (like a CUDA block)
|
||
|
|
- tl.load / tl.store: Read from / write to GPU memory
|
||
|
|
- tl.sum: Parallel reduction across elements within a program
|
||
|
|
- tl.constexpr: Compile-time constants (like BLOCK_SIZE) that Triton optimizes for
|
||
|
|
"""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import triton
|
||
|
|
import triton.language as tl
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Forward kernel
|
||
|
|
# =============================================================================
|
||
|
|
# Each program processes one row of the input tensor. For LLaMA 3.1 8B with
|
||
|
|
# hidden_size=4096, each row is 4096 elements. The kernel reads the row once,
|
||
|
|
# computes the RMS normalization, multiplies by the learned weight, and writes
|
||
|
|
# the result. It also saves the inverse RMS value (rnorm) for the backward pass.
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
@triton.jit
|
||
|
|
def _rmsnorm_fwd_kernel(
|
||
|
|
# Pointers to tensors in GPU memory
|
||
|
|
X_ptr, # Input tensor: shape [num_rows, hidden_size]
|
||
|
|
W_ptr, # Weight tensor: shape [hidden_size]
|
||
|
|
Y_ptr, # Output tensor: shape [num_rows, hidden_size]
|
||
|
|
Rnorm_ptr, # Saved inverse RMS: shape [num_rows] (for backward)
|
||
|
|
# Dimensions
|
||
|
|
stride_x, # Stride between rows of X (number of elements to skip)
|
||
|
|
hidden_size, # Number of elements per row (e.g., 4096 for LLaMA 8B)
|
||
|
|
eps, # Small constant for numerical stability (typically 1e-6)
|
||
|
|
# Compile-time constant: how many elements each program processes at once.
|
||
|
|
# Triton will generate specialized GPU code for this specific block size.
|
||
|
|
BLOCK_SIZE: tl.constexpr,
|
||
|
|
):
|
||
|
|
# Which row this program is responsible for.
|
||
|
|
# tl.program_id(0) returns a unique index for each program, similar to
|
||
|
|
# blockIdx.x in CUDA but at a higher abstraction level.
|
||
|
|
row_idx = tl.program_id(0)
|
||
|
|
|
||
|
|
# Compute the memory offset for the start of this row.
|
||
|
|
row_start = row_idx * stride_x
|
||
|
|
|
||
|
|
# Create a vector of offsets [0, 1, 2, ..., BLOCK_SIZE-1] for loading
|
||
|
|
# elements within the row. If BLOCK_SIZE > hidden_size, some offsets
|
||
|
|
# will be out of bounds — the mask below handles that.
|
||
|
|
offsets = tl.arange(0, BLOCK_SIZE)
|
||
|
|
mask = offsets < hidden_size
|
||
|
|
|
||
|
|
# Load the entire row from GPU memory into registers.
|
||
|
|
# Elements beyond hidden_size are set to 0.0 (won't affect the sum).
|
||
|
|
x = tl.load(X_ptr + row_start + offsets, mask=mask, other=0.0)
|
||
|
|
|
||
|
|
# Load the weight vector (same for every row).
|
||
|
|
w = tl.load(W_ptr + offsets, mask=mask, other=0.0)
|
||
|
|
|
||
|
|
# Compute the RMS normalization in one pass:
|
||
|
|
# 1. Square each element and sum across the row
|
||
|
|
# 2. Divide by hidden_size to get the mean of squares
|
||
|
|
# 3. Add epsilon for numerical stability
|
||
|
|
# 4. Take the inverse square root
|
||
|
|
x_fp32 = x.to(tl.float32)
|
||
|
|
variance = tl.sum(x_fp32 * x_fp32, axis=0) / hidden_size
|
||
|
|
rnorm = 1.0 / tl.sqrt(variance + eps)
|
||
|
|
|
||
|
|
# Apply normalization and multiply by the learned weight.
|
||
|
|
y = (x_fp32 * rnorm).to(x.dtype) * w
|
||
|
|
|
||
|
|
# Write the normalized output.
|
||
|
|
tl.store(Y_ptr + row_start + offsets, y, mask=mask)
|
||
|
|
|
||
|
|
# Save the inverse RMS value for the backward pass. Each row produces
|
||
|
|
# one scalar that the backward kernel needs to compute gradients.
|
||
|
|
tl.store(Rnorm_ptr + row_idx, rnorm)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Backward kernel
|
||
|
|
# =============================================================================
|
||
|
|
# The backward pass computes two gradients:
|
||
|
|
# 1. grad_x: gradient of the loss w.r.t. the input x
|
||
|
|
# 2. grad_w: gradient of the loss w.r.t. the weight (accumulated across rows)
|
||
|
|
#
|
||
|
|
# For grad_x, the derivation is:
|
||
|
|
# y = x * rnorm * w
|
||
|
|
# dy/dx = rnorm * w - x * rnorm^3 * (1/N) * sum(x * dy * w)
|
||
|
|
# = rnorm * (dy * w - x * (1/N) * sum(x * dy * w) * rnorm^2)
|
||
|
|
#
|
||
|
|
# For grad_w, it is simply: dL/dw = sum_over_rows(dy * x * rnorm)
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
@triton.jit
|
||
|
|
def _rmsnorm_bwd_kernel(
|
||
|
|
# Pointers to tensors
|
||
|
|
DY_ptr, # Upstream gradient: shape [num_rows, hidden_size]
|
||
|
|
X_ptr, # Original input (saved from forward): shape [num_rows, hidden_size]
|
||
|
|
W_ptr, # Weight: shape [hidden_size]
|
||
|
|
Rnorm_ptr, # Saved inverse RMS from forward: shape [num_rows]
|
||
|
|
DX_ptr, # Output: gradient w.r.t. input: shape [num_rows, hidden_size]
|
||
|
|
DW_ptr, # Output: partial gradient w.r.t. weight: shape [num_rows, hidden_size]
|
||
|
|
# Dimensions
|
||
|
|
stride_x,
|
||
|
|
hidden_size,
|
||
|
|
BLOCK_SIZE: tl.constexpr,
|
||
|
|
DTYPE_IS_FP32: tl.constexpr, # True if input dtype is float32
|
||
|
|
DTYPE_IS_BF16: tl.constexpr, # True if input dtype is bfloat16 (else float16)
|
||
|
|
):
|
||
|
|
row_idx = tl.program_id(0)
|
||
|
|
row_start = row_idx * stride_x
|
||
|
|
offsets = tl.arange(0, BLOCK_SIZE)
|
||
|
|
mask = offsets < hidden_size
|
||
|
|
|
||
|
|
# Load everything we need for this row.
|
||
|
|
dy = tl.load(DY_ptr + row_start + offsets, mask=mask, other=0.0).to(tl.float32)
|
||
|
|
x = tl.load(X_ptr + row_start + offsets, mask=mask, other=0.0).to(tl.float32)
|
||
|
|
w = tl.load(W_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
||
|
|
rnorm = tl.load(Rnorm_ptr + row_idx)
|
||
|
|
|
||
|
|
# Compute the normalized input (same as forward, reconstructed from saved rnorm).
|
||
|
|
x_hat = x * rnorm
|
||
|
|
|
||
|
|
# Gradient w.r.t. weight for this row: dy * x_hat_quantized
|
||
|
|
# Important: we must cast x_hat to the input dtype before computing dw,
|
||
|
|
# matching what the forward pass does (cast normalized output to BF16
|
||
|
|
# before multiplying by weight). This ensures our gradient matches PyTorch's.
|
||
|
|
if DTYPE_IS_FP32:
|
||
|
|
x_hat_q = x_hat
|
||
|
|
elif DTYPE_IS_BF16:
|
||
|
|
x_hat_q = x_hat.to(tl.bfloat16).to(tl.float32)
|
||
|
|
else:
|
||
|
|
x_hat_q = x_hat.to(tl.float16).to(tl.float32)
|
||
|
|
dw = dy * x_hat_q
|
||
|
|
|
||
|
|
# Gradient w.r.t. input:
|
||
|
|
# dx = rnorm * (dy * w - x_hat * mean(dy * w * x_hat))
|
||
|
|
dy_w = dy * w
|
||
|
|
dot = tl.sum(dy_w * x_hat, axis=0) / hidden_size
|
||
|
|
dx = rnorm * (dy_w - x_hat * dot)
|
||
|
|
|
||
|
|
# Cast dx back to the input dtype, but keep dw in float32 for accurate
|
||
|
|
# accumulation when summing across rows.
|
||
|
|
if DTYPE_IS_FP32:
|
||
|
|
out_dtype = tl.float32
|
||
|
|
elif DTYPE_IS_BF16:
|
||
|
|
out_dtype = tl.bfloat16
|
||
|
|
else:
|
||
|
|
out_dtype = tl.float16
|
||
|
|
tl.store(DX_ptr + row_start + offsets, dx.to(out_dtype), mask=mask)
|
||
|
|
tl.store(DW_ptr + row_start + offsets, dw, mask=mask) # dw already float32
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Autograd wrapper
|
||
|
|
# =============================================================================
|
||
|
|
# torch.autograd.Function connects the Triton kernels to PyTorch's automatic
|
||
|
|
# differentiation system. The forward() method runs the forward kernel and
|
||
|
|
# saves tensors needed for backward. The backward() method runs the backward
|
||
|
|
# kernel using those saved tensors.
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TritonRMSNormFunction(torch.autograd.Function):
|
||
|
|
@staticmethod
|
||
|
|
def forward(ctx, x, weight, eps):
|
||
|
|
# x shape: [*, hidden_size] — flatten all leading dims into num_rows
|
||
|
|
orig_shape = x.shape
|
||
|
|
x_2d = x.view(-1, orig_shape[-1])
|
||
|
|
num_rows, hidden_size = x_2d.shape
|
||
|
|
|
||
|
|
# Allocate output and saved tensors
|
||
|
|
y = torch.empty_like(x_2d)
|
||
|
|
rnorm = torch.empty(num_rows, dtype=torch.float32, device=x.device)
|
||
|
|
|
||
|
|
# Choose BLOCK_SIZE: must be a power of 2 >= hidden_size.
|
||
|
|
# Triton generates specialized code for each BLOCK_SIZE value.
|
||
|
|
BLOCK_SIZE = triton.next_power_of_2(hidden_size)
|
||
|
|
|
||
|
|
# Launch the kernel: one program per row.
|
||
|
|
_rmsnorm_fwd_kernel[(num_rows,)](
|
||
|
|
x_2d, weight, y, rnorm,
|
||
|
|
x_2d.stride(0),
|
||
|
|
hidden_size,
|
||
|
|
eps,
|
||
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Save tensors needed by the backward pass.
|
||
|
|
ctx.save_for_backward(x_2d, weight, rnorm)
|
||
|
|
ctx.hidden_size = hidden_size
|
||
|
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||
|
|
ctx.orig_shape = orig_shape
|
||
|
|
|
||
|
|
return y.view(orig_shape)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def backward(ctx, dy):
|
||
|
|
x_2d, weight, rnorm = ctx.saved_tensors
|
||
|
|
dy_2d = dy.view(-1, ctx.hidden_size)
|
||
|
|
num_rows = x_2d.shape[0]
|
||
|
|
|
||
|
|
dx = torch.empty_like(x_2d)
|
||
|
|
# Per-row weight gradients stored in FP32 for accurate accumulation.
|
||
|
|
# Summing thousands of BF16 values would lose significant precision.
|
||
|
|
dw_rows = torch.empty(x_2d.shape, dtype=torch.float32, device=x_2d.device)
|
||
|
|
|
||
|
|
_rmsnorm_bwd_kernel[(num_rows,)](
|
||
|
|
dy_2d, x_2d, weight, rnorm, dx, dw_rows,
|
||
|
|
x_2d.stride(0),
|
||
|
|
ctx.hidden_size,
|
||
|
|
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||
|
|
DTYPE_IS_FP32=(x_2d.dtype == torch.float32),
|
||
|
|
DTYPE_IS_BF16=(x_2d.dtype == torch.bfloat16),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Sum per-row weight gradients across all rows (already in FP32).
|
||
|
|
dw = dw_rows.sum(dim=0).to(x_2d.dtype)
|
||
|
|
|
||
|
|
return dx.view(ctx.orig_shape), dw, None # None for eps (not differentiable)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Drop-in nn.Module replacement
|
||
|
|
# =============================================================================
|
||
|
|
# This module has the same interface as transformers.models.llama.LlamaRMSNorm,
|
||
|
|
# so it can be swapped in without changing any other code.
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TritonRMSNorm(torch.nn.Module):
|
||
|
|
def __init__(self, hidden_size, eps=1e-6):
|
||
|
|
super().__init__()
|
||
|
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||
|
|
self.eps = eps
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return TritonRMSNormFunction.apply(x, self.weight, self.eps)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_llama_rmsnorm(cls, llama_norm):
|
||
|
|
"""Create a TritonRMSNorm from an existing LlamaRMSNorm, copying its weights."""
|
||
|
|
norm = cls(
|
||
|
|
hidden_size=llama_norm.weight.shape[0],
|
||
|
|
eps=llama_norm.variance_epsilon,
|
||
|
|
)
|
||
|
|
norm.weight = llama_norm.weight
|
||
|
|
return norm
|