dgx-spark-playbooks/nvidia/cutile-kernels/assets/fmha_optimization_tutorial.py

960 lines
35 KiB
Python
Raw Normal View History

2026-06-03 15:15:33 +00:00
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
FMHA Optimization Tutorial: From Naive to Optimized cuTile Implementation
This script demonstrates step-by-step optimization of Flash Multi-Head Attention
using NVIDIA cuTile, starting from a basic implementation and progressively
adding optimizations until reaching TileGym-level performance.
Target Platform: DGX Spark (sm121) with pre-determined optimal tile sizes.
Note: TileGym supports autotuning, but we use hardcoded values for this tutorial.
Configuration (matches TileGym bench_fused_attention.py):
- Batch: 4, Heads: 32, Head Dim: 128
- Sequence Lengths: 1024, 2048, 4096, 8192, 16384
- Benchmark: triton.testing.do_bench_cudagraph
Usage:
python fmha_optimization_tutorial.py [--iterations N] [--correctness-check]
"""
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from typing import List, Optional
import sys
import torch
LOG_SEPARATOR = "=" * 80
LOG_SUBSEPARATOR = "-" * 60
@dataclass
class BenchmarkResult:
step: int
name: str
description: str
latency_ms: float
tflops: float
speedup_vs_baseline: float
speedup_vs_previous: float
correct: bool
key_changes: List[str]
class Logger:
def __init__(self):
self.results: List[BenchmarkResult] = []
self.logs: List[str] = []
def log(self, msg: str):
print(msg)
self.logs.append(msg)
def section(self, title: str):
self.log(f"\n{LOG_SEPARATOR}")
self.log(f" {title}")
self.log(LOG_SEPARATOR)
def subsection(self, title: str):
self.log(f"\n{LOG_SUBSEPARATOR}")
self.log(f" {title}")
self.log(LOG_SUBSEPARATOR)
def add_result(self, result: BenchmarkResult):
self.results.append(result)
def export_json(self, filepath: str):
data = {
"results": [asdict(r) for r in self.results],
"logs": self.logs
}
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
def export_markdown(self, filepath: str):
with open(filepath, 'w') as f:
f.write("# FMHA Optimization Tutorial Results\n\n")
f.write("## Summary Table\n\n")
f.write("| Step | Name | Latency (ms) | TFLOPS | vs Baseline | vs Previous | Correct |\n")
f.write("|------|------|--------------|--------|-------------|-------------|--------|\n")
for r in self.results:
f.write(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {r.speedup_vs_previous:.2f}x | {'Yes' if r.correct else 'No'} |\n")
f.write("\n## Detailed Steps\n\n")
for r in self.results:
f.write(f"### Step {r.step}: {r.name}\n\n")
f.write(f"**Description**: {r.description}\n\n")
f.write("**Key Changes**:\n")
for change in r.key_changes:
f.write(f"- {change}\n")
f.write(f"\n**Performance**: {r.latency_ms:.3f}ms, {r.tflops:.2f} TFLOPS, {r.speedup_vs_baseline:.2f}x vs baseline\n\n")
logger = Logger()
BATCH = 4
N_HEADS = 32
HEAD_DIM = 128
INV_LOG_2 = 1.0 / math.log(2)
TILE_M = 64
TILE_N = 64
OCCUPANCY = 2
NUM_CTAS = 1
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
raise RuntimeError("CUDA not available")
DEVICE = None
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
return total_flops
def benchmark_fn(fn, warmup=10, iterations=100):
"""Benchmark using triton's do_bench_cudagraph for accurate timing (matches TileGym)"""
try:
import triton
ms = triton.testing.do_bench_cudagraph(fn)
return ms
except (ImportError, Exception):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / iterations * 1000
def verify_correctness(output, reference, atol=1e-2, rtol=1e-2):
try:
torch.testing.assert_close(output, reference, atol=atol, rtol=rtol)
return True
except AssertionError:
max_diff = (output - reference).abs().max().item()
logger.log(f" [WARN] Max difference: {max_diff:.6f}")
return max_diff < 0.1
def reference_fmha(q, k, v, sm_scale, is_causal=True):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
)
def step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True):
return reference_fmha(q, k, v, sm_scale, is_causal)
try:
import cuda.tile as ct
from cuda.tile import RoundingMode as RMd
CUTILE_AVAILABLE = True
except ImportError:
CUTILE_AVAILABLE = False
logger.log("[WARN] cuTile not available. Only PyTorch baseline will run.")
if CUTILE_AVAILABLE:
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_step2_mma(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 2: Basic cuTile FMHA with MMA (Tensor Cores)
- Uses ct.mma() for matrix multiply
- Standard exp() for softmax
- Online softmax algorithm
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True)
m_ij = ct.maximum(m_i, m_ij)
qk = qk - m_ij
p = ct.exp(qk)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step2(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step2_mma,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step3_exp2(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 3: Use exp2 with flush_to_zero for faster math
- exp2(x) = 2^x is faster than exp(x) = e^x on GPU
- Requires scaling adjustment: multiply by 1/log(2)
- flush_to_zero handles denormals efficiently
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step3(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step3_exp2,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step4_load_order(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 4: Optimize K load with order parameter
- Use order=(0,1,3,2) to load K already transposed
- Avoids explicit ct.permute() operation
- Reduces memory traffic
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2)
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step4(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step4_load_order,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step5_latency(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 5: Add latency hints for better pipelining
- latency=2 for K load (prefetch)
- latency=4 for V load (more prefetch distance)
- Helps overlap memory loads with computation
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step5(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step5_latency,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel(occupancy=2)
def fmha_step6_occupancy(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 6: Add occupancy hint
- @ct.kernel(occupancy=2) improves SM utilization
- Allows multiple thread blocks per SM
- Better for hiding memory latency
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step6(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step6_occupancy,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel(occupancy=2)
def fmha_step7_approx_div(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 7: Use approximate division for final normalization
- ct.truediv with rounding_mode=APPROX is faster
- Acceptable accuracy loss for inference
- This matches TileGym's optimized implementation
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step7(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step7_approx_div,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
"""Run TileGym's optimized FMHA for comparison"""
try:
import tilegym
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
except ImportError:
logger.log("[WARN] TileGym not available for comparison")
return None
def run_benchmark(seq_len, iterations=100, check_correct=True):
global DEVICE
DEVICE = get_device()
logger.section(f"FMHA OPTIMIZATION TUTORIAL - SEQ_LEN={seq_len}")
logger.log("Configuration:")
logger.log(f" - Batch: {BATCH}")
logger.log(f" - Heads: {N_HEADS}")
logger.log(f" - Head Dim: {HEAD_DIM}")
logger.log(f" - Sequence Length: {seq_len}")
logger.log(f" - Tile M: {TILE_M}")
logger.log(f" - Tile N: {TILE_N}")
logger.log(" - Precision: float16")
logger.log(" - Causal: True")
logger.log(f" - Iterations: {iterations}")
logger.log(f" - Device: {DEVICE}")
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
ref_output = reference_fmha(q, k, v, sm_scale, is_causal=True)
steps = [
(0, "PyTorch Baseline", "torch.nn.functional.scaled_dot_product_attention",
lambda: step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True),
["PyTorch SDPA with cuDNN backend", "Highly optimized baseline"]),
]
if CUTILE_AVAILABLE:
steps.extend([
(2, "Basic cuTile + MMA", "Tiled FMHA with ct.mma() for Tensor Cores",
lambda: run_step2(q, k, v, sm_scale, is_causal=True),
["@ct.kernel decorator", "ct.mma() for QK and PV products", "Online softmax with exp()"]),
(3, "+ exp2 + flush_to_zero", "Faster exponential math",
lambda: run_step3(q, k, v, sm_scale, is_causal=True),
["ct.exp2() instead of ct.exp()", "flush_to_zero=True for denormals", "qk_scale *= 1/log(2)"]),
(4, "+ Load Order Transpose", "Avoid explicit transpose",
lambda: run_step4(q, k, v, sm_scale, is_causal=True),
["order=(0,1,3,2) for K load", "K loaded already transposed", "Removes ct.permute() call"]),
(5, "+ Latency Hints", "Better memory pipelining",
lambda: run_step5(q, k, v, sm_scale, is_causal=True),
["latency=2 for K load", "latency=4 for V load", "Overlaps loads with compute"]),
(6, "+ Occupancy=2", "Better SM utilization",
lambda: run_step6(q, k, v, sm_scale, is_causal=True),
["@ct.kernel(occupancy=2)", "Multiple blocks per SM", "Hides memory latency"]),
(7, "+ Approx Division (Final)", "Fast final normalization",
lambda: run_step7(q, k, v, sm_scale, is_causal=True),
["ct.truediv with APPROX mode", "Matches TileGym implementation", "Full optimization achieved"]),
])
baseline_latency = None
prev_latency = None
for step_idx, name, desc, fn, changes in steps:
logger.subsection(f"Step {step_idx}: {name}")
logger.log(f"Description: {desc}")
logger.log("Key Changes:")
for change in changes:
logger.log(f" - {change}")
try:
output = fn()
latency_ms = benchmark_fn(fn, warmup=10, iterations=iterations)
tflops = flops * 1e-12 / (latency_ms * 1e-3)
if baseline_latency is None:
baseline_latency = latency_ms
speedup_baseline = 1.0
else:
speedup_baseline = baseline_latency / latency_ms
if prev_latency is None:
speedup_prev = 1.0
else:
speedup_prev = prev_latency / latency_ms
if check_correct and output is not None:
correct = verify_correctness(output, ref_output)
else:
correct = True
logger.log("\nResults:")
logger.log(f" Latency: {latency_ms:.3f} ms")
logger.log(f" TFLOPS: {tflops:.2f}")
logger.log(f" vs Baseline: {speedup_baseline:.2f}x")
logger.log(f" vs Previous: {speedup_prev:.2f}x")
logger.log(f" Correct: {'Yes' if correct else 'No'}")
result = BenchmarkResult(
step=step_idx,
name=name,
description=desc,
latency_ms=latency_ms,
tflops=tflops,
speedup_vs_baseline=speedup_baseline,
speedup_vs_previous=speedup_prev,
correct=correct,
key_changes=changes
)
logger.add_result(result)
prev_latency = latency_ms
except Exception as e:
logger.log(f"\n[ERROR] Step {step_idx} failed: {e}")
import traceback
logger.log(traceback.format_exc())
tilegym_output = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
if tilegym_output is not None:
logger.subsection("TileGym Reference (for comparison)")
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
tilegym_speedup = baseline_latency / tilegym_latency if baseline_latency else 1.0
logger.log("TileGym FMHA:")
logger.log(f" Latency: {tilegym_latency:.3f} ms")
logger.log(f" TFLOPS: {tilegym_tflops:.2f}")
logger.log(f" vs Baseline: {tilegym_speedup:.2f}x")
result = BenchmarkResult(
step=99,
name="TileGym Reference",
description="TileGym's optimized FMHA implementation",
latency_ms=tilegym_latency,
tflops=tilegym_tflops,
speedup_vs_baseline=tilegym_speedup,
speedup_vs_previous=1.0,
correct=True,
key_changes=["Full TileGym implementation", "Pre-tuned for sm121", "Production ready"]
)
logger.add_result(result)
def main():
parser = argparse.ArgumentParser(description="FMHA Optimization Tutorial")
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
parser.add_argument("--seq-len", type=int, default=2048, help="Sequence length (default matches TileGym)")
parser.add_argument("--correctness-check", action="store_true", help="Enable correctness checking")
parser.add_argument("--output-dir", type=str, default=".", help="Output directory for logs")
args = parser.parse_args()
logger.section("FMHA OPTIMIZATION TUTORIAL")
logger.log("From Naive to Optimized cuTile Implementation")
logger.log("Target Platform: DGX Spark (sm121)")
logger.log(f"Tile Sizes: TILE_M={TILE_M}, TILE_N={TILE_N} (hardcoded from TileGym)")
logger.log("Note: TileGym supports autotuning, but we use pre-determined optimal values")
run_benchmark(
seq_len=args.seq_len,
iterations=args.iterations,
check_correct=args.correctness_check
)
logger.section("FINAL SUMMARY")
logger.log("\n| Step | Name | Latency (ms) | TFLOPS | vs Baseline | Correct |")
logger.log("|------|------|--------------|--------|-------------|---------|")
for r in logger.results:
logger.log(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {'Yes' if r.correct else 'No'} |")
json_path = f"{args.output_dir}/fmha_tutorial_results.json"
md_path = f"{args.output_dir}/fmha_tutorial_results.md"
log_path = f"{args.output_dir}/fmha_tutorial_log.txt"
logger.export_json(json_path)
logger.export_markdown(md_path)
with open(log_path, 'w') as f:
f.write('\n'.join(logger.logs))
logger.log("\nResults exported to:")
logger.log(f" - {json_path}")
logger.log(f" - {md_path}")
logger.log(f" - {log_path}")
if __name__ == "__main__":
main()