mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-18 04:22:21 +00:00
892 lines
32 KiB
Python
892 lines
32 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
# SPDX-License-Identifier: Apache-2.0
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
#!/usr/bin/env python3
|
||
"""
|
||
FMHA Scaling Analysis: How Optimizations Impact Performance at Different Sizes
|
||
|
||
This script demonstrates:
|
||
1. How FMHA performance scales with sequence length
|
||
2. Which optimizations provide the most benefit at larger sizes
|
||
3. Target-specific configurations for different GPU architectures
|
||
|
||
Target Platforms (from TileGym):
|
||
- DGX Spark (sm120/sm121): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
||
- Blackwell B300 (sm100): TILE_M=256, TILE_N=128 or 128x128, num_ctas=1, occupancy=1-2
|
||
|
||
Usage:
|
||
python fmha_scaling_analysis.py [--iterations N]
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import math
|
||
import time
|
||
from dataclasses import dataclass, asdict
|
||
from typing import List
|
||
from types import SimpleNamespace
|
||
|
||
import torch
|
||
|
||
LOG_SEPARATOR = "=" * 80
|
||
LOG_SUBSEPARATOR = "-" * 60
|
||
|
||
@dataclass
|
||
class StepResult:
|
||
step: int
|
||
name: str
|
||
latency_ms: float
|
||
tflops: float
|
||
speedup_vs_baseline: float
|
||
|
||
@dataclass
|
||
class SeqLenResult:
|
||
seq_len: int
|
||
steps: List[StepResult]
|
||
best_step: int
|
||
best_speedup: float
|
||
tilegym_latency_ms: float
|
||
tilegym_tflops: float
|
||
tilegym_speedup: float
|
||
|
||
class Logger:
|
||
def __init__(self):
|
||
self.results: List[SeqLenResult] = []
|
||
self.logs: List[str] = []
|
||
|
||
def log(self, msg: str):
|
||
print(msg)
|
||
self.logs.append(msg)
|
||
|
||
def section(self, title: str):
|
||
self.log(f"\n{LOG_SEPARATOR}")
|
||
self.log(f" {title}")
|
||
self.log(LOG_SEPARATOR)
|
||
|
||
def subsection(self, title: str):
|
||
self.log(f"\n{LOG_SUBSEPARATOR}")
|
||
self.log(f" {title}")
|
||
self.log(LOG_SUBSEPARATOR)
|
||
|
||
logger = Logger()
|
||
|
||
BATCH = 4
|
||
N_HEADS = 32
|
||
HEAD_DIM = 128
|
||
INV_LOG_2 = 1.0 / math.log(2)
|
||
|
||
SEQ_LENS = [1024, 2048, 4096, 8192, 16384]
|
||
|
||
|
||
def get_fmha_config():
|
||
"""
|
||
Get target-specific FMHA configuration (from TileGym attention.py)
|
||
|
||
Returns configs matching TileGym's _fmha_autotune_configs():
|
||
- sm120/sm121 (DGX Spark): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
||
- sm100 (Blackwell B300): Two configs to try via autotuning
|
||
"""
|
||
gpu_capability = torch.cuda.get_device_capability()
|
||
|
||
if gpu_capability in [(12, 0), (12, 1)]:
|
||
return [
|
||
SimpleNamespace(
|
||
name="DGX Spark (sm121)",
|
||
TILE_M=64,
|
||
TILE_N=64,
|
||
num_ctas=1,
|
||
occupancy=2
|
||
)
|
||
]
|
||
else:
|
||
return [
|
||
SimpleNamespace(
|
||
name="Blackwell B300 (sm100) - Config 1",
|
||
TILE_M=256,
|
||
TILE_N=128,
|
||
num_ctas=1,
|
||
occupancy=1
|
||
),
|
||
SimpleNamespace(
|
||
name="Blackwell B300 (sm100) - Config 2",
|
||
TILE_M=128,
|
||
TILE_N=128,
|
||
num_ctas=1,
|
||
occupancy=2
|
||
),
|
||
]
|
||
|
||
|
||
def get_device():
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
raise RuntimeError("CUDA not available")
|
||
|
||
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
|
||
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
|
||
total_flops = 2 * flops_per_matmul
|
||
if causal:
|
||
total_flops *= 0.5
|
||
return total_flops
|
||
|
||
def benchmark_fn(fn, warmup=10, iterations=100):
|
||
"""Benchmark using triton's do_bench_cudagraph for accurate timing"""
|
||
try:
|
||
import triton
|
||
# Use triton's cudagraph benchmark - same as TileGym
|
||
ms = triton.testing.do_bench_cudagraph(fn)
|
||
return ms
|
||
except (ImportError, Exception):
|
||
# Fallback to manual timing
|
||
for _ in range(warmup):
|
||
fn()
|
||
torch.cuda.synchronize()
|
||
|
||
start = time.perf_counter()
|
||
for _ in range(iterations):
|
||
fn()
|
||
torch.cuda.synchronize()
|
||
end = time.perf_counter()
|
||
|
||
return (end - start) / iterations * 1000
|
||
|
||
def reference_fmha(q, k, v, sm_scale, is_causal=True):
|
||
return torch.nn.functional.scaled_dot_product_attention(
|
||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
|
||
)
|
||
|
||
try:
|
||
import cuda.tile as ct
|
||
from cuda.tile import RoundingMode as RMd
|
||
CUTILE_AVAILABLE = True
|
||
except ImportError:
|
||
CUTILE_AVAILABLE = False
|
||
logger.log("[WARN] cuTile not available.")
|
||
|
||
if CUTILE_AVAILABLE:
|
||
ConstInt = ct.Constant[int]
|
||
ConstBool = ct.Constant[bool]
|
||
|
||
@ct.kernel()
|
||
def fmha_basic(
|
||
Q, K, V, Out,
|
||
qk_scale: float,
|
||
TILE_D: ConstInt,
|
||
H: ConstInt,
|
||
TILE_M: ConstInt,
|
||
TILE_N: ConstInt,
|
||
CAUSAL: ConstBool,
|
||
):
|
||
"""Step 1: Basic cuTile - no optimizations"""
|
||
bid_x = ct.bid(0)
|
||
bid_y = ct.bid(1)
|
||
batch_idx = bid_y // H
|
||
head_idx = bid_y % H
|
||
|
||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||
offs_m = offs_m[:, None]
|
||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||
offs_n_tile = offs_n_tile[None, :]
|
||
|
||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||
|
||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||
q = q.reshape((TILE_M, TILE_D))
|
||
|
||
k_seqlen = K.shape[2]
|
||
if CAUSAL:
|
||
m_end = (bid_x + 1) * TILE_M
|
||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||
mask_start = (bid_x * TILE_M) // TILE_N
|
||
else:
|
||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||
mask_start = k_seqlen // TILE_N
|
||
|
||
for j in range(0, Tc):
|
||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||
k_t = ct.permute(k_tile, (1, 0))
|
||
|
||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||
qk = ct.mma(q, k_t, qk)
|
||
qk = qk * qk_scale
|
||
|
||
if CAUSAL and j >= mask_start:
|
||
offs_n = j * TILE_N + offs_n_tile
|
||
mask = offs_m >= offs_n
|
||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||
|
||
m_ij = ct.max(qk, axis=-1, keepdims=True)
|
||
m_ij = ct.maximum(m_i, m_ij)
|
||
qk = qk - m_ij
|
||
|
||
p = ct.exp(qk)
|
||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||
alpha = ct.exp(m_i - m_ij)
|
||
l_i = l_i * alpha + l_ij
|
||
acc = acc * alpha
|
||
|
||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||
p_cast = p.astype(Q.dtype)
|
||
acc = ct.mma(p_cast, v_tile, acc)
|
||
m_i = m_ij
|
||
|
||
acc = acc / l_i
|
||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||
|
||
@ct.kernel()
|
||
def fmha_math_opt(
|
||
Q, K, V, Out,
|
||
qk_scale: float,
|
||
TILE_D: ConstInt,
|
||
H: ConstInt,
|
||
TILE_M: ConstInt,
|
||
TILE_N: ConstInt,
|
||
CAUSAL: ConstBool,
|
||
):
|
||
"""Step 2: Math optimizations - exp2 + flush_to_zero"""
|
||
bid_x = ct.bid(0)
|
||
bid_y = ct.bid(1)
|
||
batch_idx = bid_y // H
|
||
head_idx = bid_y % H
|
||
|
||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||
|
||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||
offs_m = offs_m[:, None]
|
||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||
offs_n_tile = offs_n_tile[None, :]
|
||
|
||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||
|
||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||
q = q.reshape((TILE_M, TILE_D))
|
||
|
||
k_seqlen = K.shape[2]
|
||
if CAUSAL:
|
||
m_end = (bid_x + 1) * TILE_M
|
||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||
mask_start = (bid_x * TILE_M) // TILE_N
|
||
else:
|
||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||
mask_start = k_seqlen // TILE_N
|
||
|
||
for j in range(0, Tc):
|
||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||
k_t = ct.permute(k_tile, (1, 0))
|
||
|
||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||
qk = ct.mma(q, k_t, qk)
|
||
|
||
if CAUSAL and j >= mask_start:
|
||
offs_n = j * TILE_N + offs_n_tile
|
||
mask = offs_m >= offs_n
|
||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||
|
||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||
m_ij = ct.maximum(m_i, m_ij)
|
||
qk = qk * qk_scale_log2 - m_ij
|
||
|
||
p = ct.exp2(qk, flush_to_zero=True)
|
||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||
l_i = l_i * alpha + l_ij
|
||
acc = acc * alpha
|
||
|
||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||
p_cast = p.astype(Q.dtype)
|
||
acc = ct.mma(p_cast, v_tile, acc)
|
||
m_i = m_ij
|
||
|
||
acc = acc / l_i
|
||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||
|
||
@ct.kernel()
|
||
def fmha_memory_opt(
|
||
Q, K, V, Out,
|
||
qk_scale: float,
|
||
TILE_D: ConstInt,
|
||
H: ConstInt,
|
||
TILE_M: ConstInt,
|
||
TILE_N: ConstInt,
|
||
CAUSAL: ConstBool,
|
||
):
|
||
"""Step 3: Memory optimizations - load order + latency hints"""
|
||
bid_x = ct.bid(0)
|
||
bid_y = ct.bid(1)
|
||
batch_idx = bid_y // H
|
||
head_idx = bid_y % H
|
||
|
||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||
|
||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||
offs_m = offs_m[:, None]
|
||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||
offs_n_tile = offs_n_tile[None, :]
|
||
|
||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||
|
||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||
q = q.reshape((TILE_M, TILE_D))
|
||
|
||
k_seqlen = K.shape[2]
|
||
if CAUSAL:
|
||
m_end = (bid_x + 1) * TILE_M
|
||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||
mask_start = (bid_x * TILE_M) // TILE_N
|
||
else:
|
||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||
mask_start = k_seqlen // TILE_N
|
||
|
||
for j in range(0, Tc):
|
||
k_t = ct.load(
|
||
K,
|
||
index=(batch_idx, head_idx, 0, j),
|
||
shape=(1, 1, TILE_D, TILE_N),
|
||
order=(0, 1, 3, 2),
|
||
latency=2
|
||
)
|
||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||
|
||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||
qk = ct.mma(q, k_t, qk)
|
||
|
||
if CAUSAL and j >= mask_start:
|
||
offs_n = j * TILE_N + offs_n_tile
|
||
mask = offs_m >= offs_n
|
||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||
|
||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||
m_ij = ct.maximum(m_i, m_ij)
|
||
qk = qk * qk_scale_log2 - m_ij
|
||
|
||
p = ct.exp2(qk, flush_to_zero=True)
|
||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||
l_i = l_i * alpha + l_ij
|
||
acc = acc * alpha
|
||
|
||
v_tile = ct.load(
|
||
V,
|
||
index=(batch_idx, head_idx, j, 0),
|
||
shape=(1, 1, TILE_N, TILE_D),
|
||
latency=4
|
||
)
|
||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||
p_cast = p.astype(Q.dtype)
|
||
acc = ct.mma(p_cast, v_tile, acc)
|
||
m_i = m_ij
|
||
|
||
acc = acc / l_i
|
||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||
|
||
@ct.kernel(occupancy=2)
|
||
def fmha_full_opt_occ2(
|
||
Q, K, V, Out,
|
||
qk_scale: float,
|
||
TILE_D: ConstInt,
|
||
H: ConstInt,
|
||
TILE_M: ConstInt,
|
||
TILE_N: ConstInt,
|
||
CAUSAL: ConstBool,
|
||
):
|
||
"""Step 4a: Full optimization with occupancy=2 (for sm120/sm121)"""
|
||
bid_x = ct.bid(0)
|
||
bid_y = ct.bid(1)
|
||
batch_idx = bid_y // H
|
||
head_idx = bid_y % H
|
||
|
||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||
|
||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||
offs_m = offs_m[:, None]
|
||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||
offs_n_tile = offs_n_tile[None, :]
|
||
|
||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||
|
||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||
q = q.reshape((TILE_M, TILE_D))
|
||
|
||
k_seqlen = K.shape[2]
|
||
if CAUSAL:
|
||
m_end = (bid_x + 1) * TILE_M
|
||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||
mask_start = (bid_x * TILE_M) // TILE_N
|
||
else:
|
||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||
mask_start = k_seqlen // TILE_N
|
||
|
||
for j in range(0, Tc):
|
||
k_t = ct.load(
|
||
K,
|
||
index=(batch_idx, head_idx, 0, j),
|
||
shape=(1, 1, TILE_D, TILE_N),
|
||
order=(0, 1, 3, 2),
|
||
latency=2
|
||
)
|
||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||
|
||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||
qk = ct.mma(q, k_t, qk)
|
||
|
||
if CAUSAL and j >= mask_start:
|
||
offs_n = j * TILE_N + offs_n_tile
|
||
mask = offs_m >= offs_n
|
||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||
|
||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||
m_ij = ct.maximum(m_i, m_ij)
|
||
qk = qk * qk_scale_log2 - m_ij
|
||
|
||
p = ct.exp2(qk, flush_to_zero=True)
|
||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||
l_i = l_i * alpha + l_ij
|
||
acc = acc * alpha
|
||
|
||
v_tile = ct.load(
|
||
V,
|
||
index=(batch_idx, head_idx, j, 0),
|
||
shape=(1, 1, TILE_N, TILE_D),
|
||
latency=4
|
||
)
|
||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||
p_cast = p.astype(Q.dtype)
|
||
acc = ct.mma(p_cast, v_tile, acc)
|
||
m_i = m_ij
|
||
|
||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||
|
||
@ct.kernel(occupancy=1)
|
||
def fmha_full_opt_occ1(
|
||
Q, K, V, Out,
|
||
qk_scale: float,
|
||
TILE_D: ConstInt,
|
||
H: ConstInt,
|
||
TILE_M: ConstInt,
|
||
TILE_N: ConstInt,
|
||
CAUSAL: ConstBool,
|
||
):
|
||
"""Step 4b: Full optimization with occupancy=1 (for sm100 Blackwell)"""
|
||
bid_x = ct.bid(0)
|
||
bid_y = ct.bid(1)
|
||
batch_idx = bid_y // H
|
||
head_idx = bid_y % H
|
||
|
||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||
|
||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||
offs_m = offs_m[:, None]
|
||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||
offs_n_tile = offs_n_tile[None, :]
|
||
|
||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||
|
||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||
q = q.reshape((TILE_M, TILE_D))
|
||
|
||
k_seqlen = K.shape[2]
|
||
if CAUSAL:
|
||
m_end = (bid_x + 1) * TILE_M
|
||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||
mask_start = (bid_x * TILE_M) // TILE_N
|
||
else:
|
||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||
mask_start = k_seqlen // TILE_N
|
||
|
||
for j in range(0, Tc):
|
||
k_t = ct.load(
|
||
K,
|
||
index=(batch_idx, head_idx, 0, j),
|
||
shape=(1, 1, TILE_D, TILE_N),
|
||
order=(0, 1, 3, 2),
|
||
latency=2
|
||
)
|
||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||
|
||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||
qk = ct.mma(q, k_t, qk)
|
||
|
||
if CAUSAL and j >= mask_start:
|
||
offs_n = j * TILE_N + offs_n_tile
|
||
mask = offs_m >= offs_n
|
||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||
|
||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||
m_ij = ct.maximum(m_i, m_ij)
|
||
qk = qk * qk_scale_log2 - m_ij
|
||
|
||
p = ct.exp2(qk, flush_to_zero=True)
|
||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||
l_i = l_i * alpha + l_ij
|
||
acc = acc * alpha
|
||
|
||
v_tile = ct.load(
|
||
V,
|
||
index=(batch_idx, head_idx, j, 0),
|
||
shape=(1, 1, TILE_N, TILE_D),
|
||
latency=4
|
||
)
|
||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||
p_cast = p.astype(Q.dtype)
|
||
acc = ct.mma(p_cast, v_tile, acc)
|
||
m_i = m_ij
|
||
|
||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||
|
||
def run_kernel(kernel_fn, q, k, v, sm_scale, tile_m, tile_n, is_causal=True):
|
||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||
o = torch.empty_like(q)
|
||
grid = (math.ceil(seq_len / tile_m), batch_size * num_heads, 1)
|
||
ct.launch(
|
||
torch.cuda.current_stream(),
|
||
grid,
|
||
kernel_fn,
|
||
(q, k, v, o, sm_scale, head_dim, num_heads, tile_m, tile_n, is_causal)
|
||
)
|
||
return o
|
||
|
||
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
|
||
try:
|
||
import tilegym
|
||
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
|
||
except ImportError:
|
||
return None
|
||
|
||
|
||
def run_scaling_analysis(iterations=100):
|
||
device = get_device()
|
||
gpu_cap = torch.cuda.get_device_capability()
|
||
gpu_name = torch.cuda.get_device_name()
|
||
|
||
configs = get_fmha_config()
|
||
primary_cfg = configs[0]
|
||
TILE_M = primary_cfg.TILE_M
|
||
TILE_N = primary_cfg.TILE_N
|
||
|
||
logger.section("FMHA SCALING ANALYSIS (TileGym Benchmark Match)")
|
||
logger.log("Matching TileGym bench_fused_attention.py configuration")
|
||
logger.log(f"\nGPU: {gpu_name} (sm_{gpu_cap[0]}{gpu_cap[1]})")
|
||
|
||
logger.subsection("TARGET-SPECIFIC CONFIGURATION (from TileGym)")
|
||
for cfg in configs:
|
||
logger.log(f"\n {cfg.name}:")
|
||
logger.log(f" TILE_M={cfg.TILE_M}, TILE_N={cfg.TILE_N}")
|
||
logger.log(f" num_ctas={cfg.num_ctas}, occupancy={cfg.occupancy}")
|
||
|
||
logger.log(f"\nUsing primary config: TILE_M={TILE_M}, TILE_N={TILE_N}, occupancy={primary_cfg.occupancy}")
|
||
logger.log("\nTest Configuration (matches TileGym bench_fused_attention.py):")
|
||
logger.log(f" Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}")
|
||
logger.log(" Causal: True, Precision: float16")
|
||
logger.log(f" Sequence Lengths: {SEQ_LENS}")
|
||
logger.log(" Benchmark: triton.testing.do_bench_cudagraph (same as TileGym)")
|
||
|
||
logger.section("OPTIMIZATION STEPS")
|
||
logger.log(f"""
|
||
Step 0: PyTorch Baseline
|
||
- torch.nn.functional.scaled_dot_product_attention
|
||
- Uses cuDNN Flash Attention backend
|
||
- Highly optimized reference
|
||
|
||
Step 1: Basic cuTile (TILE_M={TILE_M}, TILE_N={TILE_N})
|
||
- @ct.kernel with ct.mma() for Tensor Cores
|
||
- Standard exp() for softmax
|
||
- Explicit transpose with ct.permute()
|
||
- No memory/occupancy hints
|
||
|
||
Step 2: Math Optimizations
|
||
- ct.exp2() instead of ct.exp() (faster on GPU)
|
||
- flush_to_zero=True for denormals
|
||
- Scale adjustment: multiply by 1/log(2)
|
||
|
||
Step 3: Memory Optimizations
|
||
- Load order=(0,1,3,2) for implicit K transpose
|
||
- Latency hints: K=2, V=4 for prefetching
|
||
- Overlaps memory loads with computation
|
||
|
||
Step 4: Full Optimization (Target-Specific)
|
||
- @ct.kernel(occupancy={primary_cfg.occupancy}) for {'sm120/121' if primary_cfg.occupancy == 2 else 'sm100'}
|
||
- ct.truediv with APPROX rounding mode
|
||
- Matches TileGym production implementation
|
||
""")
|
||
|
||
logger.section("PLATFORM DIFFERENCES: DGX Spark vs Blackwell B300")
|
||
logger.log("""
|
||
| Parameter | DGX Spark (sm121) | Blackwell B300 (sm100) |
|
||
|--------------|-------------------|------------------------|
|
||
| TILE_M | 64 | 256 or 128 |
|
||
| TILE_N | 64 | 128 |
|
||
| num_ctas | 1 | 1 |
|
||
| occupancy | 2 | 1 or 2 |
|
||
|
||
Why the difference?
|
||
- B300 has more SMs and larger shared memory -> can use bigger tiles
|
||
- B300 benefits from larger tiles (256x128) with lower occupancy
|
||
- DGX Spark needs smaller tiles (64x64) with higher occupancy to hide latency
|
||
- B300's higher memory bandwidth makes larger tiles more efficient
|
||
""")
|
||
|
||
all_results = []
|
||
|
||
select_kernel = fmha_full_opt_occ2 if primary_cfg.occupancy == 2 else fmha_full_opt_occ1
|
||
|
||
for seq_len in SEQ_LENS:
|
||
logger.subsection(f"Sequence Length: {seq_len}")
|
||
|
||
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
|
||
|
||
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
|
||
|
||
steps_results = []
|
||
|
||
baseline_fn = lambda: reference_fmha(q, k, v, sm_scale, is_causal=True)
|
||
baseline_latency = benchmark_fn(baseline_fn, warmup=10, iterations=iterations)
|
||
baseline_tflops = flops * 1e-12 / (baseline_latency * 1e-3)
|
||
steps_results.append(StepResult(0, "PyTorch Baseline", baseline_latency, baseline_tflops, 1.0))
|
||
|
||
if CUTILE_AVAILABLE:
|
||
kernels = [
|
||
(1, "Basic cuTile", fmha_basic),
|
||
(2, "Math Opt (exp2)", fmha_math_opt),
|
||
(3, "Memory Opt (order+latency)", fmha_memory_opt),
|
||
(4, f"Full Opt (occ={primary_cfg.occupancy})", select_kernel),
|
||
]
|
||
|
||
for step, name, kernel in kernels:
|
||
try:
|
||
fn = lambda kernel=kernel: run_kernel(kernel, q, k, v, sm_scale, TILE_M, TILE_N, is_causal=True)
|
||
latency = benchmark_fn(fn, warmup=10, iterations=iterations)
|
||
tflops = flops * 1e-12 / (latency * 1e-3)
|
||
speedup = baseline_latency / latency
|
||
steps_results.append(StepResult(step, name, latency, tflops, speedup))
|
||
except Exception as e:
|
||
logger.log(f" [ERROR] Step {step} failed: {e}")
|
||
|
||
tilegym_latency = 0.0
|
||
tilegym_tflops = 0.0
|
||
tilegym_speedup = 0.0
|
||
tilegym_out = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||
if tilegym_out is not None:
|
||
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
|
||
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
|
||
tilegym_speedup = baseline_latency / tilegym_latency
|
||
|
||
best_step = max(steps_results, key=lambda x: x.speedup_vs_baseline)
|
||
|
||
result = SeqLenResult(
|
||
seq_len=seq_len,
|
||
steps=steps_results,
|
||
best_step=best_step.step,
|
||
best_speedup=best_step.speedup_vs_baseline,
|
||
tilegym_latency_ms=tilegym_latency,
|
||
tilegym_tflops=tilegym_tflops,
|
||
tilegym_speedup=tilegym_speedup,
|
||
)
|
||
all_results.append(result)
|
||
|
||
logger.log("\n | Step | Name | Latency (ms) | TFLOPS | Speedup |")
|
||
logger.log(" |------|------|--------------|--------|---------|")
|
||
|
||
for sr in steps_results:
|
||
logger.log(f" | {sr.step} | {sr.name:<28} | {sr.latency_ms:>10.3f} | {sr.tflops:>6.2f} | {sr.speedup_vs_baseline:>6.2f}x |")
|
||
if tilegym_latency > 0:
|
||
logger.log(f" | TG | TileGym Reference | {tilegym_latency:>10.3f} | {tilegym_tflops:>6.2f} | {tilegym_speedup:>6.2f}x |")
|
||
|
||
logger.log(f"\n Best: Step {best_step.step} ({best_step.name}) with {best_step.speedup_vs_baseline:.2f}x speedup")
|
||
|
||
logger.results = all_results
|
||
return all_results
|
||
|
||
|
||
def print_summary(results: List[SeqLenResult]):
|
||
configs = get_fmha_config()
|
||
primary_cfg = configs[0]
|
||
|
||
logger.section("SCALING SUMMARY")
|
||
|
||
logger.log(f"\nTarget Config: TILE_M={primary_cfg.TILE_M}, TILE_N={primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}")
|
||
|
||
logger.log("\n## Performance vs Sequence Length\n")
|
||
logger.log("| Seq Len | Baseline (ms) | Full Opt (ms) | Speedup | TileGym (ms) | TG Speedup |")
|
||
logger.log("|---------|---------------|---------------|---------|--------------|------------|")
|
||
for r in results:
|
||
baseline = next((s for s in r.steps if s.step == 0), None)
|
||
full_opt = next((s for s in r.steps if s.step == 4), None)
|
||
if baseline and full_opt:
|
||
logger.log(f"| {r.seq_len:>7} | {baseline.latency_ms:>13.3f} | {full_opt.latency_ms:>13.3f} | {full_opt.speedup_vs_baseline:>6.2f}x | {r.tilegym_latency_ms:>12.3f} | {r.tilegym_speedup:>9.2f}x |")
|
||
|
||
logger.log("\n## Optimization Impact by Sequence Length\n")
|
||
logger.log("| Seq Len | Basic | +Math | +Memory | +Full | Best |")
|
||
logger.log("|---------|-------|-------|---------|-------|------|")
|
||
for r in results:
|
||
row = f"| {r.seq_len:>7} |"
|
||
for step in [1, 2, 3, 4]:
|
||
sr = next((s for s in r.steps if s.step == step), None)
|
||
if sr:
|
||
row += f" {sr.speedup_vs_baseline:>5.2f}x |"
|
||
else:
|
||
row += " N/A |"
|
||
row += f" {r.best_speedup:>4.2f}x |"
|
||
logger.log(row)
|
||
|
||
logger.section("KEY INSIGHTS")
|
||
logger.log("""
|
||
## Why Larger Sequences Benefit More from Optimization
|
||
|
||
1. **Memory Bandwidth Dominance**
|
||
- Attention has O(N²) memory complexity for the QK^T matrix
|
||
- At seq_len=8192: 8192² × 4 bytes = 256MB per head per batch
|
||
- Memory optimizations (order, latency hints) have larger impact
|
||
|
||
2. **More K-Loop Iterations**
|
||
- At seq_len=512: 8 K-tiles (512/64) for sm121, 2 K-tiles (512/256) for sm100
|
||
- At seq_len=8192: 128 K-tiles for sm121, 32 K-tiles for sm100
|
||
- Latency hiding through pipelining amortizes over more iterations
|
||
|
||
3. **Better Occupancy Utilization**
|
||
- More tiles = more parallelism opportunities
|
||
- sm121 uses occupancy=2 (smaller tiles, more blocks)
|
||
- sm100 uses occupancy=1 with larger tiles (256x128)
|
||
|
||
4. **Platform-Specific Tuning**
|
||
- DGX Spark (sm121): 64x64 tiles, occupancy=2 - optimized for bandwidth-limited workloads
|
||
- B300 (sm100): 256x128 tiles, occupancy=1 - optimized for compute-heavy workloads
|
||
|
||
## Optimization Priority by Problem Size
|
||
|
||
Small (seq_len <= 1024):
|
||
- Basic cuTile often sufficient
|
||
- Focus on correctness first
|
||
|
||
Medium (1024 < seq_len <= 4096):
|
||
- Math optimizations (exp2) provide ~5% gain
|
||
- Memory optimizations start to matter
|
||
|
||
Large (seq_len > 4096):
|
||
- Full optimization stack critical
|
||
- Platform-specific tuning essential
|
||
- Memory pipelining becomes essential
|
||
""")
|
||
|
||
|
||
def export_results(results: List[SeqLenResult], output_dir: str):
|
||
configs = get_fmha_config()
|
||
primary_cfg = configs[0]
|
||
|
||
data = {
|
||
"config": {
|
||
"batch": BATCH,
|
||
"n_heads": N_HEADS,
|
||
"head_dim": HEAD_DIM,
|
||
"tile_m": primary_cfg.TILE_M,
|
||
"tile_n": primary_cfg.TILE_N,
|
||
"occupancy": primary_cfg.occupancy,
|
||
"num_ctas": primary_cfg.num_ctas,
|
||
"platform": primary_cfg.name,
|
||
},
|
||
"results": [
|
||
{
|
||
"seq_len": r.seq_len,
|
||
"steps": [asdict(s) for s in r.steps],
|
||
"best_step": r.best_step,
|
||
"best_speedup": r.best_speedup,
|
||
"tilegym_latency_ms": r.tilegym_latency_ms,
|
||
"tilegym_tflops": r.tilegym_tflops,
|
||
"tilegym_speedup": r.tilegym_speedup,
|
||
}
|
||
for r in results
|
||
]
|
||
}
|
||
|
||
json_path = f"{output_dir}/fmha_scaling_results.json"
|
||
with open(json_path, 'w') as f:
|
||
json.dump(data, f, indent=2)
|
||
|
||
md_path = f"{output_dir}/fmha_scaling_results.md"
|
||
with open(md_path, 'w') as f:
|
||
f.write("# FMHA Scaling Analysis Results\n\n")
|
||
f.write("## Configuration\n")
|
||
f.write(f"- Platform: {primary_cfg.name}\n")
|
||
f.write(f"- Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}\n")
|
||
f.write(f"- Tile: {primary_cfg.TILE_M}x{primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}\n\n")
|
||
|
||
f.write("## Target-Specific Configs (from TileGym)\n\n")
|
||
f.write("| Platform | TILE_M | TILE_N | num_ctas | occupancy |\n")
|
||
f.write("|----------|--------|--------|----------|----------|\n")
|
||
f.write("| DGX Spark (sm121) | 64 | 64 | 1 | 2 |\n")
|
||
f.write("| B300 (sm100) Config 1 | 256 | 128 | 1 | 1 |\n")
|
||
f.write("| B300 (sm100) Config 2 | 128 | 128 | 1 | 2 |\n\n")
|
||
|
||
f.write("## Results by Sequence Length\n\n")
|
||
for r in results:
|
||
f.write(f"### Seq Len = {r.seq_len}\n\n")
|
||
f.write("| Step | Name | Latency (ms) | TFLOPS | Speedup |\n")
|
||
f.write("|------|------|--------------|--------|--------|\n")
|
||
for s in r.steps:
|
||
f.write(f"| {s.step} | {s.name} | {s.latency_ms:.3f} | {s.tflops:.2f} | {s.speedup_vs_baseline:.2f}x |\n")
|
||
if r.tilegym_latency_ms > 0:
|
||
f.write(f"| TG | TileGym Reference | {r.tilegym_latency_ms:.3f} | {r.tilegym_tflops:.2f} | {r.tilegym_speedup:.2f}x |\n")
|
||
f.write("\n")
|
||
|
||
log_path = f"{output_dir}/fmha_scaling_log.txt"
|
||
with open(log_path, 'w') as f:
|
||
f.write('\n'.join(logger.logs))
|
||
|
||
logger.log(f"\nResults exported to:")
|
||
logger.log(f" - {json_path}")
|
||
logger.log(f" - {md_path}")
|
||
logger.log(f" - {log_path}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="FMHA Scaling Analysis")
|
||
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
|
||
parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
|
||
args = parser.parse_args()
|
||
|
||
results = run_scaling_analysis(iterations=args.iterations)
|
||
print_summary(results)
|
||
export_results(results, args.output_dir)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|