dgx-spark-playbooks/nvidia/cutile-kernels/assets/fmha_scaling_analysis.py
2026-06-03 15:15:33 +00:00

892 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
FMHA Scaling Analysis: How Optimizations Impact Performance at Different Sizes
This script demonstrates:
1. How FMHA performance scales with sequence length
2. Which optimizations provide the most benefit at larger sizes
3. Target-specific configurations for different GPU architectures
Target Platforms (from TileGym):
- DGX Spark (sm120/sm121): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
- Blackwell B300 (sm100): TILE_M=256, TILE_N=128 or 128x128, num_ctas=1, occupancy=1-2
Usage:
python fmha_scaling_analysis.py [--iterations N]
"""
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from typing import List
from types import SimpleNamespace
import torch
LOG_SEPARATOR = "=" * 80
LOG_SUBSEPARATOR = "-" * 60
@dataclass
class StepResult:
step: int
name: str
latency_ms: float
tflops: float
speedup_vs_baseline: float
@dataclass
class SeqLenResult:
seq_len: int
steps: List[StepResult]
best_step: int
best_speedup: float
tilegym_latency_ms: float
tilegym_tflops: float
tilegym_speedup: float
class Logger:
def __init__(self):
self.results: List[SeqLenResult] = []
self.logs: List[str] = []
def log(self, msg: str):
print(msg)
self.logs.append(msg)
def section(self, title: str):
self.log(f"\n{LOG_SEPARATOR}")
self.log(f" {title}")
self.log(LOG_SEPARATOR)
def subsection(self, title: str):
self.log(f"\n{LOG_SUBSEPARATOR}")
self.log(f" {title}")
self.log(LOG_SUBSEPARATOR)
logger = Logger()
BATCH = 4
N_HEADS = 32
HEAD_DIM = 128
INV_LOG_2 = 1.0 / math.log(2)
SEQ_LENS = [1024, 2048, 4096, 8192, 16384]
def get_fmha_config():
"""
Get target-specific FMHA configuration (from TileGym attention.py)
Returns configs matching TileGym's _fmha_autotune_configs():
- sm120/sm121 (DGX Spark): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
- sm100 (Blackwell B300): Two configs to try via autotuning
"""
gpu_capability = torch.cuda.get_device_capability()
if gpu_capability in [(12, 0), (12, 1)]:
return [
SimpleNamespace(
name="DGX Spark (sm121)",
TILE_M=64,
TILE_N=64,
num_ctas=1,
occupancy=2
)
]
else:
return [
SimpleNamespace(
name="Blackwell B300 (sm100) - Config 1",
TILE_M=256,
TILE_N=128,
num_ctas=1,
occupancy=1
),
SimpleNamespace(
name="Blackwell B300 (sm100) - Config 2",
TILE_M=128,
TILE_N=128,
num_ctas=1,
occupancy=2
),
]
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
raise RuntimeError("CUDA not available")
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
return total_flops
def benchmark_fn(fn, warmup=10, iterations=100):
"""Benchmark using triton's do_bench_cudagraph for accurate timing"""
try:
import triton
# Use triton's cudagraph benchmark - same as TileGym
ms = triton.testing.do_bench_cudagraph(fn)
return ms
except (ImportError, Exception):
# Fallback to manual timing
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / iterations * 1000
def reference_fmha(q, k, v, sm_scale, is_causal=True):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
)
try:
import cuda.tile as ct
from cuda.tile import RoundingMode as RMd
CUTILE_AVAILABLE = True
except ImportError:
CUTILE_AVAILABLE = False
logger.log("[WARN] cuTile not available.")
if CUTILE_AVAILABLE:
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_basic(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 1: Basic cuTile - no optimizations"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True)
m_ij = ct.maximum(m_i, m_ij)
qk = qk - m_ij
p = ct.exp(qk)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel()
def fmha_math_opt(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 2: Math optimizations - exp2 + flush_to_zero"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel()
def fmha_memory_opt(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 3: Memory optimizations - load order + latency hints"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel(occupancy=2)
def fmha_full_opt_occ2(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 4a: Full optimization with occupancy=2 (for sm120/sm121)"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel(occupancy=1)
def fmha_full_opt_occ1(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 4b: Full optimization with occupancy=1 (for sm100 Blackwell)"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_kernel(kernel_fn, q, k, v, sm_scale, tile_m, tile_n, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / tile_m), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
kernel_fn,
(q, k, v, o, sm_scale, head_dim, num_heads, tile_m, tile_n, is_causal)
)
return o
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
try:
import tilegym
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
except ImportError:
return None
def run_scaling_analysis(iterations=100):
device = get_device()
gpu_cap = torch.cuda.get_device_capability()
gpu_name = torch.cuda.get_device_name()
configs = get_fmha_config()
primary_cfg = configs[0]
TILE_M = primary_cfg.TILE_M
TILE_N = primary_cfg.TILE_N
logger.section("FMHA SCALING ANALYSIS (TileGym Benchmark Match)")
logger.log("Matching TileGym bench_fused_attention.py configuration")
logger.log(f"\nGPU: {gpu_name} (sm_{gpu_cap[0]}{gpu_cap[1]})")
logger.subsection("TARGET-SPECIFIC CONFIGURATION (from TileGym)")
for cfg in configs:
logger.log(f"\n {cfg.name}:")
logger.log(f" TILE_M={cfg.TILE_M}, TILE_N={cfg.TILE_N}")
logger.log(f" num_ctas={cfg.num_ctas}, occupancy={cfg.occupancy}")
logger.log(f"\nUsing primary config: TILE_M={TILE_M}, TILE_N={TILE_N}, occupancy={primary_cfg.occupancy}")
logger.log("\nTest Configuration (matches TileGym bench_fused_attention.py):")
logger.log(f" Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}")
logger.log(" Causal: True, Precision: float16")
logger.log(f" Sequence Lengths: {SEQ_LENS}")
logger.log(" Benchmark: triton.testing.do_bench_cudagraph (same as TileGym)")
logger.section("OPTIMIZATION STEPS")
logger.log(f"""
Step 0: PyTorch Baseline
- torch.nn.functional.scaled_dot_product_attention
- Uses cuDNN Flash Attention backend
- Highly optimized reference
Step 1: Basic cuTile (TILE_M={TILE_M}, TILE_N={TILE_N})
- @ct.kernel with ct.mma() for Tensor Cores
- Standard exp() for softmax
- Explicit transpose with ct.permute()
- No memory/occupancy hints
Step 2: Math Optimizations
- ct.exp2() instead of ct.exp() (faster on GPU)
- flush_to_zero=True for denormals
- Scale adjustment: multiply by 1/log(2)
Step 3: Memory Optimizations
- Load order=(0,1,3,2) for implicit K transpose
- Latency hints: K=2, V=4 for prefetching
- Overlaps memory loads with computation
Step 4: Full Optimization (Target-Specific)
- @ct.kernel(occupancy={primary_cfg.occupancy}) for {'sm120/121' if primary_cfg.occupancy == 2 else 'sm100'}
- ct.truediv with APPROX rounding mode
- Matches TileGym production implementation
""")
logger.section("PLATFORM DIFFERENCES: DGX Spark vs Blackwell B300")
logger.log("""
| Parameter | DGX Spark (sm121) | Blackwell B300 (sm100) |
|--------------|-------------------|------------------------|
| TILE_M | 64 | 256 or 128 |
| TILE_N | 64 | 128 |
| num_ctas | 1 | 1 |
| occupancy | 2 | 1 or 2 |
Why the difference?
- B300 has more SMs and larger shared memory -> can use bigger tiles
- B300 benefits from larger tiles (256x128) with lower occupancy
- DGX Spark needs smaller tiles (64x64) with higher occupancy to hide latency
- B300's higher memory bandwidth makes larger tiles more efficient
""")
all_results = []
select_kernel = fmha_full_opt_occ2 if primary_cfg.occupancy == 2 else fmha_full_opt_occ1
for seq_len in SEQ_LENS:
logger.subsection(f"Sequence Length: {seq_len}")
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
steps_results = []
baseline_fn = lambda: reference_fmha(q, k, v, sm_scale, is_causal=True)
baseline_latency = benchmark_fn(baseline_fn, warmup=10, iterations=iterations)
baseline_tflops = flops * 1e-12 / (baseline_latency * 1e-3)
steps_results.append(StepResult(0, "PyTorch Baseline", baseline_latency, baseline_tflops, 1.0))
if CUTILE_AVAILABLE:
kernels = [
(1, "Basic cuTile", fmha_basic),
(2, "Math Opt (exp2)", fmha_math_opt),
(3, "Memory Opt (order+latency)", fmha_memory_opt),
(4, f"Full Opt (occ={primary_cfg.occupancy})", select_kernel),
]
for step, name, kernel in kernels:
try:
fn = lambda kernel=kernel: run_kernel(kernel, q, k, v, sm_scale, TILE_M, TILE_N, is_causal=True)
latency = benchmark_fn(fn, warmup=10, iterations=iterations)
tflops = flops * 1e-12 / (latency * 1e-3)
speedup = baseline_latency / latency
steps_results.append(StepResult(step, name, latency, tflops, speedup))
except Exception as e:
logger.log(f" [ERROR] Step {step} failed: {e}")
tilegym_latency = 0.0
tilegym_tflops = 0.0
tilegym_speedup = 0.0
tilegym_out = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
if tilegym_out is not None:
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
tilegym_speedup = baseline_latency / tilegym_latency
best_step = max(steps_results, key=lambda x: x.speedup_vs_baseline)
result = SeqLenResult(
seq_len=seq_len,
steps=steps_results,
best_step=best_step.step,
best_speedup=best_step.speedup_vs_baseline,
tilegym_latency_ms=tilegym_latency,
tilegym_tflops=tilegym_tflops,
tilegym_speedup=tilegym_speedup,
)
all_results.append(result)
logger.log("\n | Step | Name | Latency (ms) | TFLOPS | Speedup |")
logger.log(" |------|------|--------------|--------|---------|")
for sr in steps_results:
logger.log(f" | {sr.step} | {sr.name:<28} | {sr.latency_ms:>10.3f} | {sr.tflops:>6.2f} | {sr.speedup_vs_baseline:>6.2f}x |")
if tilegym_latency > 0:
logger.log(f" | TG | TileGym Reference | {tilegym_latency:>10.3f} | {tilegym_tflops:>6.2f} | {tilegym_speedup:>6.2f}x |")
logger.log(f"\n Best: Step {best_step.step} ({best_step.name}) with {best_step.speedup_vs_baseline:.2f}x speedup")
logger.results = all_results
return all_results
def print_summary(results: List[SeqLenResult]):
configs = get_fmha_config()
primary_cfg = configs[0]
logger.section("SCALING SUMMARY")
logger.log(f"\nTarget Config: TILE_M={primary_cfg.TILE_M}, TILE_N={primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}")
logger.log("\n## Performance vs Sequence Length\n")
logger.log("| Seq Len | Baseline (ms) | Full Opt (ms) | Speedup | TileGym (ms) | TG Speedup |")
logger.log("|---------|---------------|---------------|---------|--------------|------------|")
for r in results:
baseline = next((s for s in r.steps if s.step == 0), None)
full_opt = next((s for s in r.steps if s.step == 4), None)
if baseline and full_opt:
logger.log(f"| {r.seq_len:>7} | {baseline.latency_ms:>13.3f} | {full_opt.latency_ms:>13.3f} | {full_opt.speedup_vs_baseline:>6.2f}x | {r.tilegym_latency_ms:>12.3f} | {r.tilegym_speedup:>9.2f}x |")
logger.log("\n## Optimization Impact by Sequence Length\n")
logger.log("| Seq Len | Basic | +Math | +Memory | +Full | Best |")
logger.log("|---------|-------|-------|---------|-------|------|")
for r in results:
row = f"| {r.seq_len:>7} |"
for step in [1, 2, 3, 4]:
sr = next((s for s in r.steps if s.step == step), None)
if sr:
row += f" {sr.speedup_vs_baseline:>5.2f}x |"
else:
row += " N/A |"
row += f" {r.best_speedup:>4.2f}x |"
logger.log(row)
logger.section("KEY INSIGHTS")
logger.log("""
## Why Larger Sequences Benefit More from Optimization
1. **Memory Bandwidth Dominance**
- Attention has O(N²) memory complexity for the QK^T matrix
- At seq_len=8192: 8192² × 4 bytes = 256MB per head per batch
- Memory optimizations (order, latency hints) have larger impact
2. **More K-Loop Iterations**
- At seq_len=512: 8 K-tiles (512/64) for sm121, 2 K-tiles (512/256) for sm100
- At seq_len=8192: 128 K-tiles for sm121, 32 K-tiles for sm100
- Latency hiding through pipelining amortizes over more iterations
3. **Better Occupancy Utilization**
- More tiles = more parallelism opportunities
- sm121 uses occupancy=2 (smaller tiles, more blocks)
- sm100 uses occupancy=1 with larger tiles (256x128)
4. **Platform-Specific Tuning**
- DGX Spark (sm121): 64x64 tiles, occupancy=2 - optimized for bandwidth-limited workloads
- B300 (sm100): 256x128 tiles, occupancy=1 - optimized for compute-heavy workloads
## Optimization Priority by Problem Size
Small (seq_len <= 1024):
- Basic cuTile often sufficient
- Focus on correctness first
Medium (1024 < seq_len <= 4096):
- Math optimizations (exp2) provide ~5% gain
- Memory optimizations start to matter
Large (seq_len > 4096):
- Full optimization stack critical
- Platform-specific tuning essential
- Memory pipelining becomes essential
""")
def export_results(results: List[SeqLenResult], output_dir: str):
configs = get_fmha_config()
primary_cfg = configs[0]
data = {
"config": {
"batch": BATCH,
"n_heads": N_HEADS,
"head_dim": HEAD_DIM,
"tile_m": primary_cfg.TILE_M,
"tile_n": primary_cfg.TILE_N,
"occupancy": primary_cfg.occupancy,
"num_ctas": primary_cfg.num_ctas,
"platform": primary_cfg.name,
},
"results": [
{
"seq_len": r.seq_len,
"steps": [asdict(s) for s in r.steps],
"best_step": r.best_step,
"best_speedup": r.best_speedup,
"tilegym_latency_ms": r.tilegym_latency_ms,
"tilegym_tflops": r.tilegym_tflops,
"tilegym_speedup": r.tilegym_speedup,
}
for r in results
]
}
json_path = f"{output_dir}/fmha_scaling_results.json"
with open(json_path, 'w') as f:
json.dump(data, f, indent=2)
md_path = f"{output_dir}/fmha_scaling_results.md"
with open(md_path, 'w') as f:
f.write("# FMHA Scaling Analysis Results\n\n")
f.write("## Configuration\n")
f.write(f"- Platform: {primary_cfg.name}\n")
f.write(f"- Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}\n")
f.write(f"- Tile: {primary_cfg.TILE_M}x{primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}\n\n")
f.write("## Target-Specific Configs (from TileGym)\n\n")
f.write("| Platform | TILE_M | TILE_N | num_ctas | occupancy |\n")
f.write("|----------|--------|--------|----------|----------|\n")
f.write("| DGX Spark (sm121) | 64 | 64 | 1 | 2 |\n")
f.write("| B300 (sm100) Config 1 | 256 | 128 | 1 | 1 |\n")
f.write("| B300 (sm100) Config 2 | 128 | 128 | 1 | 2 |\n\n")
f.write("## Results by Sequence Length\n\n")
for r in results:
f.write(f"### Seq Len = {r.seq_len}\n\n")
f.write("| Step | Name | Latency (ms) | TFLOPS | Speedup |\n")
f.write("|------|------|--------------|--------|--------|\n")
for s in r.steps:
f.write(f"| {s.step} | {s.name} | {s.latency_ms:.3f} | {s.tflops:.2f} | {s.speedup_vs_baseline:.2f}x |\n")
if r.tilegym_latency_ms > 0:
f.write(f"| TG | TileGym Reference | {r.tilegym_latency_ms:.3f} | {r.tilegym_tflops:.2f} | {r.tilegym_speedup:.2f}x |\n")
f.write("\n")
log_path = f"{output_dir}/fmha_scaling_log.txt"
with open(log_path, 'w') as f:
f.write('\n'.join(logger.logs))
logger.log(f"\nResults exported to:")
logger.log(f" - {json_path}")
logger.log(f" - {md_path}")
logger.log(f" - {log_path}")
def main():
parser = argparse.ArgumentParser(description="FMHA Scaling Analysis")
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
args = parser.parse_args()
results = run_scaling_analysis(iterations=args.iterations)
print_summary(results)
export_results(results, args.output_dir)
if __name__ == "__main__":
main()