dgx-spark-playbooks/nvidia/station-kernel-dev-ft/assets/cross_entropy_kernel.py

259 lines
10 KiB
Python
Raw Permalink Normal View History

2026-05-26 18:25:53 +00:00
"""
Fused Cross-Entropy Triton Kernel Online Softmax
====================================================
This module implements a fused cross-entropy loss that avoids materializing the
full logit tensor. Standard PyTorch cross-entropy computes:
loss = -log(softmax(logits)[target])
This requires computing softmax over the entire vocabulary for every token
position. For LLaMA 3.1 8B with vocabulary size 128,256, a batch of 512 tokens
produces a logit tensor of shape [512, 128256] about 250 MB in float32.
During training, PyTorch also stores this for the backward pass, roughly
doubling the memory cost.
Our fused kernel uses the **online softmax** algorithm (Milakov & Gimelshein, 2018):
instead of computing softmax over all vocabulary entries at once, it processes
the vocabulary in chunks. For each chunk, it maintains a running maximum and a
running sum-of-exponentials. After processing all chunks, it has enough
information to compute the loss without ever allocating the full [B*T, V]
tensor.
Memory savings: For V=128256, the standard approach allocates O(B*T*V) memory.
The online approach allocates O(B*T) avoiding the full vocabulary-sized
intermediate. In practice, the input logits are still retained for the backward
pass, so the measured end-to-end memory reduction is ~6x at realistic batch
sizes still a significant saving.
Key Triton concepts introduced beyond rmsnorm_kernel.py:
- Loops inside kernels: Processing vocabulary in chunks with a for loop
- tl.where: Conditional element selection (for masking the last chunk)
- Multi-pass algorithms: Maintaining running state across iterations
"""
import torch
import triton
import triton.language as tl
# =============================================================================
# Forward kernel
# =============================================================================
# Each program handles one row (one token position). It iterates over the
# vocabulary in chunks of BLOCK_SIZE, maintaining:
# m: running maximum logit (for numerical stability)
# d: running sum of exp(logit - m) (the softmax denominator)
# After all chunks, the loss for this row is: log(d) + m - logit[target]
# =============================================================================
@triton.jit
def _cross_entropy_fwd_kernel(
# Pointers
Logits_ptr, # Input logits: shape [num_rows, vocab_size]
Targets_ptr, # Target class indices: shape [num_rows]
Losses_ptr, # Output per-row loss: shape [num_rows]
Max_ptr, # Saved running max for backward: shape [num_rows]
Denom_ptr, # Saved running denominator for backward: shape [num_rows]
# Dimensions
vocab_size,
stride_logits, # Stride between rows of logits
# Compile-time constant
BLOCK_SIZE: tl.constexpr,
):
row_idx = tl.program_id(0)
logit_row_start = row_idx * stride_logits
target = tl.load(Targets_ptr + row_idx)
# Initialize the running max and running sum-of-exp.
# We start with -inf for the max so any real logit will be larger.
m = float("-inf") # Running maximum logit
d = 0.0 # Running sum of exp(logit_i - m)
# Also track the logit value at the target index (needed for the loss).
target_logit = 0.0
# --- Online softmax: iterate over vocabulary in chunks ---
# This is the key optimization. Instead of loading all 128K logits at once
# (which would require allocating a huge tensor), we process BLOCK_SIZE
# logits at a time and update our running statistics.
for start in range(0, vocab_size, BLOCK_SIZE):
offsets = start + tl.arange(0, BLOCK_SIZE)
mask = offsets < vocab_size
# Load a chunk of logits.
logits_chunk = tl.load(
Logits_ptr + logit_row_start + offsets,
mask=mask,
other=float("-inf"), # Out-of-bounds values won't affect max/sum
).to(tl.float32)
# Update running max.
chunk_max = tl.max(logits_chunk, axis=0)
new_m = tl.maximum(m, chunk_max)
# Update running sum-of-exp using the new max.
# The key identity: sum(exp(x_i - m_new)) = exp(m_old - m_new) * d_old + sum(exp(chunk - m_new))
# This rescales the previous sum to account for the potentially larger max.
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(logits_chunk - new_m), axis=0)
m = new_m
# Check if the target index falls within this chunk.
target_mask = offsets == target
target_logit = target_logit + tl.sum(tl.where(target_mask, logits_chunk, 0.0), axis=0)
# Compute the cross-entropy loss for this row:
# loss = log(sum(exp(logit_i - m))) + m - logit[target]
# = log(d) + m - target_logit
loss = tl.log(d) + m - target_logit
# Store loss and save m, d for the backward pass.
tl.store(Losses_ptr + row_idx, loss)
tl.store(Max_ptr + row_idx, m)
tl.store(Denom_ptr + row_idx, d)
# =============================================================================
# Backward kernel
# =============================================================================
# The gradient of cross-entropy loss w.r.t. logits is:
# grad_logit[i] = softmax(logits)[i] - (1 if i == target else 0)
# scaled by the upstream gradient (grad_output).
#
# We compute softmax using the saved m and d from the forward pass:
# softmax(logits)[i] = exp(logits[i] - m) / d
#
# Like the forward kernel, we process the vocabulary in chunks to avoid
# materializing the full softmax vector.
# =============================================================================
@triton.jit
def _cross_entropy_bwd_kernel(
# Pointers
Logits_ptr, # Original logits (re-read from memory): shape [num_rows, vocab_size]
Targets_ptr, # Target class indices: shape [num_rows]
GradOutput_ptr, # Upstream gradient (scalar per row): shape [num_rows]
Max_ptr, # Saved max from forward: shape [num_rows]
Denom_ptr, # Saved denominator from forward: shape [num_rows]
GradLogits_ptr, # Output gradient w.r.t. logits: shape [num_rows, vocab_size]
# Dimensions
vocab_size,
stride_logits,
BLOCK_SIZE: tl.constexpr,
DTYPE_IS_FP32: tl.constexpr, # True if input dtype is float32
DTYPE_IS_BF16: tl.constexpr, # True if input dtype is bfloat16 (else float16)
):
row_idx = tl.program_id(0)
logit_row_start = row_idx * stride_logits
# Load saved values and upstream gradient.
target = tl.load(Targets_ptr + row_idx)
grad_output = tl.load(GradOutput_ptr + row_idx)
m = tl.load(Max_ptr + row_idx)
d = tl.load(Denom_ptr + row_idx)
# Process vocabulary in chunks, computing gradient for each chunk.
for start in range(0, vocab_size, BLOCK_SIZE):
offsets = start + tl.arange(0, BLOCK_SIZE)
mask = offsets < vocab_size
logits_chunk = tl.load(
Logits_ptr + logit_row_start + offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
# Compute softmax probabilities for this chunk using saved m and d.
probs = tl.exp(logits_chunk - m) / d
# Subtract 1 at the target position.
is_target = offsets == target
grad = (probs - tl.where(is_target, 1.0, 0.0)) * grad_output
# Cast gradient back to the input dtype.
if DTYPE_IS_FP32:
out_dtype = tl.float32
elif DTYPE_IS_BF16:
out_dtype = tl.bfloat16
else:
out_dtype = tl.float16
tl.store(
GradLogits_ptr + logit_row_start + offsets,
grad.to(out_dtype),
mask=mask,
)
# =============================================================================
# Autograd wrapper
# =============================================================================
class FusedCrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, targets):
# logits: [num_rows, vocab_size], targets: [num_rows]
num_rows, vocab_size = logits.shape
losses = torch.empty(num_rows, dtype=torch.float32, device=logits.device)
saved_max = torch.empty(num_rows, dtype=torch.float32, device=logits.device)
saved_denom = torch.empty(num_rows, dtype=torch.float32, device=logits.device)
BLOCK_SIZE = min(triton.next_power_of_2(vocab_size), 4096)
_cross_entropy_fwd_kernel[(num_rows,)](
logits, targets, losses, saved_max, saved_denom,
vocab_size,
logits.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
)
ctx.save_for_backward(logits, targets, saved_max, saved_denom)
ctx.vocab_size = vocab_size
ctx.BLOCK_SIZE = BLOCK_SIZE
return losses.mean()
@staticmethod
def backward(ctx, grad_output):
logits, targets, saved_max, saved_denom = ctx.saved_tensors
num_rows = logits.shape[0]
grad_logits = torch.empty_like(logits)
# grad_output is a scalar (mean reduction), so scale by 1/num_rows per row.
# Use full_like to create a contiguous tensor (Triton requires contiguous
# memory for pointer arithmetic). We use .item() here — the sync cost is
# acceptable since backward only runs once per step.
grad_per_row = torch.full(
(num_rows,), grad_output.item() / num_rows,
dtype=torch.float32, device=logits.device,
)
_cross_entropy_bwd_kernel[(num_rows,)](
logits, targets, grad_per_row, saved_max, saved_denom, grad_logits,
ctx.vocab_size,
logits.stride(0),
BLOCK_SIZE=ctx.BLOCK_SIZE,
DTYPE_IS_FP32=(logits.dtype == torch.float32),
DTYPE_IS_BF16=(logits.dtype == torch.bfloat16),
)
return grad_logits, None # None for targets
# =============================================================================
# Drop-in nn.Module replacement
# =============================================================================
class TritonCrossEntropyLoss(torch.nn.Module):
"""Fused cross-entropy loss using online softmax. Drop-in replacement for
torch.nn.CrossEntropyLoss with mean reduction."""
def forward(self, logits, targets):
# Flatten logits to 2D if needed (e.g., [batch, seq_len, vocab] -> [batch*seq_len, vocab])
if logits.dim() > 2:
logits = logits.view(-1, logits.size(-1))
if targets.dim() > 1:
targets = targets.view(-1)
return FusedCrossEntropyFunction.apply(logits, targets)