mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-21 21:59:30 +00:00
892 lines
32 KiB
Python
892 lines
32 KiB
Python
|
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|||
|
|
# SPDX-License-Identifier: Apache-2.0
|
|||
|
|
#
|
|||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
|
# you may not use this file except in compliance with the License.
|
|||
|
|
# You may obtain a copy of the License at
|
|||
|
|
#
|
|||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|
#
|
|||
|
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
|
# See the License for the specific language governing permissions and
|
|||
|
|
# limitations under the License.
|
|||
|
|
|
|||
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
FMHA Scaling Analysis: How Optimizations Impact Performance at Different Sizes
|
|||
|
|
|
|||
|
|
This script demonstrates:
|
|||
|
|
1. How FMHA performance scales with sequence length
|
|||
|
|
2. Which optimizations provide the most benefit at larger sizes
|
|||
|
|
3. Target-specific configurations for different GPU architectures
|
|||
|
|
|
|||
|
|
Target Platforms (from TileGym):
|
|||
|
|
- DGX Spark (sm120/sm121): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
|||
|
|
- Blackwell B300 (sm100): TILE_M=256, TILE_N=128 or 128x128, num_ctas=1, occupancy=1-2
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
python fmha_scaling_analysis.py [--iterations N]
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import math
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass, asdict
|
|||
|
|
from typing import List
|
|||
|
|
from types import SimpleNamespace
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
LOG_SEPARATOR = "=" * 80
|
|||
|
|
LOG_SUBSEPARATOR = "-" * 60
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class StepResult:
|
|||
|
|
step: int
|
|||
|
|
name: str
|
|||
|
|
latency_ms: float
|
|||
|
|
tflops: float
|
|||
|
|
speedup_vs_baseline: float
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class SeqLenResult:
|
|||
|
|
seq_len: int
|
|||
|
|
steps: List[StepResult]
|
|||
|
|
best_step: int
|
|||
|
|
best_speedup: float
|
|||
|
|
tilegym_latency_ms: float
|
|||
|
|
tilegym_tflops: float
|
|||
|
|
tilegym_speedup: float
|
|||
|
|
|
|||
|
|
class Logger:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.results: List[SeqLenResult] = []
|
|||
|
|
self.logs: List[str] = []
|
|||
|
|
|
|||
|
|
def log(self, msg: str):
|
|||
|
|
print(msg)
|
|||
|
|
self.logs.append(msg)
|
|||
|
|
|
|||
|
|
def section(self, title: str):
|
|||
|
|
self.log(f"\n{LOG_SEPARATOR}")
|
|||
|
|
self.log(f" {title}")
|
|||
|
|
self.log(LOG_SEPARATOR)
|
|||
|
|
|
|||
|
|
def subsection(self, title: str):
|
|||
|
|
self.log(f"\n{LOG_SUBSEPARATOR}")
|
|||
|
|
self.log(f" {title}")
|
|||
|
|
self.log(LOG_SUBSEPARATOR)
|
|||
|
|
|
|||
|
|
logger = Logger()
|
|||
|
|
|
|||
|
|
BATCH = 4
|
|||
|
|
N_HEADS = 32
|
|||
|
|
HEAD_DIM = 128
|
|||
|
|
INV_LOG_2 = 1.0 / math.log(2)
|
|||
|
|
|
|||
|
|
SEQ_LENS = [1024, 2048, 4096, 8192, 16384]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_fmha_config():
|
|||
|
|
"""
|
|||
|
|
Get target-specific FMHA configuration (from TileGym attention.py)
|
|||
|
|
|
|||
|
|
Returns configs matching TileGym's _fmha_autotune_configs():
|
|||
|
|
- sm120/sm121 (DGX Spark): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
|||
|
|
- sm100 (Blackwell B300): Two configs to try via autotuning
|
|||
|
|
"""
|
|||
|
|
gpu_capability = torch.cuda.get_device_capability()
|
|||
|
|
|
|||
|
|
if gpu_capability in [(12, 0), (12, 1)]:
|
|||
|
|
return [
|
|||
|
|
SimpleNamespace(
|
|||
|
|
name="DGX Spark (sm121)",
|
|||
|
|
TILE_M=64,
|
|||
|
|
TILE_N=64,
|
|||
|
|
num_ctas=1,
|
|||
|
|
occupancy=2
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
else:
|
|||
|
|
return [
|
|||
|
|
SimpleNamespace(
|
|||
|
|
name="Blackwell B300 (sm100) - Config 1",
|
|||
|
|
TILE_M=256,
|
|||
|
|
TILE_N=128,
|
|||
|
|
num_ctas=1,
|
|||
|
|
occupancy=1
|
|||
|
|
),
|
|||
|
|
SimpleNamespace(
|
|||
|
|
name="Blackwell B300 (sm100) - Config 2",
|
|||
|
|
TILE_M=128,
|
|||
|
|
TILE_N=128,
|
|||
|
|
num_ctas=1,
|
|||
|
|
occupancy=2
|
|||
|
|
),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_device():
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
return torch.device("cuda")
|
|||
|
|
raise RuntimeError("CUDA not available")
|
|||
|
|
|
|||
|
|
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
|
|||
|
|
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
|
|||
|
|
total_flops = 2 * flops_per_matmul
|
|||
|
|
if causal:
|
|||
|
|
total_flops *= 0.5
|
|||
|
|
return total_flops
|
|||
|
|
|
|||
|
|
def benchmark_fn(fn, warmup=10, iterations=100):
|
|||
|
|
"""Benchmark using triton's do_bench_cudagraph for accurate timing"""
|
|||
|
|
try:
|
|||
|
|
import triton
|
|||
|
|
# Use triton's cudagraph benchmark - same as TileGym
|
|||
|
|
ms = triton.testing.do_bench_cudagraph(fn)
|
|||
|
|
return ms
|
|||
|
|
except (ImportError, Exception):
|
|||
|
|
# Fallback to manual timing
|
|||
|
|
for _ in range(warmup):
|
|||
|
|
fn()
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
|
|||
|
|
start = time.perf_counter()
|
|||
|
|
for _ in range(iterations):
|
|||
|
|
fn()
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
end = time.perf_counter()
|
|||
|
|
|
|||
|
|
return (end - start) / iterations * 1000
|
|||
|
|
|
|||
|
|
def reference_fmha(q, k, v, sm_scale, is_causal=True):
|
|||
|
|
return torch.nn.functional.scaled_dot_product_attention(
|
|||
|
|
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import cuda.tile as ct
|
|||
|
|
from cuda.tile import RoundingMode as RMd
|
|||
|
|
CUTILE_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
CUTILE_AVAILABLE = False
|
|||
|
|
logger.log("[WARN] cuTile not available.")
|
|||
|
|
|
|||
|
|
if CUTILE_AVAILABLE:
|
|||
|
|
ConstInt = ct.Constant[int]
|
|||
|
|
ConstBool = ct.Constant[bool]
|
|||
|
|
|
|||
|
|
@ct.kernel()
|
|||
|
|
def fmha_basic(
|
|||
|
|
Q, K, V, Out,
|
|||
|
|
qk_scale: float,
|
|||
|
|
TILE_D: ConstInt,
|
|||
|
|
H: ConstInt,
|
|||
|
|
TILE_M: ConstInt,
|
|||
|
|
TILE_N: ConstInt,
|
|||
|
|
CAUSAL: ConstBool,
|
|||
|
|
):
|
|||
|
|
"""Step 1: Basic cuTile - no optimizations"""
|
|||
|
|
bid_x = ct.bid(0)
|
|||
|
|
bid_y = ct.bid(1)
|
|||
|
|
batch_idx = bid_y // H
|
|||
|
|
head_idx = bid_y % H
|
|||
|
|
|
|||
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|||
|
|
offs_m = offs_m[:, None]
|
|||
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|||
|
|
offs_n_tile = offs_n_tile[None, :]
|
|||
|
|
|
|||
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|||
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|||
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|||
|
|
|
|||
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|||
|
|
q = q.reshape((TILE_M, TILE_D))
|
|||
|
|
|
|||
|
|
k_seqlen = K.shape[2]
|
|||
|
|
if CAUSAL:
|
|||
|
|
m_end = (bid_x + 1) * TILE_M
|
|||
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|||
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|||
|
|
else:
|
|||
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|||
|
|
mask_start = k_seqlen // TILE_N
|
|||
|
|
|
|||
|
|
for j in range(0, Tc):
|
|||
|
|
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|||
|
|
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
|||
|
|
k_t = ct.permute(k_tile, (1, 0))
|
|||
|
|
|
|||
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|||
|
|
qk = ct.mma(q, k_t, qk)
|
|||
|
|
qk = qk * qk_scale
|
|||
|
|
|
|||
|
|
if CAUSAL and j >= mask_start:
|
|||
|
|
offs_n = j * TILE_N + offs_n_tile
|
|||
|
|
mask = offs_m >= offs_n
|
|||
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|||
|
|
|
|||
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True)
|
|||
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|||
|
|
qk = qk - m_ij
|
|||
|
|
|
|||
|
|
p = ct.exp(qk)
|
|||
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|||
|
|
alpha = ct.exp(m_i - m_ij)
|
|||
|
|
l_i = l_i * alpha + l_ij
|
|||
|
|
acc = acc * alpha
|
|||
|
|
|
|||
|
|
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|||
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|||
|
|
p_cast = p.astype(Q.dtype)
|
|||
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|||
|
|
m_i = m_ij
|
|||
|
|
|
|||
|
|
acc = acc / l_i
|
|||
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|||
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|||
|
|
|
|||
|
|
@ct.kernel()
|
|||
|
|
def fmha_math_opt(
|
|||
|
|
Q, K, V, Out,
|
|||
|
|
qk_scale: float,
|
|||
|
|
TILE_D: ConstInt,
|
|||
|
|
H: ConstInt,
|
|||
|
|
TILE_M: ConstInt,
|
|||
|
|
TILE_N: ConstInt,
|
|||
|
|
CAUSAL: ConstBool,
|
|||
|
|
):
|
|||
|
|
"""Step 2: Math optimizations - exp2 + flush_to_zero"""
|
|||
|
|
bid_x = ct.bid(0)
|
|||
|
|
bid_y = ct.bid(1)
|
|||
|
|
batch_idx = bid_y // H
|
|||
|
|
head_idx = bid_y % H
|
|||
|
|
|
|||
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|||
|
|
|
|||
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|||
|
|
offs_m = offs_m[:, None]
|
|||
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|||
|
|
offs_n_tile = offs_n_tile[None, :]
|
|||
|
|
|
|||
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|||
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|||
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|||
|
|
|
|||
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|||
|
|
q = q.reshape((TILE_M, TILE_D))
|
|||
|
|
|
|||
|
|
k_seqlen = K.shape[2]
|
|||
|
|
if CAUSAL:
|
|||
|
|
m_end = (bid_x + 1) * TILE_M
|
|||
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|||
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|||
|
|
else:
|
|||
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|||
|
|
mask_start = k_seqlen // TILE_N
|
|||
|
|
|
|||
|
|
for j in range(0, Tc):
|
|||
|
|
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|||
|
|
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
|||
|
|
k_t = ct.permute(k_tile, (1, 0))
|
|||
|
|
|
|||
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|||
|
|
qk = ct.mma(q, k_t, qk)
|
|||
|
|
|
|||
|
|
if CAUSAL and j >= mask_start:
|
|||
|
|
offs_n = j * TILE_N + offs_n_tile
|
|||
|
|
mask = offs_m >= offs_n
|
|||
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|||
|
|
|
|||
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|||
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|||
|
|
qk = qk * qk_scale_log2 - m_ij
|
|||
|
|
|
|||
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|||
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|||
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|||
|
|
l_i = l_i * alpha + l_ij
|
|||
|
|
acc = acc * alpha
|
|||
|
|
|
|||
|
|
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
|||
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|||
|
|
p_cast = p.astype(Q.dtype)
|
|||
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|||
|
|
m_i = m_ij
|
|||
|
|
|
|||
|
|
acc = acc / l_i
|
|||
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|||
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|||
|
|
|
|||
|
|
@ct.kernel()
|
|||
|
|
def fmha_memory_opt(
|
|||
|
|
Q, K, V, Out,
|
|||
|
|
qk_scale: float,
|
|||
|
|
TILE_D: ConstInt,
|
|||
|
|
H: ConstInt,
|
|||
|
|
TILE_M: ConstInt,
|
|||
|
|
TILE_N: ConstInt,
|
|||
|
|
CAUSAL: ConstBool,
|
|||
|
|
):
|
|||
|
|
"""Step 3: Memory optimizations - load order + latency hints"""
|
|||
|
|
bid_x = ct.bid(0)
|
|||
|
|
bid_y = ct.bid(1)
|
|||
|
|
batch_idx = bid_y // H
|
|||
|
|
head_idx = bid_y % H
|
|||
|
|
|
|||
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|||
|
|
|
|||
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|||
|
|
offs_m = offs_m[:, None]
|
|||
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|||
|
|
offs_n_tile = offs_n_tile[None, :]
|
|||
|
|
|
|||
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|||
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|||
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|||
|
|
|
|||
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|||
|
|
q = q.reshape((TILE_M, TILE_D))
|
|||
|
|
|
|||
|
|
k_seqlen = K.shape[2]
|
|||
|
|
if CAUSAL:
|
|||
|
|
m_end = (bid_x + 1) * TILE_M
|
|||
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|||
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|||
|
|
else:
|
|||
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|||
|
|
mask_start = k_seqlen // TILE_N
|
|||
|
|
|
|||
|
|
for j in range(0, Tc):
|
|||
|
|
k_t = ct.load(
|
|||
|
|
K,
|
|||
|
|
index=(batch_idx, head_idx, 0, j),
|
|||
|
|
shape=(1, 1, TILE_D, TILE_N),
|
|||
|
|
order=(0, 1, 3, 2),
|
|||
|
|
latency=2
|
|||
|
|
)
|
|||
|
|
k_t = k_t.reshape((TILE_D, TILE_N))
|
|||
|
|
|
|||
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|||
|
|
qk = ct.mma(q, k_t, qk)
|
|||
|
|
|
|||
|
|
if CAUSAL and j >= mask_start:
|
|||
|
|
offs_n = j * TILE_N + offs_n_tile
|
|||
|
|
mask = offs_m >= offs_n
|
|||
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|||
|
|
|
|||
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|||
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|||
|
|
qk = qk * qk_scale_log2 - m_ij
|
|||
|
|
|
|||
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|||
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|||
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|||
|
|
l_i = l_i * alpha + l_ij
|
|||
|
|
acc = acc * alpha
|
|||
|
|
|
|||
|
|
v_tile = ct.load(
|
|||
|
|
V,
|
|||
|
|
index=(batch_idx, head_idx, j, 0),
|
|||
|
|
shape=(1, 1, TILE_N, TILE_D),
|
|||
|
|
latency=4
|
|||
|
|
)
|
|||
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|||
|
|
p_cast = p.astype(Q.dtype)
|
|||
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|||
|
|
m_i = m_ij
|
|||
|
|
|
|||
|
|
acc = acc / l_i
|
|||
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|||
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|||
|
|
|
|||
|
|
@ct.kernel(occupancy=2)
|
|||
|
|
def fmha_full_opt_occ2(
|
|||
|
|
Q, K, V, Out,
|
|||
|
|
qk_scale: float,
|
|||
|
|
TILE_D: ConstInt,
|
|||
|
|
H: ConstInt,
|
|||
|
|
TILE_M: ConstInt,
|
|||
|
|
TILE_N: ConstInt,
|
|||
|
|
CAUSAL: ConstBool,
|
|||
|
|
):
|
|||
|
|
"""Step 4a: Full optimization with occupancy=2 (for sm120/sm121)"""
|
|||
|
|
bid_x = ct.bid(0)
|
|||
|
|
bid_y = ct.bid(1)
|
|||
|
|
batch_idx = bid_y // H
|
|||
|
|
head_idx = bid_y % H
|
|||
|
|
|
|||
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|||
|
|
|
|||
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|||
|
|
offs_m = offs_m[:, None]
|
|||
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|||
|
|
offs_n_tile = offs_n_tile[None, :]
|
|||
|
|
|
|||
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|||
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|||
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|||
|
|
|
|||
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|||
|
|
q = q.reshape((TILE_M, TILE_D))
|
|||
|
|
|
|||
|
|
k_seqlen = K.shape[2]
|
|||
|
|
if CAUSAL:
|
|||
|
|
m_end = (bid_x + 1) * TILE_M
|
|||
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|||
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|||
|
|
else:
|
|||
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|||
|
|
mask_start = k_seqlen // TILE_N
|
|||
|
|
|
|||
|
|
for j in range(0, Tc):
|
|||
|
|
k_t = ct.load(
|
|||
|
|
K,
|
|||
|
|
index=(batch_idx, head_idx, 0, j),
|
|||
|
|
shape=(1, 1, TILE_D, TILE_N),
|
|||
|
|
order=(0, 1, 3, 2),
|
|||
|
|
latency=2
|
|||
|
|
)
|
|||
|
|
k_t = k_t.reshape((TILE_D, TILE_N))
|
|||
|
|
|
|||
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|||
|
|
qk = ct.mma(q, k_t, qk)
|
|||
|
|
|
|||
|
|
if CAUSAL and j >= mask_start:
|
|||
|
|
offs_n = j * TILE_N + offs_n_tile
|
|||
|
|
mask = offs_m >= offs_n
|
|||
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|||
|
|
|
|||
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|||
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|||
|
|
qk = qk * qk_scale_log2 - m_ij
|
|||
|
|
|
|||
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|||
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|||
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|||
|
|
l_i = l_i * alpha + l_ij
|
|||
|
|
acc = acc * alpha
|
|||
|
|
|
|||
|
|
v_tile = ct.load(
|
|||
|
|
V,
|
|||
|
|
index=(batch_idx, head_idx, j, 0),
|
|||
|
|
shape=(1, 1, TILE_N, TILE_D),
|
|||
|
|
latency=4
|
|||
|
|
)
|
|||
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|||
|
|
p_cast = p.astype(Q.dtype)
|
|||
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|||
|
|
m_i = m_ij
|
|||
|
|
|
|||
|
|
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
|||
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|||
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|||
|
|
|
|||
|
|
@ct.kernel(occupancy=1)
|
|||
|
|
def fmha_full_opt_occ1(
|
|||
|
|
Q, K, V, Out,
|
|||
|
|
qk_scale: float,
|
|||
|
|
TILE_D: ConstInt,
|
|||
|
|
H: ConstInt,
|
|||
|
|
TILE_M: ConstInt,
|
|||
|
|
TILE_N: ConstInt,
|
|||
|
|
CAUSAL: ConstBool,
|
|||
|
|
):
|
|||
|
|
"""Step 4b: Full optimization with occupancy=1 (for sm100 Blackwell)"""
|
|||
|
|
bid_x = ct.bid(0)
|
|||
|
|
bid_y = ct.bid(1)
|
|||
|
|
batch_idx = bid_y // H
|
|||
|
|
head_idx = bid_y % H
|
|||
|
|
|
|||
|
|
qk_scale_log2 = qk_scale * INV_LOG_2
|
|||
|
|
|
|||
|
|
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
|||
|
|
offs_m = offs_m[:, None]
|
|||
|
|
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
|||
|
|
offs_n_tile = offs_n_tile[None, :]
|
|||
|
|
|
|||
|
|
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
|||
|
|
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
|||
|
|
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
|||
|
|
|
|||
|
|
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
|||
|
|
q = q.reshape((TILE_M, TILE_D))
|
|||
|
|
|
|||
|
|
k_seqlen = K.shape[2]
|
|||
|
|
if CAUSAL:
|
|||
|
|
m_end = (bid_x + 1) * TILE_M
|
|||
|
|
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
|||
|
|
mask_start = (bid_x * TILE_M) // TILE_N
|
|||
|
|
else:
|
|||
|
|
Tc = ct.cdiv(k_seqlen, TILE_N)
|
|||
|
|
mask_start = k_seqlen // TILE_N
|
|||
|
|
|
|||
|
|
for j in range(0, Tc):
|
|||
|
|
k_t = ct.load(
|
|||
|
|
K,
|
|||
|
|
index=(batch_idx, head_idx, 0, j),
|
|||
|
|
shape=(1, 1, TILE_D, TILE_N),
|
|||
|
|
order=(0, 1, 3, 2),
|
|||
|
|
latency=2
|
|||
|
|
)
|
|||
|
|
k_t = k_t.reshape((TILE_D, TILE_N))
|
|||
|
|
|
|||
|
|
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
|||
|
|
qk = ct.mma(q, k_t, qk)
|
|||
|
|
|
|||
|
|
if CAUSAL and j >= mask_start:
|
|||
|
|
offs_n = j * TILE_N + offs_n_tile
|
|||
|
|
mask = offs_m >= offs_n
|
|||
|
|
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
|||
|
|
|
|||
|
|
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
|||
|
|
m_ij = ct.maximum(m_i, m_ij)
|
|||
|
|
qk = qk * qk_scale_log2 - m_ij
|
|||
|
|
|
|||
|
|
p = ct.exp2(qk, flush_to_zero=True)
|
|||
|
|
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
|||
|
|
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
|||
|
|
l_i = l_i * alpha + l_ij
|
|||
|
|
acc = acc * alpha
|
|||
|
|
|
|||
|
|
v_tile = ct.load(
|
|||
|
|
V,
|
|||
|
|
index=(batch_idx, head_idx, j, 0),
|
|||
|
|
shape=(1, 1, TILE_N, TILE_D),
|
|||
|
|
latency=4
|
|||
|
|
)
|
|||
|
|
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
|||
|
|
p_cast = p.astype(Q.dtype)
|
|||
|
|
acc = ct.mma(p_cast, v_tile, acc)
|
|||
|
|
m_i = m_ij
|
|||
|
|
|
|||
|
|
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
|||
|
|
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
|||
|
|
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
|||
|
|
|
|||
|
|
def run_kernel(kernel_fn, q, k, v, sm_scale, tile_m, tile_n, is_causal=True):
|
|||
|
|
batch_size, num_heads, seq_len, head_dim = q.shape
|
|||
|
|
o = torch.empty_like(q)
|
|||
|
|
grid = (math.ceil(seq_len / tile_m), batch_size * num_heads, 1)
|
|||
|
|
ct.launch(
|
|||
|
|
torch.cuda.current_stream(),
|
|||
|
|
grid,
|
|||
|
|
kernel_fn,
|
|||
|
|
(q, k, v, o, sm_scale, head_dim, num_heads, tile_m, tile_n, is_causal)
|
|||
|
|
)
|
|||
|
|
return o
|
|||
|
|
|
|||
|
|
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
|
|||
|
|
try:
|
|||
|
|
import tilegym
|
|||
|
|
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
|
|||
|
|
except ImportError:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_scaling_analysis(iterations=100):
|
|||
|
|
device = get_device()
|
|||
|
|
gpu_cap = torch.cuda.get_device_capability()
|
|||
|
|
gpu_name = torch.cuda.get_device_name()
|
|||
|
|
|
|||
|
|
configs = get_fmha_config()
|
|||
|
|
primary_cfg = configs[0]
|
|||
|
|
TILE_M = primary_cfg.TILE_M
|
|||
|
|
TILE_N = primary_cfg.TILE_N
|
|||
|
|
|
|||
|
|
logger.section("FMHA SCALING ANALYSIS (TileGym Benchmark Match)")
|
|||
|
|
logger.log("Matching TileGym bench_fused_attention.py configuration")
|
|||
|
|
logger.log(f"\nGPU: {gpu_name} (sm_{gpu_cap[0]}{gpu_cap[1]})")
|
|||
|
|
|
|||
|
|
logger.subsection("TARGET-SPECIFIC CONFIGURATION (from TileGym)")
|
|||
|
|
for cfg in configs:
|
|||
|
|
logger.log(f"\n {cfg.name}:")
|
|||
|
|
logger.log(f" TILE_M={cfg.TILE_M}, TILE_N={cfg.TILE_N}")
|
|||
|
|
logger.log(f" num_ctas={cfg.num_ctas}, occupancy={cfg.occupancy}")
|
|||
|
|
|
|||
|
|
logger.log(f"\nUsing primary config: TILE_M={TILE_M}, TILE_N={TILE_N}, occupancy={primary_cfg.occupancy}")
|
|||
|
|
logger.log("\nTest Configuration (matches TileGym bench_fused_attention.py):")
|
|||
|
|
logger.log(f" Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}")
|
|||
|
|
logger.log(" Causal: True, Precision: float16")
|
|||
|
|
logger.log(f" Sequence Lengths: {SEQ_LENS}")
|
|||
|
|
logger.log(" Benchmark: triton.testing.do_bench_cudagraph (same as TileGym)")
|
|||
|
|
|
|||
|
|
logger.section("OPTIMIZATION STEPS")
|
|||
|
|
logger.log(f"""
|
|||
|
|
Step 0: PyTorch Baseline
|
|||
|
|
- torch.nn.functional.scaled_dot_product_attention
|
|||
|
|
- Uses cuDNN Flash Attention backend
|
|||
|
|
- Highly optimized reference
|
|||
|
|
|
|||
|
|
Step 1: Basic cuTile (TILE_M={TILE_M}, TILE_N={TILE_N})
|
|||
|
|
- @ct.kernel with ct.mma() for Tensor Cores
|
|||
|
|
- Standard exp() for softmax
|
|||
|
|
- Explicit transpose with ct.permute()
|
|||
|
|
- No memory/occupancy hints
|
|||
|
|
|
|||
|
|
Step 2: Math Optimizations
|
|||
|
|
- ct.exp2() instead of ct.exp() (faster on GPU)
|
|||
|
|
- flush_to_zero=True for denormals
|
|||
|
|
- Scale adjustment: multiply by 1/log(2)
|
|||
|
|
|
|||
|
|
Step 3: Memory Optimizations
|
|||
|
|
- Load order=(0,1,3,2) for implicit K transpose
|
|||
|
|
- Latency hints: K=2, V=4 for prefetching
|
|||
|
|
- Overlaps memory loads with computation
|
|||
|
|
|
|||
|
|
Step 4: Full Optimization (Target-Specific)
|
|||
|
|
- @ct.kernel(occupancy={primary_cfg.occupancy}) for {'sm120/121' if primary_cfg.occupancy == 2 else 'sm100'}
|
|||
|
|
- ct.truediv with APPROX rounding mode
|
|||
|
|
- Matches TileGym production implementation
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
logger.section("PLATFORM DIFFERENCES: DGX Spark vs Blackwell B300")
|
|||
|
|
logger.log("""
|
|||
|
|
| Parameter | DGX Spark (sm121) | Blackwell B300 (sm100) |
|
|||
|
|
|--------------|-------------------|------------------------|
|
|||
|
|
| TILE_M | 64 | 256 or 128 |
|
|||
|
|
| TILE_N | 64 | 128 |
|
|||
|
|
| num_ctas | 1 | 1 |
|
|||
|
|
| occupancy | 2 | 1 or 2 |
|
|||
|
|
|
|||
|
|
Why the difference?
|
|||
|
|
- B300 has more SMs and larger shared memory -> can use bigger tiles
|
|||
|
|
- B300 benefits from larger tiles (256x128) with lower occupancy
|
|||
|
|
- DGX Spark needs smaller tiles (64x64) with higher occupancy to hide latency
|
|||
|
|
- B300's higher memory bandwidth makes larger tiles more efficient
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
all_results = []
|
|||
|
|
|
|||
|
|
select_kernel = fmha_full_opt_occ2 if primary_cfg.occupancy == 2 else fmha_full_opt_occ1
|
|||
|
|
|
|||
|
|
for seq_len in SEQ_LENS:
|
|||
|
|
logger.subsection(f"Sequence Length: {seq_len}")
|
|||
|
|
|
|||
|
|
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
|||
|
|
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
|||
|
|
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
|||
|
|
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
|
|||
|
|
|
|||
|
|
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
|
|||
|
|
|
|||
|
|
steps_results = []
|
|||
|
|
|
|||
|
|
baseline_fn = lambda: reference_fmha(q, k, v, sm_scale, is_causal=True)
|
|||
|
|
baseline_latency = benchmark_fn(baseline_fn, warmup=10, iterations=iterations)
|
|||
|
|
baseline_tflops = flops * 1e-12 / (baseline_latency * 1e-3)
|
|||
|
|
steps_results.append(StepResult(0, "PyTorch Baseline", baseline_latency, baseline_tflops, 1.0))
|
|||
|
|
|
|||
|
|
if CUTILE_AVAILABLE:
|
|||
|
|
kernels = [
|
|||
|
|
(1, "Basic cuTile", fmha_basic),
|
|||
|
|
(2, "Math Opt (exp2)", fmha_math_opt),
|
|||
|
|
(3, "Memory Opt (order+latency)", fmha_memory_opt),
|
|||
|
|
(4, f"Full Opt (occ={primary_cfg.occupancy})", select_kernel),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for step, name, kernel in kernels:
|
|||
|
|
try:
|
|||
|
|
fn = lambda kernel=kernel: run_kernel(kernel, q, k, v, sm_scale, TILE_M, TILE_N, is_causal=True)
|
|||
|
|
latency = benchmark_fn(fn, warmup=10, iterations=iterations)
|
|||
|
|
tflops = flops * 1e-12 / (latency * 1e-3)
|
|||
|
|
speedup = baseline_latency / latency
|
|||
|
|
steps_results.append(StepResult(step, name, latency, tflops, speedup))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.log(f" [ERROR] Step {step} failed: {e}")
|
|||
|
|
|
|||
|
|
tilegym_latency = 0.0
|
|||
|
|
tilegym_tflops = 0.0
|
|||
|
|
tilegym_speedup = 0.0
|
|||
|
|
tilegym_out = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
|||
|
|
if tilegym_out is not None:
|
|||
|
|
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
|||
|
|
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
|
|||
|
|
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
|
|||
|
|
tilegym_speedup = baseline_latency / tilegym_latency
|
|||
|
|
|
|||
|
|
best_step = max(steps_results, key=lambda x: x.speedup_vs_baseline)
|
|||
|
|
|
|||
|
|
result = SeqLenResult(
|
|||
|
|
seq_len=seq_len,
|
|||
|
|
steps=steps_results,
|
|||
|
|
best_step=best_step.step,
|
|||
|
|
best_speedup=best_step.speedup_vs_baseline,
|
|||
|
|
tilegym_latency_ms=tilegym_latency,
|
|||
|
|
tilegym_tflops=tilegym_tflops,
|
|||
|
|
tilegym_speedup=tilegym_speedup,
|
|||
|
|
)
|
|||
|
|
all_results.append(result)
|
|||
|
|
|
|||
|
|
logger.log("\n | Step | Name | Latency (ms) | TFLOPS | Speedup |")
|
|||
|
|
logger.log(" |------|------|--------------|--------|---------|")
|
|||
|
|
|
|||
|
|
for sr in steps_results:
|
|||
|
|
logger.log(f" | {sr.step} | {sr.name:<28} | {sr.latency_ms:>10.3f} | {sr.tflops:>6.2f} | {sr.speedup_vs_baseline:>6.2f}x |")
|
|||
|
|
if tilegym_latency > 0:
|
|||
|
|
logger.log(f" | TG | TileGym Reference | {tilegym_latency:>10.3f} | {tilegym_tflops:>6.2f} | {tilegym_speedup:>6.2f}x |")
|
|||
|
|
|
|||
|
|
logger.log(f"\n Best: Step {best_step.step} ({best_step.name}) with {best_step.speedup_vs_baseline:.2f}x speedup")
|
|||
|
|
|
|||
|
|
logger.results = all_results
|
|||
|
|
return all_results
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_summary(results: List[SeqLenResult]):
|
|||
|
|
configs = get_fmha_config()
|
|||
|
|
primary_cfg = configs[0]
|
|||
|
|
|
|||
|
|
logger.section("SCALING SUMMARY")
|
|||
|
|
|
|||
|
|
logger.log(f"\nTarget Config: TILE_M={primary_cfg.TILE_M}, TILE_N={primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}")
|
|||
|
|
|
|||
|
|
logger.log("\n## Performance vs Sequence Length\n")
|
|||
|
|
logger.log("| Seq Len | Baseline (ms) | Full Opt (ms) | Speedup | TileGym (ms) | TG Speedup |")
|
|||
|
|
logger.log("|---------|---------------|---------------|---------|--------------|------------|")
|
|||
|
|
for r in results:
|
|||
|
|
baseline = next((s for s in r.steps if s.step == 0), None)
|
|||
|
|
full_opt = next((s for s in r.steps if s.step == 4), None)
|
|||
|
|
if baseline and full_opt:
|
|||
|
|
logger.log(f"| {r.seq_len:>7} | {baseline.latency_ms:>13.3f} | {full_opt.latency_ms:>13.3f} | {full_opt.speedup_vs_baseline:>6.2f}x | {r.tilegym_latency_ms:>12.3f} | {r.tilegym_speedup:>9.2f}x |")
|
|||
|
|
|
|||
|
|
logger.log("\n## Optimization Impact by Sequence Length\n")
|
|||
|
|
logger.log("| Seq Len | Basic | +Math | +Memory | +Full | Best |")
|
|||
|
|
logger.log("|---------|-------|-------|---------|-------|------|")
|
|||
|
|
for r in results:
|
|||
|
|
row = f"| {r.seq_len:>7} |"
|
|||
|
|
for step in [1, 2, 3, 4]:
|
|||
|
|
sr = next((s for s in r.steps if s.step == step), None)
|
|||
|
|
if sr:
|
|||
|
|
row += f" {sr.speedup_vs_baseline:>5.2f}x |"
|
|||
|
|
else:
|
|||
|
|
row += " N/A |"
|
|||
|
|
row += f" {r.best_speedup:>4.2f}x |"
|
|||
|
|
logger.log(row)
|
|||
|
|
|
|||
|
|
logger.section("KEY INSIGHTS")
|
|||
|
|
logger.log("""
|
|||
|
|
## Why Larger Sequences Benefit More from Optimization
|
|||
|
|
|
|||
|
|
1. **Memory Bandwidth Dominance**
|
|||
|
|
- Attention has O(N²) memory complexity for the QK^T matrix
|
|||
|
|
- At seq_len=8192: 8192² × 4 bytes = 256MB per head per batch
|
|||
|
|
- Memory optimizations (order, latency hints) have larger impact
|
|||
|
|
|
|||
|
|
2. **More K-Loop Iterations**
|
|||
|
|
- At seq_len=512: 8 K-tiles (512/64) for sm121, 2 K-tiles (512/256) for sm100
|
|||
|
|
- At seq_len=8192: 128 K-tiles for sm121, 32 K-tiles for sm100
|
|||
|
|
- Latency hiding through pipelining amortizes over more iterations
|
|||
|
|
|
|||
|
|
3. **Better Occupancy Utilization**
|
|||
|
|
- More tiles = more parallelism opportunities
|
|||
|
|
- sm121 uses occupancy=2 (smaller tiles, more blocks)
|
|||
|
|
- sm100 uses occupancy=1 with larger tiles (256x128)
|
|||
|
|
|
|||
|
|
4. **Platform-Specific Tuning**
|
|||
|
|
- DGX Spark (sm121): 64x64 tiles, occupancy=2 - optimized for bandwidth-limited workloads
|
|||
|
|
- B300 (sm100): 256x128 tiles, occupancy=1 - optimized for compute-heavy workloads
|
|||
|
|
|
|||
|
|
## Optimization Priority by Problem Size
|
|||
|
|
|
|||
|
|
Small (seq_len <= 1024):
|
|||
|
|
- Basic cuTile often sufficient
|
|||
|
|
- Focus on correctness first
|
|||
|
|
|
|||
|
|
Medium (1024 < seq_len <= 4096):
|
|||
|
|
- Math optimizations (exp2) provide ~5% gain
|
|||
|
|
- Memory optimizations start to matter
|
|||
|
|
|
|||
|
|
Large (seq_len > 4096):
|
|||
|
|
- Full optimization stack critical
|
|||
|
|
- Platform-specific tuning essential
|
|||
|
|
- Memory pipelining becomes essential
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def export_results(results: List[SeqLenResult], output_dir: str):
|
|||
|
|
configs = get_fmha_config()
|
|||
|
|
primary_cfg = configs[0]
|
|||
|
|
|
|||
|
|
data = {
|
|||
|
|
"config": {
|
|||
|
|
"batch": BATCH,
|
|||
|
|
"n_heads": N_HEADS,
|
|||
|
|
"head_dim": HEAD_DIM,
|
|||
|
|
"tile_m": primary_cfg.TILE_M,
|
|||
|
|
"tile_n": primary_cfg.TILE_N,
|
|||
|
|
"occupancy": primary_cfg.occupancy,
|
|||
|
|
"num_ctas": primary_cfg.num_ctas,
|
|||
|
|
"platform": primary_cfg.name,
|
|||
|
|
},
|
|||
|
|
"results": [
|
|||
|
|
{
|
|||
|
|
"seq_len": r.seq_len,
|
|||
|
|
"steps": [asdict(s) for s in r.steps],
|
|||
|
|
"best_step": r.best_step,
|
|||
|
|
"best_speedup": r.best_speedup,
|
|||
|
|
"tilegym_latency_ms": r.tilegym_latency_ms,
|
|||
|
|
"tilegym_tflops": r.tilegym_tflops,
|
|||
|
|
"tilegym_speedup": r.tilegym_speedup,
|
|||
|
|
}
|
|||
|
|
for r in results
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
json_path = f"{output_dir}/fmha_scaling_results.json"
|
|||
|
|
with open(json_path, 'w') as f:
|
|||
|
|
json.dump(data, f, indent=2)
|
|||
|
|
|
|||
|
|
md_path = f"{output_dir}/fmha_scaling_results.md"
|
|||
|
|
with open(md_path, 'w') as f:
|
|||
|
|
f.write("# FMHA Scaling Analysis Results\n\n")
|
|||
|
|
f.write("## Configuration\n")
|
|||
|
|
f.write(f"- Platform: {primary_cfg.name}\n")
|
|||
|
|
f.write(f"- Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}\n")
|
|||
|
|
f.write(f"- Tile: {primary_cfg.TILE_M}x{primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}\n\n")
|
|||
|
|
|
|||
|
|
f.write("## Target-Specific Configs (from TileGym)\n\n")
|
|||
|
|
f.write("| Platform | TILE_M | TILE_N | num_ctas | occupancy |\n")
|
|||
|
|
f.write("|----------|--------|--------|----------|----------|\n")
|
|||
|
|
f.write("| DGX Spark (sm121) | 64 | 64 | 1 | 2 |\n")
|
|||
|
|
f.write("| B300 (sm100) Config 1 | 256 | 128 | 1 | 1 |\n")
|
|||
|
|
f.write("| B300 (sm100) Config 2 | 128 | 128 | 1 | 2 |\n\n")
|
|||
|
|
|
|||
|
|
f.write("## Results by Sequence Length\n\n")
|
|||
|
|
for r in results:
|
|||
|
|
f.write(f"### Seq Len = {r.seq_len}\n\n")
|
|||
|
|
f.write("| Step | Name | Latency (ms) | TFLOPS | Speedup |\n")
|
|||
|
|
f.write("|------|------|--------------|--------|--------|\n")
|
|||
|
|
for s in r.steps:
|
|||
|
|
f.write(f"| {s.step} | {s.name} | {s.latency_ms:.3f} | {s.tflops:.2f} | {s.speedup_vs_baseline:.2f}x |\n")
|
|||
|
|
if r.tilegym_latency_ms > 0:
|
|||
|
|
f.write(f"| TG | TileGym Reference | {r.tilegym_latency_ms:.3f} | {r.tilegym_tflops:.2f} | {r.tilegym_speedup:.2f}x |\n")
|
|||
|
|
f.write("\n")
|
|||
|
|
|
|||
|
|
log_path = f"{output_dir}/fmha_scaling_log.txt"
|
|||
|
|
with open(log_path, 'w') as f:
|
|||
|
|
f.write('\n'.join(logger.logs))
|
|||
|
|
|
|||
|
|
logger.log(f"\nResults exported to:")
|
|||
|
|
logger.log(f" - {json_path}")
|
|||
|
|
logger.log(f" - {md_path}")
|
|||
|
|
logger.log(f" - {log_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description="FMHA Scaling Analysis")
|
|||
|
|
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
|
|||
|
|
parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
results = run_scaling_analysis(iterations=args.iterations)
|
|||
|
|
print_summary(results)
|
|||
|
|
export_results(results, args.output_dir)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|