mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-18 12:32:23 +00:00
259 lines
10 KiB
Python
259 lines
10 KiB
Python
|
|
"""
|
||
|
|
Fused Cross-Entropy Triton Kernel — Online Softmax
|
||
|
|
====================================================
|
||
|
|
|
||
|
|
This module implements a fused cross-entropy loss that avoids materializing the
|
||
|
|
full logit tensor. Standard PyTorch cross-entropy computes:
|
||
|
|
|
||
|
|
loss = -log(softmax(logits)[target])
|
||
|
|
|
||
|
|
This requires computing softmax over the entire vocabulary for every token
|
||
|
|
position. For LLaMA 3.1 8B with vocabulary size 128,256, a batch of 512 tokens
|
||
|
|
produces a logit tensor of shape [512, 128256] — about 250 MB in float32.
|
||
|
|
During training, PyTorch also stores this for the backward pass, roughly
|
||
|
|
doubling the memory cost.
|
||
|
|
|
||
|
|
Our fused kernel uses the **online softmax** algorithm (Milakov & Gimelshein, 2018):
|
||
|
|
instead of computing softmax over all vocabulary entries at once, it processes
|
||
|
|
the vocabulary in chunks. For each chunk, it maintains a running maximum and a
|
||
|
|
running sum-of-exponentials. After processing all chunks, it has enough
|
||
|
|
information to compute the loss — without ever allocating the full [B*T, V]
|
||
|
|
tensor.
|
||
|
|
|
||
|
|
Memory savings: For V=128256, the standard approach allocates O(B*T*V) memory.
|
||
|
|
The online approach allocates O(B*T) — avoiding the full vocabulary-sized
|
||
|
|
intermediate. In practice, the input logits are still retained for the backward
|
||
|
|
pass, so the measured end-to-end memory reduction is ~6x at realistic batch
|
||
|
|
sizes — still a significant saving.
|
||
|
|
|
||
|
|
Key Triton concepts introduced beyond rmsnorm_kernel.py:
|
||
|
|
- Loops inside kernels: Processing vocabulary in chunks with a for loop
|
||
|
|
- tl.where: Conditional element selection (for masking the last chunk)
|
||
|
|
- Multi-pass algorithms: Maintaining running state across iterations
|
||
|
|
"""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import triton
|
||
|
|
import triton.language as tl
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Forward kernel
|
||
|
|
# =============================================================================
|
||
|
|
# Each program handles one row (one token position). It iterates over the
|
||
|
|
# vocabulary in chunks of BLOCK_SIZE, maintaining:
|
||
|
|
# m: running maximum logit (for numerical stability)
|
||
|
|
# d: running sum of exp(logit - m) (the softmax denominator)
|
||
|
|
# After all chunks, the loss for this row is: log(d) + m - logit[target]
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
@triton.jit
|
||
|
|
def _cross_entropy_fwd_kernel(
|
||
|
|
# Pointers
|
||
|
|
Logits_ptr, # Input logits: shape [num_rows, vocab_size]
|
||
|
|
Targets_ptr, # Target class indices: shape [num_rows]
|
||
|
|
Losses_ptr, # Output per-row loss: shape [num_rows]
|
||
|
|
Max_ptr, # Saved running max for backward: shape [num_rows]
|
||
|
|
Denom_ptr, # Saved running denominator for backward: shape [num_rows]
|
||
|
|
# Dimensions
|
||
|
|
vocab_size,
|
||
|
|
stride_logits, # Stride between rows of logits
|
||
|
|
# Compile-time constant
|
||
|
|
BLOCK_SIZE: tl.constexpr,
|
||
|
|
):
|
||
|
|
row_idx = tl.program_id(0)
|
||
|
|
logit_row_start = row_idx * stride_logits
|
||
|
|
target = tl.load(Targets_ptr + row_idx)
|
||
|
|
|
||
|
|
# Initialize the running max and running sum-of-exp.
|
||
|
|
# We start with -inf for the max so any real logit will be larger.
|
||
|
|
m = float("-inf") # Running maximum logit
|
||
|
|
d = 0.0 # Running sum of exp(logit_i - m)
|
||
|
|
|
||
|
|
# Also track the logit value at the target index (needed for the loss).
|
||
|
|
target_logit = 0.0
|
||
|
|
|
||
|
|
# --- Online softmax: iterate over vocabulary in chunks ---
|
||
|
|
# This is the key optimization. Instead of loading all 128K logits at once
|
||
|
|
# (which would require allocating a huge tensor), we process BLOCK_SIZE
|
||
|
|
# logits at a time and update our running statistics.
|
||
|
|
for start in range(0, vocab_size, BLOCK_SIZE):
|
||
|
|
offsets = start + tl.arange(0, BLOCK_SIZE)
|
||
|
|
mask = offsets < vocab_size
|
||
|
|
|
||
|
|
# Load a chunk of logits.
|
||
|
|
logits_chunk = tl.load(
|
||
|
|
Logits_ptr + logit_row_start + offsets,
|
||
|
|
mask=mask,
|
||
|
|
other=float("-inf"), # Out-of-bounds values won't affect max/sum
|
||
|
|
).to(tl.float32)
|
||
|
|
|
||
|
|
# Update running max.
|
||
|
|
chunk_max = tl.max(logits_chunk, axis=0)
|
||
|
|
new_m = tl.maximum(m, chunk_max)
|
||
|
|
|
||
|
|
# Update running sum-of-exp using the new max.
|
||
|
|
# The key identity: sum(exp(x_i - m_new)) = exp(m_old - m_new) * d_old + sum(exp(chunk - m_new))
|
||
|
|
# This rescales the previous sum to account for the potentially larger max.
|
||
|
|
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(logits_chunk - new_m), axis=0)
|
||
|
|
m = new_m
|
||
|
|
|
||
|
|
# Check if the target index falls within this chunk.
|
||
|
|
target_mask = offsets == target
|
||
|
|
target_logit = target_logit + tl.sum(tl.where(target_mask, logits_chunk, 0.0), axis=0)
|
||
|
|
|
||
|
|
# Compute the cross-entropy loss for this row:
|
||
|
|
# loss = log(sum(exp(logit_i - m))) + m - logit[target]
|
||
|
|
# = log(d) + m - target_logit
|
||
|
|
loss = tl.log(d) + m - target_logit
|
||
|
|
|
||
|
|
# Store loss and save m, d for the backward pass.
|
||
|
|
tl.store(Losses_ptr + row_idx, loss)
|
||
|
|
tl.store(Max_ptr + row_idx, m)
|
||
|
|
tl.store(Denom_ptr + row_idx, d)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Backward kernel
|
||
|
|
# =============================================================================
|
||
|
|
# The gradient of cross-entropy loss w.r.t. logits is:
|
||
|
|
# grad_logit[i] = softmax(logits)[i] - (1 if i == target else 0)
|
||
|
|
# scaled by the upstream gradient (grad_output).
|
||
|
|
#
|
||
|
|
# We compute softmax using the saved m and d from the forward pass:
|
||
|
|
# softmax(logits)[i] = exp(logits[i] - m) / d
|
||
|
|
#
|
||
|
|
# Like the forward kernel, we process the vocabulary in chunks to avoid
|
||
|
|
# materializing the full softmax vector.
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
@triton.jit
|
||
|
|
def _cross_entropy_bwd_kernel(
|
||
|
|
# Pointers
|
||
|
|
Logits_ptr, # Original logits (re-read from memory): shape [num_rows, vocab_size]
|
||
|
|
Targets_ptr, # Target class indices: shape [num_rows]
|
||
|
|
GradOutput_ptr, # Upstream gradient (scalar per row): shape [num_rows]
|
||
|
|
Max_ptr, # Saved max from forward: shape [num_rows]
|
||
|
|
Denom_ptr, # Saved denominator from forward: shape [num_rows]
|
||
|
|
GradLogits_ptr, # Output gradient w.r.t. logits: shape [num_rows, vocab_size]
|
||
|
|
# Dimensions
|
||
|
|
vocab_size,
|
||
|
|
stride_logits,
|
||
|
|
BLOCK_SIZE: tl.constexpr,
|
||
|
|
DTYPE_IS_FP32: tl.constexpr, # True if input dtype is float32
|
||
|
|
DTYPE_IS_BF16: tl.constexpr, # True if input dtype is bfloat16 (else float16)
|
||
|
|
):
|
||
|
|
row_idx = tl.program_id(0)
|
||
|
|
logit_row_start = row_idx * stride_logits
|
||
|
|
|
||
|
|
# Load saved values and upstream gradient.
|
||
|
|
target = tl.load(Targets_ptr + row_idx)
|
||
|
|
grad_output = tl.load(GradOutput_ptr + row_idx)
|
||
|
|
m = tl.load(Max_ptr + row_idx)
|
||
|
|
d = tl.load(Denom_ptr + row_idx)
|
||
|
|
|
||
|
|
# Process vocabulary in chunks, computing gradient for each chunk.
|
||
|
|
for start in range(0, vocab_size, BLOCK_SIZE):
|
||
|
|
offsets = start + tl.arange(0, BLOCK_SIZE)
|
||
|
|
mask = offsets < vocab_size
|
||
|
|
|
||
|
|
logits_chunk = tl.load(
|
||
|
|
Logits_ptr + logit_row_start + offsets,
|
||
|
|
mask=mask,
|
||
|
|
other=float("-inf"),
|
||
|
|
).to(tl.float32)
|
||
|
|
|
||
|
|
# Compute softmax probabilities for this chunk using saved m and d.
|
||
|
|
probs = tl.exp(logits_chunk - m) / d
|
||
|
|
|
||
|
|
# Subtract 1 at the target position.
|
||
|
|
is_target = offsets == target
|
||
|
|
grad = (probs - tl.where(is_target, 1.0, 0.0)) * grad_output
|
||
|
|
|
||
|
|
# Cast gradient back to the input dtype.
|
||
|
|
if DTYPE_IS_FP32:
|
||
|
|
out_dtype = tl.float32
|
||
|
|
elif DTYPE_IS_BF16:
|
||
|
|
out_dtype = tl.bfloat16
|
||
|
|
else:
|
||
|
|
out_dtype = tl.float16
|
||
|
|
tl.store(
|
||
|
|
GradLogits_ptr + logit_row_start + offsets,
|
||
|
|
grad.to(out_dtype),
|
||
|
|
mask=mask,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Autograd wrapper
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class FusedCrossEntropyFunction(torch.autograd.Function):
|
||
|
|
@staticmethod
|
||
|
|
def forward(ctx, logits, targets):
|
||
|
|
# logits: [num_rows, vocab_size], targets: [num_rows]
|
||
|
|
num_rows, vocab_size = logits.shape
|
||
|
|
|
||
|
|
losses = torch.empty(num_rows, dtype=torch.float32, device=logits.device)
|
||
|
|
saved_max = torch.empty(num_rows, dtype=torch.float32, device=logits.device)
|
||
|
|
saved_denom = torch.empty(num_rows, dtype=torch.float32, device=logits.device)
|
||
|
|
|
||
|
|
BLOCK_SIZE = min(triton.next_power_of_2(vocab_size), 4096)
|
||
|
|
|
||
|
|
_cross_entropy_fwd_kernel[(num_rows,)](
|
||
|
|
logits, targets, losses, saved_max, saved_denom,
|
||
|
|
vocab_size,
|
||
|
|
logits.stride(0),
|
||
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
||
|
|
)
|
||
|
|
|
||
|
|
ctx.save_for_backward(logits, targets, saved_max, saved_denom)
|
||
|
|
ctx.vocab_size = vocab_size
|
||
|
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||
|
|
|
||
|
|
return losses.mean()
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def backward(ctx, grad_output):
|
||
|
|
logits, targets, saved_max, saved_denom = ctx.saved_tensors
|
||
|
|
num_rows = logits.shape[0]
|
||
|
|
|
||
|
|
grad_logits = torch.empty_like(logits)
|
||
|
|
|
||
|
|
# grad_output is a scalar (mean reduction), so scale by 1/num_rows per row.
|
||
|
|
# Use full_like to create a contiguous tensor (Triton requires contiguous
|
||
|
|
# memory for pointer arithmetic). We use .item() here — the sync cost is
|
||
|
|
# acceptable since backward only runs once per step.
|
||
|
|
grad_per_row = torch.full(
|
||
|
|
(num_rows,), grad_output.item() / num_rows,
|
||
|
|
dtype=torch.float32, device=logits.device,
|
||
|
|
)
|
||
|
|
|
||
|
|
_cross_entropy_bwd_kernel[(num_rows,)](
|
||
|
|
logits, targets, grad_per_row, saved_max, saved_denom, grad_logits,
|
||
|
|
ctx.vocab_size,
|
||
|
|
logits.stride(0),
|
||
|
|
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||
|
|
DTYPE_IS_FP32=(logits.dtype == torch.float32),
|
||
|
|
DTYPE_IS_BF16=(logits.dtype == torch.bfloat16),
|
||
|
|
)
|
||
|
|
|
||
|
|
return grad_logits, None # None for targets
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Drop-in nn.Module replacement
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TritonCrossEntropyLoss(torch.nn.Module):
|
||
|
|
"""Fused cross-entropy loss using online softmax. Drop-in replacement for
|
||
|
|
torch.nn.CrossEntropyLoss with mean reduction."""
|
||
|
|
|
||
|
|
def forward(self, logits, targets):
|
||
|
|
# Flatten logits to 2D if needed (e.g., [batch, seq_len, vocab] -> [batch*seq_len, vocab])
|
||
|
|
if logits.dim() > 2:
|
||
|
|
logits = logits.view(-1, logits.size(-1))
|
||
|
|
if targets.dim() > 1:
|
||
|
|
targets = targets.view(-1)
|
||
|
|
return FusedCrossEntropyFunction.apply(logits, targets)
|