mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-18 20:42:20 +00:00
960 lines
35 KiB
Python
960 lines
35 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
#!/usr/bin/env python3
|
|
"""
|
|
FMHA Optimization Tutorial: From Naive to Optimized cuTile Implementation
|
|
|
|
This script demonstrates step-by-step optimization of Flash Multi-Head Attention
|
|
using NVIDIA cuTile, starting from a basic implementation and progressively
|
|
adding optimizations until reaching TileGym-level performance.
|
|
|
|
Target Platform: DGX Spark (sm121) with pre-determined optimal tile sizes.
|
|
Note: TileGym supports autotuning, but we use hardcoded values for this tutorial.
|
|
|
|
Configuration (matches TileGym bench_fused_attention.py):
|
|
- Batch: 4, Heads: 32, Head Dim: 128
|
|
- Sequence Lengths: 1024, 2048, 4096, 8192, 16384
|
|
- Benchmark: triton.testing.do_bench_cudagraph
|
|
|
|
Usage:
|
|
python fmha_optimization_tutorial.py [--iterations N] [--correctness-check]
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import time
|
|
from dataclasses import dataclass, asdict
|
|
from typing import List, Optional
|
|
import sys
|
|
|
|
import torch
|
|
|
|
LOG_SEPARATOR = "=" * 80
|
|
LOG_SUBSEPARATOR = "-" * 60
|
|
|
|
@dataclass
|
|
class BenchmarkResult:
|
|
step: int
|
|
name: str
|
|
description: str
|
|
latency_ms: float
|
|
tflops: float
|
|
speedup_vs_baseline: float
|
|
speedup_vs_previous: float
|
|
correct: bool
|
|
key_changes: List[str]
|
|
|
|
class Logger:
|
|
def __init__(self):
|
|
self.results: List[BenchmarkResult] = []
|
|
self.logs: List[str] = []
|
|
|
|
def log(self, msg: str):
|
|
print(msg)
|
|
self.logs.append(msg)
|
|
|
|
def section(self, title: str):
|
|
self.log(f"\n{LOG_SEPARATOR}")
|
|
self.log(f" {title}")
|
|
self.log(LOG_SEPARATOR)
|
|
|
|
def subsection(self, title: str):
|
|
self.log(f"\n{LOG_SUBSEPARATOR}")
|
|
self.log(f" {title}")
|
|
self.log(LOG_SUBSEPARATOR)
|
|
|
|
def add_result(self, result: BenchmarkResult):
|
|
self.results.append(result)
|
|
|
|
def export_json(self, filepath: str):
|
|
data = {
|
|
"results": [asdict(r) for r in self.results],
|
|
"logs": self.logs
|
|
}
|
|
with open(filepath, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
def export_markdown(self, filepath: str):
|
|
with open(filepath, 'w') as f:
|
|
f.write("# FMHA Optimization Tutorial Results\n\n")
|
|
f.write("## Summary Table\n\n")
|
|
f.write("| Step | Name | Latency (ms) | TFLOPS | vs Baseline | vs Previous | Correct |\n")
|
|
f.write("|------|------|--------------|--------|-------------|-------------|--------|\n")
|
|
for r in self.results:
|
|
f.write(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {r.speedup_vs_previous:.2f}x | {'Yes' if r.correct else 'No'} |\n")
|
|
f.write("\n## Detailed Steps\n\n")
|
|
for r in self.results:
|
|
f.write(f"### Step {r.step}: {r.name}\n\n")
|
|
f.write(f"**Description**: {r.description}\n\n")
|
|
f.write("**Key Changes**:\n")
|
|
for change in r.key_changes:
|
|
f.write(f"- {change}\n")
|
|
f.write(f"\n**Performance**: {r.latency_ms:.3f}ms, {r.tflops:.2f} TFLOPS, {r.speedup_vs_baseline:.2f}x vs baseline\n\n")
|
|
|
|
logger = Logger()
|
|
|
|
BATCH = 4
|
|
N_HEADS = 32
|
|
HEAD_DIM = 128
|
|
INV_LOG_2 = 1.0 / math.log(2)
|
|
|
|
TILE_M = 64
|
|
TILE_N = 64
|
|
OCCUPANCY = 2
|
|
NUM_CTAS = 1
|
|
|
|
def get_device():
|
|
if torch.cuda.is_available():
|
|
return torch.device("cuda")
|
|
raise RuntimeError("CUDA not available")
|
|
|
|
DEVICE = None
|
|
|
|
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
|
|
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
|
|
total_flops = 2 * flops_per_matmul
|
|
if causal:
|
|
total_flops *= 0.5
|
|
return total_flops
|
|
|
|
def benchmark_fn(fn, warmup=10, iterations=100):
|
|
"""Benchmark using triton's do_bench_cudagraph for accurate timing (matches TileGym)"""
|
|
try:
|
|
import triton
|
|
ms = triton.testing.do_bench_cudagraph(fn)
|
|
return ms
|
|
except (ImportError, Exception):
|
|
for _ in range(warmup):
|
|
fn()
|
|
torch.cuda.synchronize()
|
|
|
|
start = time.perf_counter()
|
|
for _ in range(iterations):
|
|
fn()
|
|
torch.cuda.synchronize()
|
|
end = time.perf_counter()
|
|
|
|
return (end - start) / iterations * 1000
|
|
|
|
def verify_correctness(output, reference, atol=1e-2, rtol=1e-2):
|
|
try:
|
|
torch.testing.assert_close(output, reference, atol=atol, rtol=rtol)
|
|
return True
|
|
except AssertionError:
|
|
max_diff = (output - reference).abs().max().item()
|
|
logger.log(f" [WARN] Max difference: {max_diff:.6f}")
|
|
return max_diff < 0.1
|
|
|
|
def reference_fmha(q, k, v, sm_scale, is_causal=True):
|
|
return torch.nn.functional.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
|
|
)
|
|
|
|
def step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True):
|
|
return reference_fmha(q, k, v, sm_scale, is_causal)
|
|
|
|
try:
|
|
import cuda.tile as ct
|
|
from cuda.tile import RoundingMode as RMd
|
|
CUTILE_AVAILABLE = True
|
|
except ImportError:
|
|
CUTILE_AVAILABLE = False
|
|
logger.log("[WARN] cuTile not available. Only PyTorch baseline will run.")
|
|
|
|
if CUTILE_AVAILABLE:
|
|
ConstInt = ct.Constant[int]
|
|
ConstBool = ct.Constant[bool]
|
|
|
|
@ct.kernel()
|
|
def fmha_step2_mma(
|
|
Q, K, V, Out,
|
|
qk_scale: float,
|
|
TILE_D: ConstInt,
|
|
H: ConstInt,
|
|
TILE_M: ConstInt,
|
|
TILE_N: ConstInt,
|
|
CAUSAL: ConstBool,
|
|
):
|
|
"""
|
|
Step 2: Basic cuTile FMHA with MMA (Tensor Cores)
|
|
- Uses ct.mma() for matrix multiply
|
|
- Standard exp() for softmax
|
|
- Online softmax algorithm
|
|
"""
|
|
bid_x = ct.bid(0)
|
|
bid_y = ct.bid(1)
|
|
batch_idx = bid_y // H
|
|
head_idx = bid_y % H
|
|
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|
offs_m = offs_m[:, None]
|
|
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|
offs_n_tile = offs_n_tile[None, :]
|
|
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|
q = q.reshape((TILE_M, TILE_D))
|
|
|
|
k_seqlen = K.shape[2]
|
|
if CAUSAL:
|
|
m_end = (bid_x + 1) * TILE_M
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|
else:
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|
mask_start = k_seqlen // TILE_N
|
|
|
|
for j in range(0, Tc):
|
|
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
|
k_t = ct.permute(k_tile, (1, 0))
|
|
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|
qk = ct.mma(q, k_t, qk)
|
|
qk = qk * qk_scale
|
|
|
|
if CAUSAL and j >= mask_start:
|
|
offs_n = j * TILE_N + offs_n_tile
|
|
mask = offs_m >= offs_n
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True)
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|
qk = qk - m_ij
|
|
|
|
p = ct.exp(qk)
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|
alpha = ct.exp(m_i - m_ij)
|
|
l_i = l_i * alpha + l_ij
|
|
acc = acc * alpha
|
|
|
|
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|
p_cast = p.astype(Q.dtype)
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|
|
|
m_i = m_ij
|
|
|
|
acc = acc / l_i
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|
|
|
def run_step2(q, k, v, sm_scale, is_causal=True):
|
|
batch_size, num_heads, seq_len, head_dim = q.shape
|
|
o = torch.empty_like(q)
|
|
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
|
ct.launch(
|
|
torch.cuda.current_stream(),
|
|
grid,
|
|
fmha_step2_mma,
|
|
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
|
)
|
|
return o
|
|
|
|
@ct.kernel()
|
|
def fmha_step3_exp2(
|
|
Q, K, V, Out,
|
|
qk_scale: float,
|
|
TILE_D: ConstInt,
|
|
H: ConstInt,
|
|
TILE_M: ConstInt,
|
|
TILE_N: ConstInt,
|
|
CAUSAL: ConstBool,
|
|
):
|
|
"""
|
|
Step 3: Use exp2 with flush_to_zero for faster math
|
|
- exp2(x) = 2^x is faster than exp(x) = e^x on GPU
|
|
- Requires scaling adjustment: multiply by 1/log(2)
|
|
- flush_to_zero handles denormals efficiently
|
|
"""
|
|
bid_x = ct.bid(0)
|
|
bid_y = ct.bid(1)
|
|
batch_idx = bid_y // H
|
|
head_idx = bid_y % H
|
|
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|
offs_m = offs_m[:, None]
|
|
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|
offs_n_tile = offs_n_tile[None, :]
|
|
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|
q = q.reshape((TILE_M, TILE_D))
|
|
|
|
k_seqlen = K.shape[2]
|
|
if CAUSAL:
|
|
m_end = (bid_x + 1) * TILE_M
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|
else:
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|
mask_start = k_seqlen // TILE_N
|
|
|
|
for j in range(0, Tc):
|
|
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
|
k_t = ct.permute(k_tile, (1, 0))
|
|
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|
qk = ct.mma(q, k_t, qk)
|
|
|
|
if CAUSAL and j >= mask_start:
|
|
offs_n = j * TILE_N + offs_n_tile
|
|
mask = offs_m >= offs_n
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|
qk = qk * qk_scale_log2 - m_ij
|
|
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|
l_i = l_i * alpha + l_ij
|
|
acc = acc * alpha
|
|
|
|
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|
p_cast = p.astype(Q.dtype)
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|
|
|
m_i = m_ij
|
|
|
|
acc = acc / l_i
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|
|
|
def run_step3(q, k, v, sm_scale, is_causal=True):
|
|
batch_size, num_heads, seq_len, head_dim = q.shape
|
|
o = torch.empty_like(q)
|
|
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
|
ct.launch(
|
|
torch.cuda.current_stream(),
|
|
grid,
|
|
fmha_step3_exp2,
|
|
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
|
)
|
|
return o
|
|
|
|
@ct.kernel()
|
|
def fmha_step4_load_order(
|
|
Q, K, V, Out,
|
|
qk_scale: float,
|
|
TILE_D: ConstInt,
|
|
H: ConstInt,
|
|
TILE_M: ConstInt,
|
|
TILE_N: ConstInt,
|
|
CAUSAL: ConstBool,
|
|
):
|
|
"""
|
|
Step 4: Optimize K load with order parameter
|
|
- Use order=(0,1,3,2) to load K already transposed
|
|
- Avoids explicit ct.permute() operation
|
|
- Reduces memory traffic
|
|
"""
|
|
bid_x = ct.bid(0)
|
|
bid_y = ct.bid(1)
|
|
batch_idx = bid_y // H
|
|
head_idx = bid_y % H
|
|
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|
offs_m = offs_m[:, None]
|
|
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|
offs_n_tile = offs_n_tile[None, :]
|
|
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|
q = q.reshape((TILE_M, TILE_D))
|
|
|
|
k_seqlen = K.shape[2]
|
|
if CAUSAL:
|
|
m_end = (bid_x + 1) * TILE_M
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|
else:
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|
mask_start = k_seqlen // TILE_N
|
|
|
|
for j in range(0, Tc):
|
|
k_t = ct.load(
|
|
K,
|
|
index=(batch_idx, head_idx, 0, j),
|
|
shape=(1, 1, TILE_D, TILE_N),
|
|
order=(0, 1, 3, 2)
|
|
)
|
|
k_t = k_t.reshape((TILE_D, TILE_N))
|
|
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|
qk = ct.mma(q, k_t, qk)
|
|
|
|
if CAUSAL and j >= mask_start:
|
|
offs_n = j * TILE_N + offs_n_tile
|
|
mask = offs_m >= offs_n
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|
qk = qk * qk_scale_log2 - m_ij
|
|
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|
l_i = l_i * alpha + l_ij
|
|
acc = acc * alpha
|
|
|
|
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|
p_cast = p.astype(Q.dtype)
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|
|
|
m_i = m_ij
|
|
|
|
acc = acc / l_i
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|
|
|
def run_step4(q, k, v, sm_scale, is_causal=True):
|
|
batch_size, num_heads, seq_len, head_dim = q.shape
|
|
o = torch.empty_like(q)
|
|
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
|
ct.launch(
|
|
torch.cuda.current_stream(),
|
|
grid,
|
|
fmha_step4_load_order,
|
|
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
|
)
|
|
return o
|
|
|
|
@ct.kernel()
|
|
def fmha_step5_latency(
|
|
Q, K, V, Out,
|
|
qk_scale: float,
|
|
TILE_D: ConstInt,
|
|
H: ConstInt,
|
|
TILE_M: ConstInt,
|
|
TILE_N: ConstInt,
|
|
CAUSAL: ConstBool,
|
|
):
|
|
"""
|
|
Step 5: Add latency hints for better pipelining
|
|
- latency=2 for K load (prefetch)
|
|
- latency=4 for V load (more prefetch distance)
|
|
- Helps overlap memory loads with computation
|
|
"""
|
|
bid_x = ct.bid(0)
|
|
bid_y = ct.bid(1)
|
|
batch_idx = bid_y // H
|
|
head_idx = bid_y % H
|
|
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|
offs_m = offs_m[:, None]
|
|
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|
offs_n_tile = offs_n_tile[None, :]
|
|
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|
q = q.reshape((TILE_M, TILE_D))
|
|
|
|
k_seqlen = K.shape[2]
|
|
if CAUSAL:
|
|
m_end = (bid_x + 1) * TILE_M
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|
else:
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|
mask_start = k_seqlen // TILE_N
|
|
|
|
for j in range(0, Tc):
|
|
k_t = ct.load(
|
|
K,
|
|
index=(batch_idx, head_idx, 0, j),
|
|
shape=(1, 1, TILE_D, TILE_N),
|
|
order=(0, 1, 3, 2),
|
|
latency=2
|
|
)
|
|
k_t = k_t.reshape((TILE_D, TILE_N))
|
|
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|
qk = ct.mma(q, k_t, qk)
|
|
|
|
if CAUSAL and j >= mask_start:
|
|
offs_n = j * TILE_N + offs_n_tile
|
|
mask = offs_m >= offs_n
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|
qk = qk * qk_scale_log2 - m_ij
|
|
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|
l_i = l_i * alpha + l_ij
|
|
acc = acc * alpha
|
|
|
|
v_tile = ct.load(
|
|
V,
|
|
index=(batch_idx, head_idx, j, 0),
|
|
shape=(1, 1, TILE_N, TILE_D),
|
|
latency=4
|
|
)
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|
p_cast = p.astype(Q.dtype)
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|
|
|
m_i = m_ij
|
|
|
|
acc = acc / l_i
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|
|
|
def run_step5(q, k, v, sm_scale, is_causal=True):
|
|
batch_size, num_heads, seq_len, head_dim = q.shape
|
|
o = torch.empty_like(q)
|
|
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
|
ct.launch(
|
|
torch.cuda.current_stream(),
|
|
grid,
|
|
fmha_step5_latency,
|
|
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
|
)
|
|
return o
|
|
|
|
@ct.kernel(occupancy=2)
|
|
def fmha_step6_occupancy(
|
|
Q, K, V, Out,
|
|
qk_scale: float,
|
|
TILE_D: ConstInt,
|
|
H: ConstInt,
|
|
TILE_M: ConstInt,
|
|
TILE_N: ConstInt,
|
|
CAUSAL: ConstBool,
|
|
):
|
|
"""
|
|
Step 6: Add occupancy hint
|
|
- @ct.kernel(occupancy=2) improves SM utilization
|
|
- Allows multiple thread blocks per SM
|
|
- Better for hiding memory latency
|
|
"""
|
|
bid_x = ct.bid(0)
|
|
bid_y = ct.bid(1)
|
|
batch_idx = bid_y // H
|
|
head_idx = bid_y % H
|
|
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|
offs_m = offs_m[:, None]
|
|
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|
offs_n_tile = offs_n_tile[None, :]
|
|
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|
q = q.reshape((TILE_M, TILE_D))
|
|
|
|
k_seqlen = K.shape[2]
|
|
if CAUSAL:
|
|
m_end = (bid_x + 1) * TILE_M
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|
else:
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|
mask_start = k_seqlen // TILE_N
|
|
|
|
for j in range(0, Tc):
|
|
k_t = ct.load(
|
|
K,
|
|
index=(batch_idx, head_idx, 0, j),
|
|
shape=(1, 1, TILE_D, TILE_N),
|
|
order=(0, 1, 3, 2),
|
|
latency=2
|
|
)
|
|
k_t = k_t.reshape((TILE_D, TILE_N))
|
|
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|
qk = ct.mma(q, k_t, qk)
|
|
|
|
if CAUSAL and j >= mask_start:
|
|
offs_n = j * TILE_N + offs_n_tile
|
|
mask = offs_m >= offs_n
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|
qk = qk * qk_scale_log2 - m_ij
|
|
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|
l_i = l_i * alpha + l_ij
|
|
acc = acc * alpha
|
|
|
|
v_tile = ct.load(
|
|
V,
|
|
index=(batch_idx, head_idx, j, 0),
|
|
shape=(1, 1, TILE_N, TILE_D),
|
|
latency=4
|
|
)
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|
p_cast = p.astype(Q.dtype)
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|
|
|
m_i = m_ij
|
|
|
|
acc = acc / l_i
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|
|
|
def run_step6(q, k, v, sm_scale, is_causal=True):
|
|
batch_size, num_heads, seq_len, head_dim = q.shape
|
|
o = torch.empty_like(q)
|
|
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
|
ct.launch(
|
|
torch.cuda.current_stream(),
|
|
grid,
|
|
fmha_step6_occupancy,
|
|
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
|
)
|
|
return o
|
|
|
|
@ct.kernel(occupancy=2)
|
|
def fmha_step7_approx_div(
|
|
Q, K, V, Out,
|
|
qk_scale: float,
|
|
TILE_D: ConstInt,
|
|
H: ConstInt,
|
|
TILE_M: ConstInt,
|
|
TILE_N: ConstInt,
|
|
CAUSAL: ConstBool,
|
|
):
|
|
"""
|
|
Step 7: Use approximate division for final normalization
|
|
- ct.truediv with rounding_mode=APPROX is faster
|
|
- Acceptable accuracy loss for inference
|
|
- This matches TileGym's optimized implementation
|
|
"""
|
|
bid_x = ct.bid(0)
|
|
bid_y = ct.bid(1)
|
|
batch_idx = bid_y // H
|
|
head_idx = bid_y % H
|
|
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|
offs_m = offs_m[:, None]
|
|
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|
offs_n_tile = offs_n_tile[None, :]
|
|
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|
q = q.reshape((TILE_M, TILE_D))
|
|
|
|
k_seqlen = K.shape[2]
|
|
if CAUSAL:
|
|
m_end = (bid_x + 1) * TILE_M
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|
else:
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|
mask_start = k_seqlen // TILE_N
|
|
|
|
for j in range(0, Tc):
|
|
k_t = ct.load(
|
|
K,
|
|
index=(batch_idx, head_idx, 0, j),
|
|
shape=(1, 1, TILE_D, TILE_N),
|
|
order=(0, 1, 3, 2),
|
|
latency=2
|
|
)
|
|
k_t = k_t.reshape((TILE_D, TILE_N))
|
|
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|
qk = ct.mma(q, k_t, qk)
|
|
|
|
if CAUSAL and j >= mask_start:
|
|
offs_n = j * TILE_N + offs_n_tile
|
|
mask = offs_m >= offs_n
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|
qk = qk * qk_scale_log2 - m_ij
|
|
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|
l_i = l_i * alpha + l_ij
|
|
acc = acc * alpha
|
|
|
|
v_tile = ct.load(
|
|
V,
|
|
index=(batch_idx, head_idx, j, 0),
|
|
shape=(1, 1, TILE_N, TILE_D),
|
|
latency=4
|
|
)
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|
p_cast = p.astype(Q.dtype)
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|
|
|
m_i = m_ij
|
|
|
|
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|
|
|
def run_step7(q, k, v, sm_scale, is_causal=True):
|
|
batch_size, num_heads, seq_len, head_dim = q.shape
|
|
o = torch.empty_like(q)
|
|
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
|
ct.launch(
|
|
torch.cuda.current_stream(),
|
|
grid,
|
|
fmha_step7_approx_div,
|
|
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
|
)
|
|
return o
|
|
|
|
|
|
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
|
|
"""Run TileGym's optimized FMHA for comparison"""
|
|
try:
|
|
import tilegym
|
|
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
|
|
except ImportError:
|
|
logger.log("[WARN] TileGym not available for comparison")
|
|
return None
|
|
|
|
|
|
def run_benchmark(seq_len, iterations=100, check_correct=True):
|
|
global DEVICE
|
|
DEVICE = get_device()
|
|
|
|
logger.section(f"FMHA OPTIMIZATION TUTORIAL - SEQ_LEN={seq_len}")
|
|
logger.log("Configuration:")
|
|
logger.log(f" - Batch: {BATCH}")
|
|
logger.log(f" - Heads: {N_HEADS}")
|
|
logger.log(f" - Head Dim: {HEAD_DIM}")
|
|
logger.log(f" - Sequence Length: {seq_len}")
|
|
logger.log(f" - Tile M: {TILE_M}")
|
|
logger.log(f" - Tile N: {TILE_N}")
|
|
logger.log(" - Precision: float16")
|
|
logger.log(" - Causal: True")
|
|
logger.log(f" - Iterations: {iterations}")
|
|
logger.log(f" - Device: {DEVICE}")
|
|
|
|
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
|
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
|
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
|
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
|
|
|
|
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
|
|
|
|
ref_output = reference_fmha(q, k, v, sm_scale, is_causal=True)
|
|
|
|
steps = [
|
|
(0, "PyTorch Baseline", "torch.nn.functional.scaled_dot_product_attention",
|
|
lambda: step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True),
|
|
["PyTorch SDPA with cuDNN backend", "Highly optimized baseline"]),
|
|
]
|
|
|
|
if CUTILE_AVAILABLE:
|
|
steps.extend([
|
|
(2, "Basic cuTile + MMA", "Tiled FMHA with ct.mma() for Tensor Cores",
|
|
lambda: run_step2(q, k, v, sm_scale, is_causal=True),
|
|
["@ct.kernel decorator", "ct.mma() for QK and PV products", "Online softmax with exp()"]),
|
|
|
|
(3, "+ exp2 + flush_to_zero", "Faster exponential math",
|
|
lambda: run_step3(q, k, v, sm_scale, is_causal=True),
|
|
["ct.exp2() instead of ct.exp()", "flush_to_zero=True for denormals", "qk_scale *= 1/log(2)"]),
|
|
|
|
(4, "+ Load Order Transpose", "Avoid explicit transpose",
|
|
lambda: run_step4(q, k, v, sm_scale, is_causal=True),
|
|
["order=(0,1,3,2) for K load", "K loaded already transposed", "Removes ct.permute() call"]),
|
|
|
|
(5, "+ Latency Hints", "Better memory pipelining",
|
|
lambda: run_step5(q, k, v, sm_scale, is_causal=True),
|
|
["latency=2 for K load", "latency=4 for V load", "Overlaps loads with compute"]),
|
|
|
|
(6, "+ Occupancy=2", "Better SM utilization",
|
|
lambda: run_step6(q, k, v, sm_scale, is_causal=True),
|
|
["@ct.kernel(occupancy=2)", "Multiple blocks per SM", "Hides memory latency"]),
|
|
|
|
(7, "+ Approx Division (Final)", "Fast final normalization",
|
|
lambda: run_step7(q, k, v, sm_scale, is_causal=True),
|
|
["ct.truediv with APPROX mode", "Matches TileGym implementation", "Full optimization achieved"]),
|
|
])
|
|
|
|
baseline_latency = None
|
|
prev_latency = None
|
|
|
|
for step_idx, name, desc, fn, changes in steps:
|
|
logger.subsection(f"Step {step_idx}: {name}")
|
|
logger.log(f"Description: {desc}")
|
|
logger.log("Key Changes:")
|
|
for change in changes:
|
|
logger.log(f" - {change}")
|
|
|
|
try:
|
|
output = fn()
|
|
latency_ms = benchmark_fn(fn, warmup=10, iterations=iterations)
|
|
tflops = flops * 1e-12 / (latency_ms * 1e-3)
|
|
|
|
if baseline_latency is None:
|
|
baseline_latency = latency_ms
|
|
speedup_baseline = 1.0
|
|
else:
|
|
speedup_baseline = baseline_latency / latency_ms
|
|
|
|
if prev_latency is None:
|
|
speedup_prev = 1.0
|
|
else:
|
|
speedup_prev = prev_latency / latency_ms
|
|
|
|
if check_correct and output is not None:
|
|
correct = verify_correctness(output, ref_output)
|
|
else:
|
|
correct = True
|
|
|
|
logger.log("\nResults:")
|
|
logger.log(f" Latency: {latency_ms:.3f} ms")
|
|
logger.log(f" TFLOPS: {tflops:.2f}")
|
|
logger.log(f" vs Baseline: {speedup_baseline:.2f}x")
|
|
logger.log(f" vs Previous: {speedup_prev:.2f}x")
|
|
logger.log(f" Correct: {'Yes' if correct else 'No'}")
|
|
|
|
result = BenchmarkResult(
|
|
step=step_idx,
|
|
name=name,
|
|
description=desc,
|
|
latency_ms=latency_ms,
|
|
tflops=tflops,
|
|
speedup_vs_baseline=speedup_baseline,
|
|
speedup_vs_previous=speedup_prev,
|
|
correct=correct,
|
|
key_changes=changes
|
|
)
|
|
logger.add_result(result)
|
|
|
|
prev_latency = latency_ms
|
|
|
|
except Exception as e:
|
|
logger.log(f"\n[ERROR] Step {step_idx} failed: {e}")
|
|
import traceback
|
|
logger.log(traceback.format_exc())
|
|
|
|
tilegym_output = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
|
if tilegym_output is not None:
|
|
logger.subsection("TileGym Reference (for comparison)")
|
|
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
|
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
|
|
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
|
|
tilegym_speedup = baseline_latency / tilegym_latency if baseline_latency else 1.0
|
|
|
|
logger.log("TileGym FMHA:")
|
|
logger.log(f" Latency: {tilegym_latency:.3f} ms")
|
|
logger.log(f" TFLOPS: {tilegym_tflops:.2f}")
|
|
logger.log(f" vs Baseline: {tilegym_speedup:.2f}x")
|
|
|
|
result = BenchmarkResult(
|
|
step=99,
|
|
name="TileGym Reference",
|
|
description="TileGym's optimized FMHA implementation",
|
|
latency_ms=tilegym_latency,
|
|
tflops=tilegym_tflops,
|
|
speedup_vs_baseline=tilegym_speedup,
|
|
speedup_vs_previous=1.0,
|
|
correct=True,
|
|
key_changes=["Full TileGym implementation", "Pre-tuned for sm121", "Production ready"]
|
|
)
|
|
logger.add_result(result)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="FMHA Optimization Tutorial")
|
|
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
|
|
parser.add_argument("--seq-len", type=int, default=2048, help="Sequence length (default matches TileGym)")
|
|
parser.add_argument("--correctness-check", action="store_true", help="Enable correctness checking")
|
|
parser.add_argument("--output-dir", type=str, default=".", help="Output directory for logs")
|
|
args = parser.parse_args()
|
|
|
|
logger.section("FMHA OPTIMIZATION TUTORIAL")
|
|
logger.log("From Naive to Optimized cuTile Implementation")
|
|
logger.log("Target Platform: DGX Spark (sm121)")
|
|
logger.log(f"Tile Sizes: TILE_M={TILE_M}, TILE_N={TILE_N} (hardcoded from TileGym)")
|
|
logger.log("Note: TileGym supports autotuning, but we use pre-determined optimal values")
|
|
|
|
run_benchmark(
|
|
seq_len=args.seq_len,
|
|
iterations=args.iterations,
|
|
check_correct=args.correctness_check
|
|
)
|
|
|
|
logger.section("FINAL SUMMARY")
|
|
logger.log("\n| Step | Name | Latency (ms) | TFLOPS | vs Baseline | Correct |")
|
|
logger.log("|------|------|--------------|--------|-------------|---------|")
|
|
for r in logger.results:
|
|
logger.log(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {'Yes' if r.correct else 'No'} |")
|
|
|
|
json_path = f"{args.output_dir}/fmha_tutorial_results.json"
|
|
md_path = f"{args.output_dir}/fmha_tutorial_results.md"
|
|
log_path = f"{args.output_dir}/fmha_tutorial_log.txt"
|
|
|
|
logger.export_json(json_path)
|
|
logger.export_markdown(md_path)
|
|
|
|
with open(log_path, 'w') as f:
|
|
f.write('\n'.join(logger.logs))
|
|
|
|
logger.log("\nResults exported to:")
|
|
logger.log(f" - {json_path}")
|
|
logger.log(f" - {md_path}")
|
|
logger.log(f" - {log_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|