mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-17 20:12:20 +00:00
chore: Regenerate all playbooks
This commit is contained in:
parent
3eff7461e1
commit
231a45230d
@ -21,11 +21,13 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
|
||||
|
||||
### NVIDIA
|
||||
|
||||
- [CLI Coding Agent](nvidia/cli-coding-agent/)
|
||||
- [Comfy UI](nvidia/comfy-ui/)
|
||||
- [Connect Three DGX Spark in a Ring Topology](nvidia/connect-three-sparks/)
|
||||
- [Set Up Local Network Access](nvidia/connect-to-your-spark/)
|
||||
- [Connect Two Sparks](nvidia/connect-two-sparks/)
|
||||
- [CUDA-X Data Science](nvidia/cuda-x-data-science/)
|
||||
- [cuTile Kernels](nvidia/cutile-kernels/)
|
||||
- [DGX Dashboard](nvidia/dgx-dashboard/)
|
||||
- [FLUX.1 Dreambooth LoRA Fine-tuning](nvidia/flux-finetuning/)
|
||||
- [Run Hermes Agent with Local Models](nvidia/hermes-agent/)
|
||||
@ -41,6 +43,7 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
|
||||
- [NCCL for Two Sparks](nvidia/nccl/)
|
||||
- [Fine-tune with NeMo](nvidia/nemo-fine-tune/)
|
||||
- [Run NemoClaw with a Local LLM](nvidia/nemoclaw/)
|
||||
- [🦞 Set Up Example NemoClaw Agents 🦞](nvidia/nemoclaw-applications/)
|
||||
- [Nemotron-3-Nano with llama.cpp](nvidia/nemotron/)
|
||||
- [NIM on Spark](nvidia/nim-llm/)
|
||||
- [NVFP4 Quantization](nvidia/nvfp4-quantization/)
|
||||
@ -52,6 +55,7 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
|
||||
- [Fine-tune with Pytorch](nvidia/pytorch-fine-tune/)
|
||||
- [RAG Application in AI Workbench](nvidia/rag-ai-workbench/)
|
||||
- [Spark & Reachy Photo Booth](nvidia/reachy-photo-booth/)
|
||||
- [Register DGX Spark to Brev](nvidia/register-to-brev/)
|
||||
- [SGLang for Inference](nvidia/sglang/)
|
||||
- [Single-cell RNA Sequencing](nvidia/single-cell/)
|
||||
- [Speculative Decoding](nvidia/speculative-decoding/)
|
||||
|
||||
474
nvidia/cli-coding-agent/README.md
Normal file
474
nvidia/cli-coding-agent/README.md
Normal file
@ -0,0 +1,474 @@
|
||||
# CLI Coding Agent
|
||||
|
||||
> Build local CLI coding agents with Ollama
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Claude Code](#claude-code)
|
||||
- [OpenCode](#opencode)
|
||||
- [Codex CLI](#codex-cli)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
Use [Ollama](https://ollama.com) on [DGX Spark](https://www.nvidia.com/en-us/products/workstations/dgx-spark/) to run a local coding model and connect a CLI coding agent. This
|
||||
playbook supports three options: **[Claude Code](https://docs.claude.com/en/docs/claude-code)**, **[OpenCode](https://opencode.ai)**, and **[Codex CLI](https://github.com/openai/codex)**. Each
|
||||
agent is wired up with Ollama's built-in [launch method](https://ollama.com/blog/launch) (`ollama launch <agent>`), so you
|
||||
can work without environment variables, provider config files, or external cloud APIs.
|
||||
|
||||
## Choose your CLI agent
|
||||
|
||||
Pick the tab that matches the CLI agent you want to use:
|
||||
|
||||
- **Claude Code**: Fastest path to a working CLI agent with a local Ollama model.
|
||||
- **OpenCode**: Open-source CLI launched directly from Ollama.
|
||||
- **Codex CLI**: OpenAI Codex CLI launched directly from Ollama against the local model.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You will run a local coding model ([Qwen3.6](https://ollama.com/library/qwen3.6)) on your DGX Spark with Ollama, launch your
|
||||
chosen CLI agent against it with a single command, and complete a small coding task end-to-end.
|
||||
|
||||
## What to know before starting
|
||||
|
||||
- Comfort with Linux command line basics
|
||||
- Experience running terminal-based tools and editors
|
||||
- Familiarity with Python for the short coding task
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- DGX Spark access with NVIDIA DGX OS 7.3.1 (Ubuntu 24.04.3 LTS base)
|
||||
- Internet access to download model weights
|
||||
- [Ollama](https://ollama.com/download) v0.15 or newer (required for [`ollama launch`](https://ollama.com/blog/launch))
|
||||
- GPU memory depends on the Qwen3.6 variant you choose:
|
||||
- `qwen3.6:latest` (35B-a3b, MoE) — ~24GB, 256K context
|
||||
- `qwen3.6:35b-a3b-nvfp4` — ~22GB, NVIDIA FP4 build tuned for Blackwell (DGX Spark)
|
||||
- `qwen3.6:35b-a3b-q8_0` — ~39GB, higher-quality quant
|
||||
- `qwen3.6:35b-a3b-bf16` — ~71GB, full precision (fits Spark's unified memory)
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Duration**: ~15-25 minutes (mostly model download time)
|
||||
* **Risk level**: Low
|
||||
* Large model downloads can fail if network connectivity is unstable
|
||||
* Ollama versions older than 0.15 do not support `ollama launch`
|
||||
* **Rollback**: Stop Ollama and delete the downloaded model from `~/.ollama/models`
|
||||
* **Last Updated:** 04/16/2026
|
||||
* Switched to `ollama launch` method and upgraded the default model to Qwen3.6
|
||||
|
||||
## Claude Code
|
||||
|
||||
## Step 1. Confirm your environment
|
||||
|
||||
**Description**: Verify the OS version and GPU are visible before installing anything.
|
||||
|
||||
```bash
|
||||
cat /etc/os-release | head -n 2
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
|
||||
|
||||
## Step 2. Install or update Ollama
|
||||
|
||||
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
|
||||
|
||||
```bash
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
ollama --version
|
||||
```
|
||||
|
||||
If Ollama is already installed, just verify the version:
|
||||
|
||||
```bash
|
||||
ollama --version
|
||||
```
|
||||
|
||||
Expected output should show Ollama v0.15 or newer.
|
||||
|
||||
## Step 3. Pull Qwen3.6
|
||||
|
||||
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6
|
||||
```
|
||||
|
||||
Optional variants if you want different memory footprints or precision:
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
|
||||
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
|
||||
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
|
||||
```
|
||||
|
||||
Expected output should show `qwen3.6` (and any optional variants) in `ollama list`.
|
||||
|
||||
## Step 4. Test local inference (optional)
|
||||
|
||||
**Description**: Run a quick prompt to confirm the model loads.
|
||||
|
||||
```bash
|
||||
ollama run qwen3.6
|
||||
```
|
||||
|
||||
Try a prompt like:
|
||||
|
||||
```text
|
||||
Write a short README checklist for a Python project.
|
||||
```
|
||||
|
||||
Expected output should show the model responding in the terminal. When you are done, type `/bye` or press `Ctrl+D` to exit the interactive session before continuing.
|
||||
|
||||
## Step 5. Launch Claude Code with Ollama
|
||||
|
||||
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Claude Code](https://docs.claude.com/en/docs/claude-code) against your local model. No environment variables or config files are required.
|
||||
|
||||
```bash
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Expected output should show Claude Code starting and using the local Qwen3.6 model. Qwen3.6 ships with a 256K context window by default; adjust context length through Ollama's settings if you need to tune it further.
|
||||
|
||||
## Step 6. Complete a small coding task
|
||||
|
||||
**Description**: Create a tiny repo and let Claude Code implement a function and tests.
|
||||
|
||||
```bash
|
||||
mkdir -p ~/cli-agent-demo
|
||||
cd ~/cli-agent-demo
|
||||
|
||||
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
|
||||
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
|
||||
```
|
||||
|
||||
If you do not already have pytest installed:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -U pytest
|
||||
```
|
||||
|
||||
In Claude Code:
|
||||
|
||||
```text
|
||||
Please implement add() in math_utils.py and make sure the test passes.
|
||||
```
|
||||
|
||||
Run the test:
|
||||
|
||||
```bash
|
||||
python3 -m pytest -q
|
||||
```
|
||||
|
||||
Expected output should show the test passing.
|
||||
|
||||
## Step 7. Cleanup and rollback
|
||||
|
||||
**Description**: Remove the model and stop services if you no longer need them.
|
||||
|
||||
To stop the service:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop ollama
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This will delete the downloaded model files.
|
||||
|
||||
```bash
|
||||
ollama rm qwen3.6
|
||||
```
|
||||
|
||||
## Step 8. Next steps
|
||||
|
||||
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
|
||||
- Use Claude Code on multi-file refactors or test-generation tasks
|
||||
- Explore the full 256K context window on larger codebases
|
||||
|
||||
## OpenCode
|
||||
|
||||
## Step 1. Confirm your environment
|
||||
|
||||
**Description**: Verify the OS version and GPU are visible before installing anything.
|
||||
|
||||
```bash
|
||||
cat /etc/os-release | head -n 2
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
|
||||
|
||||
## Step 2. Install or update Ollama
|
||||
|
||||
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
|
||||
|
||||
```bash
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
ollama --version
|
||||
```
|
||||
|
||||
If Ollama is already installed, just verify the version:
|
||||
|
||||
```bash
|
||||
ollama --version
|
||||
```
|
||||
|
||||
Expected output should show Ollama v0.15 or newer.
|
||||
|
||||
## Step 3. Pull Qwen3.6
|
||||
|
||||
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6
|
||||
```
|
||||
|
||||
Optional variants if you want different memory footprints or precision:
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
|
||||
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
|
||||
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
|
||||
```
|
||||
|
||||
Expected output should show `qwen3.6` in `ollama list`.
|
||||
|
||||
## Step 4. Test local inference (optional)
|
||||
|
||||
**Description**: Run a quick prompt to confirm the model loads.
|
||||
|
||||
```bash
|
||||
ollama run qwen3.6
|
||||
```
|
||||
|
||||
Try a prompt like:
|
||||
|
||||
```text
|
||||
Write a short README checklist for a Python project.
|
||||
```
|
||||
|
||||
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
|
||||
|
||||
## Step 5. Launch OpenCode with Ollama
|
||||
|
||||
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [OpenCode](https://opencode.ai) against your local model. No [`opencode.json`](https://opencode.ai/docs/config/) provider configuration is required.
|
||||
|
||||
```bash
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
If you want to pre-configure OpenCode without launching immediately:
|
||||
|
||||
```bash
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
Expected output should show OpenCode starting with Ollama preselected as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default.
|
||||
|
||||
## Step 6. Complete a small coding task
|
||||
|
||||
**Description**: Create a tiny repo and let OpenCode implement a function and tests.
|
||||
|
||||
```bash
|
||||
mkdir -p ~/cli-agent-demo
|
||||
cd ~/cli-agent-demo
|
||||
|
||||
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
|
||||
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
|
||||
```
|
||||
|
||||
If you do not already have pytest installed:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -U pytest
|
||||
```
|
||||
|
||||
In OpenCode:
|
||||
|
||||
```text
|
||||
Please implement add() in math_utils.py and make sure the test passes.
|
||||
```
|
||||
|
||||
Run the test:
|
||||
|
||||
```bash
|
||||
python3 -m pytest -q
|
||||
```
|
||||
|
||||
Expected output should show the test passing.
|
||||
|
||||
## Step 7. Cleanup and rollback
|
||||
|
||||
**Description**: Remove the model and stop services if you no longer need them.
|
||||
|
||||
To stop the service:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop ollama
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This will delete the downloaded model files.
|
||||
|
||||
```bash
|
||||
ollama rm qwen3.6
|
||||
```
|
||||
|
||||
## Step 8. Next steps
|
||||
|
||||
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
|
||||
- Use OpenCode on multi-file changes or test-generation tasks
|
||||
- Explore the full 256K context window on larger codebases
|
||||
|
||||
## Codex CLI
|
||||
|
||||
## Step 1. Confirm your environment
|
||||
|
||||
**Description**: Verify the OS version and GPU are visible before installing anything.
|
||||
|
||||
```bash
|
||||
cat /etc/os-release | head -n 2
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
|
||||
|
||||
## Step 2. Install or update Ollama
|
||||
|
||||
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
|
||||
|
||||
```bash
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
ollama --version
|
||||
```
|
||||
|
||||
If Ollama is already installed, just verify the version:
|
||||
|
||||
```bash
|
||||
ollama --version
|
||||
```
|
||||
|
||||
Expected output should show Ollama v0.15 or newer.
|
||||
|
||||
## Step 3. Pull Qwen3.6
|
||||
|
||||
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6
|
||||
```
|
||||
|
||||
Optional variants if you want different memory footprints or precision:
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
|
||||
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
|
||||
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
|
||||
```
|
||||
|
||||
Expected output should show `qwen3.6` in `ollama list`.
|
||||
|
||||
## Step 4. Test local inference (optional)
|
||||
|
||||
**Description**: Run a quick prompt to confirm the model loads.
|
||||
|
||||
```bash
|
||||
ollama run qwen3.6
|
||||
```
|
||||
|
||||
Try a prompt like:
|
||||
|
||||
```text
|
||||
Write a short README checklist for a Python project.
|
||||
```
|
||||
|
||||
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
|
||||
|
||||
## Step 5. Launch Codex CLI with Ollama
|
||||
|
||||
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Codex CLI](https://github.com/openai/codex) against your local model. No `~/.codex/config.toml` and no manual `npm install -g @openai/codex` are required — Ollama handles the Codex integration.
|
||||
|
||||
```bash
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
Expected output should show Codex CLI starting with Ollama as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default, which is well suited to Codex's agentic workflows.
|
||||
|
||||
## Step 6. Complete a small coding task
|
||||
|
||||
**Description**: Create a tiny repo and let Codex implement a function and tests.
|
||||
|
||||
```bash
|
||||
mkdir -p ~/cli-agent-demo
|
||||
cd ~/cli-agent-demo
|
||||
|
||||
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
|
||||
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
|
||||
```
|
||||
|
||||
If you do not already have pytest installed:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -U pytest
|
||||
```
|
||||
|
||||
In Codex:
|
||||
|
||||
```text
|
||||
Please implement add() in math_utils.py and make sure the test passes.
|
||||
```
|
||||
|
||||
Run the test:
|
||||
|
||||
```bash
|
||||
python3 -m pytest -q
|
||||
```
|
||||
|
||||
Expected output should show the test passing.
|
||||
|
||||
## Step 7. Cleanup and rollback
|
||||
|
||||
**Description**: Remove the model and stop services if you no longer need them.
|
||||
|
||||
To stop the service:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop ollama
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This will delete the downloaded model files.
|
||||
|
||||
```bash
|
||||
ollama rm qwen3.6
|
||||
```
|
||||
|
||||
## Step 8. Next steps
|
||||
|
||||
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
|
||||
- Use Codex CLI on multi-file changes or test-generation tasks
|
||||
- Explore the full 256K context window on larger codebases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| `ollama: command not found` | Ollama not installed or PATH not updated | Rerun `curl -fsSL https://ollama.com/install.sh \| sh` and open a new shell |
|
||||
| `ollama launch` reports unknown command | Ollama is older than v0.15 | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
|
||||
| Model load fails with version error or HTTP 412 | Ollama version is too old for the model | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
|
||||
| `model not found` when launching an agent | Model was not pulled | Run `ollama pull qwen3.6` and retry |
|
||||
| `connection refused` to localhost:11434 | Ollama service not running | Start with `ollama serve` or `sudo systemctl start ollama` |
|
||||
| `ollama launch <agent>` exits immediately | Agent integration failed to initialize | Re-run `ollama launch <agent>`; if it persists, check `journalctl -u ollama` |
|
||||
| Slow responses or OOM errors | Model variant too large for GPU memory | Switch to `qwen3.6:35b-a3b-nvfp4` or close other GPU workloads |
|
||||
|
||||
> [!NOTE]
|
||||
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing
|
||||
> between the GPU and CPU. If you see memory pressure, flush the buffer cache with:
|
||||
> ```bash
|
||||
> sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
> ```
|
||||
859
nvidia/cutile-kernels/README.md
Normal file
859
nvidia/cutile-kernels/README.md
Normal file
@ -0,0 +1,859 @@
|
||||
# cuTile Kernels
|
||||
|
||||
> Run cuTile kernel benchmarks, FMHA implementation, and LLM inference on DGX Spark and B300
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Kernel Benchmarks](#kernel-benchmarks)
|
||||
- [End-to-End Inference](#end-to-end-inference)
|
||||
- [FMHA Implementation](#fmha-implementation)
|
||||
- [Attention Basics](#attention-basics)
|
||||
- [Flash Attention Algorithm](#flash-attention-algorithm)
|
||||
- [cuTile Pseudocode → Actual Mapping](#cutile-pseudocode-actual-mapping)
|
||||
- [Kernel Pseudocode](#kernel-pseudocode)
|
||||
- [cuTile Implementation](#cutile-implementation)
|
||||
- [Launching the Kernel](#launching-the-kernel)
|
||||
- [Optimizations](#optimizations)
|
||||
- [Platform Configuration](#platform-configuration)
|
||||
- [Performance Results](#performance-results)
|
||||
- [Common Issues](#common-issues)
|
||||
- [Companion Scripts](#companion-scripts)
|
||||
- [References](#references)
|
||||
- [Platform Comparison](#platform-comparison)
|
||||
- [End-to-End Throughput](#end-to-end-throughput)
|
||||
- [CUDA Kernel Time](#cuda-kernel-time)
|
||||
- [cuTile Kernel Breakdown](#cutile-kernel-breakdown)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
[TileGym](https://github.com/NVIDIA/TileGym) is NVIDIA's benchmark suite and integration framework for cuTile kernels - high-performance GPU kernels written using the cuTile Python DSL. cuTile compiles to Tile IR, enabling developers to write efficient kernels without low-level CUDA programming.
|
||||
|
||||
This playbook covers three workflows:
|
||||
1. **[Kernel Benchmarks](kernel-benchmarks)** - Run standalone cuTile kernel benchmarks (FMHA, MatMul, RMSNorm, etc.)
|
||||
2. **[End-to-End Inference](e2e-inference)** - Run LLM inference with cuTile-optimized kernels via monkey-patching
|
||||
3. **[FMHA Implementation](fmha)** - Step-by-step tutorial building a Flash Multi-Head Attention kernel from pseudocode to optimized cuTile, with companion scripts to run and benchmark
|
||||
|
||||
The same cuTile code runs on both DGX Spark (sm_121) and B300 (sm_103) - cuTile JIT compiles to the appropriate GPU architecture automatically.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
- Run the TileGym benchmark suite on DGX Spark
|
||||
- Run Qwen2-7B or DeepSeek-V2-Lite inference with cuTile-optimized kernels
|
||||
- Observe performance scaling between DGX Spark and B300
|
||||
- Build an FMHA kernel step-by-step from pseudocode to optimized cuTile implementation
|
||||
|
||||
## What to know before starting
|
||||
|
||||
- Basic familiarity with Docker and command-line tools
|
||||
- Understanding of GPU compute concepts (TFLOPS, memory bandwidth)
|
||||
- No CUDA programming experience required
|
||||
- HuggingFace account with access token (for LLM inference)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
**Hardware Requirements:**
|
||||
- DGX Spark with Ubuntu 24.04 or B300 cloud instance
|
||||
- Minimum 16GB GPU memory for LLM inference
|
||||
- At least 50GB available storage space for model downloads
|
||||
|
||||
**Software Requirements:**
|
||||
- Docker installed and configured: `docker ps`
|
||||
- CUDA Toolkit 13.x with Tile IR support
|
||||
- HuggingFace token for model access (LLM inference only)
|
||||
- Network access for pulling containers and downloading models
|
||||
|
||||
Verify Docker is available:
|
||||
```bash
|
||||
docker ps
|
||||
```
|
||||
|
||||
If you get a permission error:
|
||||
```bash
|
||||
sudo usermod -aG docker $USER
|
||||
newgrp docker
|
||||
```
|
||||
|
||||
## Kernel support matrix
|
||||
|
||||
| Kernel | Category | Data Types | Description |
|
||||
|--------|----------|------------|-------------|
|
||||
| **FMHA** | Attention | float16, float8 | Flash Multi-Head Attention |
|
||||
| **MLA** | Attention | bfloat16, float8 | Multi-head Latent Attention |
|
||||
| **MLA Decoding** | Attention | float16, float8 | MLA for decode phase |
|
||||
| **MatMul** | Matrix Ops | float16, float8 | Matrix multiplication |
|
||||
| **BMM** | Matrix Ops | float16 | Batched matrix multiplication |
|
||||
| **Group GEMM** | Matrix Ops | float16, float8 | Grouped GEMM for MoE |
|
||||
| **RMSNorm** | Normalization | float16, bfloat16 | Root mean square normalization |
|
||||
| **RoPE** | Positional | float16 | Rotary position embedding |
|
||||
| **SiLU** | Activation | float16, float32 | SiLU activation with multiply |
|
||||
| **SwiGLU** | Activation | float16, float32 | SwiGLU fused operation |
|
||||
| **Softmax** | Activation | float16 | Softmax normalization |
|
||||
| **Dropout** | Regularization | float16, float32 | Dropout forward |
|
||||
|
||||
## Model support for LLM inference
|
||||
|
||||
| Model | Supported Kernels | Batch Size | Output Tokens | Notes |
|
||||
|-------|-------------------|------------|---------------|-------|
|
||||
| **Qwen2-7B** | RoPE, RMSNorm, SwiGLU, FMHA | 16 | 50 | Standard transformer |
|
||||
| **DeepSeek-V2-Lite** | RoPE, RMSNorm, SiLU, MLA, MoE | 1 | 100 | MLA attention, MoE layers |
|
||||
|
||||
## Ancillary files
|
||||
|
||||
All required assets can be found in the [TileGym repository](https://github.com/NVIDIA/TileGym).
|
||||
|
||||
- `tests/benchmark/run_all.sh` - Run all kernel benchmarks
|
||||
- `modeling/transformers/bench_qwen.sh` - Qwen2-7B benchmark script
|
||||
- `modeling/transformers/bench_deepseek.sh` - DeepSeek-V2-Lite benchmark script
|
||||
- `modeling/transformers/infer.py` - Main inference script with TileGym integration
|
||||
- [`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py) - FMHA step-by-step optimization tutorial
|
||||
- [`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py) - FMHA scaling analysis across sequence lengths
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Estimated time:** 30-45 minutes (including model download for LLM inference)
|
||||
* **Risk level:** Low
|
||||
* Large downloads may fail due to network issues
|
||||
* First run includes JIT compilation overhead
|
||||
* **Rollback:** Remove Docker container to undo all changes
|
||||
* **Last Updated:** February 2026
|
||||
* First Publication
|
||||
|
||||
## Kernel Benchmarks
|
||||
|
||||
## Step 1. Pull CUDA NGC container with CTK 13.x
|
||||
|
||||
```bash
|
||||
docker pull nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
|
||||
```
|
||||
|
||||
Launch an interactive session with GPU access:
|
||||
|
||||
```bash
|
||||
docker run --gpus all -it --rm \
|
||||
-v ~/TileGym:/workspace/TileGym \
|
||||
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The `-v` flag mounts a local directory to persist the TileGym repository. The `--rm` flag automatically removes the container when you exit; omit it if you want to keep the container for later use.
|
||||
|
||||
Or if running outside a container, install Tile IR directly:
|
||||
|
||||
```bash
|
||||
## Requires root privileges - run with sudo or as root
|
||||
sudo apt-get install cuda-tile-ir-13-1 cuda-compiler-13-1
|
||||
```
|
||||
|
||||
## Step 2. Clone TileGym repository
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NVIDIA/TileGym
|
||||
cd TileGym
|
||||
pip install .
|
||||
```
|
||||
|
||||
## Step 3. Run benchmark suite
|
||||
|
||||
```bash
|
||||
cd tests/benchmark/
|
||||
bash run_all.sh
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The benchmark runs sequentially to ensure accurate timing results. This may take 10-15 minutes to complete all kernels.
|
||||
|
||||
## Step 4. View results
|
||||
|
||||
Results show cuTile performance for each kernel and sequence length.
|
||||
|
||||
Expected output should look like:
|
||||
|
||||
```text
|
||||
==========================================
|
||||
Running bench_fused_attention.py...
|
||||
==========================================
|
||||
fused-attention-batch4-head32-d128-fwd-causal=True-float16-TFLOPS:
|
||||
N_CTX CuTile
|
||||
0 1024.0 58.188262
|
||||
1 2048.0 80.906892
|
||||
2 4096.0 86.189532
|
||||
3 8192.0 88.891086
|
||||
4 16384.0 89.491869
|
||||
✓ PASSED: bench_fused_attention.py
|
||||
```
|
||||
|
||||
## Step 5. Run individual benchmarks
|
||||
|
||||
To run specific kernel benchmarks:
|
||||
|
||||
```bash
|
||||
## Flash Multi-Head Attention
|
||||
python bench_fused_attention.py
|
||||
|
||||
## Matrix Multiplication
|
||||
python bench_matrix_multiplication.py
|
||||
|
||||
## RMSNorm
|
||||
python bench_rmsnorm.py
|
||||
|
||||
## RoPE
|
||||
python bench_rope.py
|
||||
|
||||
## SwiGLU
|
||||
python bench_swiglu.py
|
||||
```
|
||||
|
||||
## Step 6. Clean up
|
||||
|
||||
Exit the container:
|
||||
|
||||
```bash
|
||||
exit
|
||||
```
|
||||
|
||||
Remove this workflow's containers (if you ran without `--rm`):
|
||||
|
||||
```bash
|
||||
## Preferred: remove only containers from this workflow's image
|
||||
docker rm $(docker ps -a --filter ancestor=nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 --format '{{.ID}}')
|
||||
|
||||
## Alternative: prune all stopped containers (will prompt for confirmation)
|
||||
## docker container prune
|
||||
```
|
||||
|
||||
Remove the image (optional):
|
||||
|
||||
```bash
|
||||
docker rmi nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
|
||||
```
|
||||
|
||||
## Step 7. Repeat on B300
|
||||
|
||||
Repeat Steps 1-6 on B300 hardware to observe scaling. See the **Platform Comparison** tab for expected scaling results.
|
||||
|
||||
## End-to-End Inference
|
||||
|
||||
## Step 1. Set up environment
|
||||
|
||||
If you haven't already, pull the CUDA container and clone TileGym (see **Kernel Benchmarks** tab for details).
|
||||
|
||||
First, clone TileGym on the host:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/TileGym
|
||||
git clone https://github.com/NVIDIA/TileGym ~/TileGym
|
||||
```
|
||||
|
||||
Then launch the container with the repository mounted:
|
||||
|
||||
```bash
|
||||
docker run --gpus all -it --rm \
|
||||
-v ~/TileGym:/workspace/TileGym \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The `-v ~/.cache/huggingface:/root/.cache/huggingface` mounts your HuggingFace cache to avoid re-downloading models.
|
||||
|
||||
Install TileGym inside the container:
|
||||
|
||||
```bash
|
||||
cd /workspace/TileGym
|
||||
pip install .
|
||||
```
|
||||
|
||||
Set your HuggingFace token for accessing gated models:
|
||||
|
||||
```bash
|
||||
export HF_TOKEN=<your_huggingface_token>
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> You need a HuggingFace account and access token. Get one at https://huggingface.co/settings/tokens
|
||||
|
||||
## Step 2. Run inference benchmark
|
||||
|
||||
Navigate to the transformers benchmark directory:
|
||||
|
||||
```bash
|
||||
cd modeling/transformers
|
||||
```
|
||||
|
||||
**Option A: Run Qwen2-7B benchmark**
|
||||
|
||||
```bash
|
||||
./bench_qwen.sh
|
||||
```
|
||||
|
||||
Configuration: Model `Qwen/Qwen2-7B`, Batch size 16, Output length 50 tokens.
|
||||
|
||||
**Option B: Run DeepSeek-V2-Lite benchmark**
|
||||
|
||||
```bash
|
||||
./bench_deepseek.sh
|
||||
```
|
||||
|
||||
Configuration: Model `deepseek-ai/DeepSeek-V2-Lite-Chat`, Batch size 1, Output length 100 tokens.
|
||||
|
||||
Both scripts run two configurations:
|
||||
1. **PyTorch baseline** - Standard HuggingFace inference
|
||||
2. **TileGym cuTile** - With cuTile kernel replacements
|
||||
|
||||
## Step 3. View results
|
||||
|
||||
**Sample DGX Spark (GB10) Results for Qwen2-7B:**
|
||||
|
||||
```text
|
||||
========================================
|
||||
Benchmark Results
|
||||
========================================
|
||||
Qwen2-7B_naive_bfloat16 | 15.66 tokens/s | 51.10s | 51151.0ms CUDA
|
||||
Qwen2-7B_cutile_attn | 18.52 tokens/s | 43.20s | 43079.7ms CUDA
|
||||
========================================
|
||||
```
|
||||
|
||||
**cuTile Kernel Breakdown (DGX Spark - Qwen2):**
|
||||
|
||||
| Kernel | CUDA Time (ms) | Calls |
|
||||
|--------|----------------|-------|
|
||||
| `fmha_kernel` | 4185.9 | 28 |
|
||||
| `swiglu_forward_kernel` | 2459.8 | 1400 |
|
||||
| `attention_decode_kernel_grouped` | 2271.8 | 1372 |
|
||||
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
|
||||
| `rope_kernel` | 355.6 | 1400 |
|
||||
|
||||
## Step 4. How TileGym monkey-patching works
|
||||
|
||||
TileGym replaces PyTorch model operations with cuTile kernels. The snippet below is taken from TileGym's [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py) and invoked from [`modeling/transformers/infer.py`](https://github.com/NVIDIA/TileGym/blob/main/modeling/transformers/infer.py):
|
||||
|
||||
```python
|
||||
from tilegym.transformers import apply_tilegym_kernel_to_qwen2
|
||||
|
||||
apply_tilegym_kernel_to_qwen2(
|
||||
rope=True, # Replace RoPE with cuTile kernel
|
||||
rms_norm=True, # Replace RMSNorm with cuTile kernel
|
||||
swiglu=True, # Replace SwiGLU with cuTile kernel
|
||||
attn=True, # Replace attention with cuTile FMHA
|
||||
use_cutile=True # Use cuTile backend (vs Triton)
|
||||
)
|
||||
```
|
||||
|
||||
**Patched Kernels for Qwen2:**
|
||||
|
||||
| Kernel | PyTorch Operation | cuTile Replacement |
|
||||
|--------|-------------------|-------------------|
|
||||
| `rms_norm_kernel_static_persistent` | `nn.RMSNorm` | Persistent RMSNorm |
|
||||
| `rope_kernel` | Rotary position embedding | Fused RoPE |
|
||||
| `fmha_kernel` | `F.scaled_dot_product_attention` | Flash Attention |
|
||||
| `swiglu_forward_kernel` | SiLU + Mul | Fused SwiGLU |
|
||||
| `attention_decode_kernel_grouped` | Decode attention | Grouped decode |
|
||||
|
||||
**Patched Kernels for DeepSeek-V2:** (see [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py))
|
||||
|
||||
```python
|
||||
from tilegym.transformers import apply_tilegym_kernel_to_deepseek_v2
|
||||
|
||||
apply_tilegym_kernel_to_deepseek_v2(
|
||||
rope=True, # Replace RoPE with cuTile kernel
|
||||
rms_norm=True, # Replace RMSNorm with cuTile kernel
|
||||
swiglu=True, # Replace SiLU+Mul with cuTile kernel
|
||||
attn=True, # Replace MLA attention with cuTile
|
||||
moe=True, # Replace MoE routing with cuTile
|
||||
use_cutile=True
|
||||
)
|
||||
```
|
||||
|
||||
| Kernel | PyTorch Operation | cuTile Replacement |
|
||||
|--------|-------------------|-------------------|
|
||||
| `prefill_mla` | MLA prefill attention | Multi-head Latent Attention |
|
||||
| `_mla_decoding_split_kv` | MLA decode attention | Split-KV decoding |
|
||||
| `fused_moe_kernel` | MoE expert routing | Fused MoE |
|
||||
| `group_gemm_kernel` | Expert FFN | Grouped GEMM |
|
||||
|
||||
## Step 5. Platform-specific tuning (Advanced)
|
||||
|
||||
cuTile exposes two complementary performance-tuning mechanisms:
|
||||
|
||||
- **[`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Select different kernel launch parameters per GPU architecture (`sm_<major><minor>`). The compiler picks the value matching the current target at JIT time; if no entry matches, the `default` value is used. See the [Performance Tuning](https://docs.nvidia.com/cuda/cutile-python/performance.html) and [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) pages.
|
||||
- **`num_ctas`** - Number of Cooperative Thread Arrays (thread blocks) launched per kernel invocation. Tune to the number of SMs on the target GPU.
|
||||
- **`occupancy`** - Hint for the number of concurrent CTAs the compiler should target per SM. Higher occupancy hides memory latency but increases register/shared-memory pressure. See the [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) documentation.
|
||||
- **[`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Search a list of candidate values at runtime and pick the fastest configuration. Results are reported via [`cuda.tile.tune.TuningResult`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.TuningResult.html) / [`Measurement`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.Measurement.html).
|
||||
|
||||
```python
|
||||
import cuda.tile as ct
|
||||
|
||||
@ct.kernel(
|
||||
# # num_ctas: how many thread blocks to launch.
|
||||
# # Use ByTarget to pick an arch-specific value at JIT time.
|
||||
num_ctas=ct.ByTarget({
|
||||
"sm_103": 8, # B300 - more SMs, launch more CTAs
|
||||
"sm_121": 4, # DGX Spark - fewer SMs (48), use fewer CTAs
|
||||
"default": 1, # Fallback for any other GPU architecture
|
||||
}),
|
||||
# # occupancy: hint for concurrent CTAs per SM (latency hiding vs. register pressure).
|
||||
occupancy=ct.ByTarget({
|
||||
"sm_103": 16, # B300 - high occupancy, plenty of registers/SMEM
|
||||
"sm_121": 12, # DGX Spark - moderate occupancy
|
||||
"default": 8, # Conservative fallback
|
||||
}),
|
||||
opt_level=3 # Maximum compiler optimization level
|
||||
)
|
||||
def optimized_kernel(A, B, C):
|
||||
# # Same kernel code works on all platforms;
|
||||
# # ByTarget swaps in the arch-specific launch params automatically.
|
||||
...
|
||||
```
|
||||
|
||||
For automatic tuning, use [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search over candidate values and pick the fastest configuration at runtime:
|
||||
|
||||
```python
|
||||
@ct.kernel(
|
||||
# # autotune: benchmark each value and pick the fastest.
|
||||
num_ctas=ct.autotune([1, 2, 4, 8, 16]),
|
||||
occupancy=ct.autotune([8, 12, 16, 24]),
|
||||
opt_level=3
|
||||
)
|
||||
def autotuned_kernel(A, B, C):
|
||||
...
|
||||
```
|
||||
|
||||
## Step 6. Repeat on B300
|
||||
|
||||
Repeat Steps 1-3 on B300 hardware. The **same code runs without modification** - cuTile JIT compiles for sm_103 automatically.
|
||||
|
||||
See the **Platform Comparison** tab for detailed scaling results.
|
||||
|
||||
## FMHA Implementation
|
||||
|
||||
## FMHA Implementation Guide
|
||||
|
||||
> [!NOTE]
|
||||
> This is a guide to understanding FMHA implementation in cuTile, not a complete reference. For comprehensive documentation, see the [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/).
|
||||
|
||||
### Attention Basics
|
||||
|
||||
Attention allows a neural network to focus on relevant parts of the input. In transformers (GPT, LLaMA, Qwen), each position computes how much to attend to every other position using three vectors:
|
||||
|
||||
- **Query (Q)**: "What am I looking for?"
|
||||
- **Key (K)**: "What do I contain?"
|
||||
- **Value (V)**: "Here is my content"
|
||||
|
||||
```text
|
||||
Attention(Q, K, V) = softmax(Q × K^T / √d) × V
|
||||
|
||||
Shapes:
|
||||
Q, K, V = [batch, heads, seq_len, head_dim]
|
||||
Q × K^T = [batch, heads, seq_len, seq_len] # Attention scores
|
||||
Output = [batch, heads, seq_len, head_dim]
|
||||
```
|
||||
|
||||
For autoregressive models, **causal masking** ensures each token only attends to previous tokens by setting future scores to -infinity before softmax.
|
||||
|
||||
### Flash Attention Algorithm
|
||||
|
||||
Standard attention materializes a [seq_len × seq_len] matrix (e.g., 2 GB for seq_len=32768). Flash Attention avoids this by processing in tiles with **online softmax**:
|
||||
|
||||
```text
|
||||
m = -infinity # Running maximum
|
||||
l = 0 # Running sum of exp(x - m)
|
||||
acc = 0 # Running weighted sum of values
|
||||
|
||||
FOR each K,V tile:
|
||||
scores = Q_tile @ K_tile.T * scale
|
||||
m_new = max(m, max(scores))
|
||||
correction = exp(m - m_new)
|
||||
l = l * correction + sum(exp(scores - m_new))
|
||||
acc = acc * correction + exp(scores - m_new) @ V_tile
|
||||
m = m_new
|
||||
|
||||
output = acc / l
|
||||
```
|
||||
|
||||
### cuTile Pseudocode → Actual Mapping
|
||||
|
||||
| Concept | Pseudocode | cuTile |
|
||||
|---|---|---|
|
||||
| Define kernel | `KERNEL fmha(...)` | `@ct.kernel()` |
|
||||
| Get block ID | `block_x = BLOCK_ID_X` | `bid_x = ct.bid(0)` |
|
||||
| Create indices | `range(0, N)` | `ct.arange(N, dtype=ct.int32)` |
|
||||
| Create constant tile | `tile = zeros(M, N)` | `ct.full((M, N), 0.0, dtype)` |
|
||||
| Load from memory | `tile = LOAD(ptr, shape)` | `ct.load(tensor, index, shape)` |
|
||||
| Store to memory | `STORE(ptr, tile)` | `ct.store(tensor, index, tile)` |
|
||||
| Matrix multiply | `C = A @ B + C` | `ct.mma(A, B, C)` |
|
||||
| Reduction | `max_val = MAX(tile, axis)` | `ct.max(tile, axis, keepdims)` |
|
||||
|
||||
### Kernel Pseudocode
|
||||
|
||||
```text
|
||||
KERNEL fmha(Q, K, V, Out, scale, TILE_M, TILE_N):
|
||||
tile_row = BLOCK_ID_X
|
||||
batch_head = BLOCK_ID_Y
|
||||
batch = batch_head // num_heads
|
||||
head = batch_head % num_heads
|
||||
|
||||
m_i = full(TILE_M, -infinity)
|
||||
l_i = full(TILE_M, 0)
|
||||
acc = zeros(TILE_M, head_dim)
|
||||
|
||||
q = LOAD(Q[batch, head, tile_row*TILE_M : (tile_row+1)*TILE_M, :])
|
||||
|
||||
FOR j = 0 to num_k_tiles:
|
||||
k = LOAD(K[batch, head, j*TILE_N : (j+1)*TILE_N, :])
|
||||
v = LOAD(V[batch, head, j*TILE_N : (j+1)*TILE_N, :])
|
||||
scores = MMA(q, transpose(k)) * scale
|
||||
IF causal AND in_mask_region:
|
||||
scores = WHERE(valid_mask, scores, -infinity)
|
||||
m_new = max(m_i, row_max(scores))
|
||||
correction = exp(m_i - m_new)
|
||||
p = exp(scores - m_new)
|
||||
l_i = l_i * correction + row_sum(p)
|
||||
acc = acc * correction + MMA(p, v)
|
||||
m_i = m_new
|
||||
|
||||
out = acc / l_i
|
||||
STORE(Out[batch, head, tile_row*TILE_M :, :], out)
|
||||
```
|
||||
|
||||
### cuTile Implementation
|
||||
|
||||
```python
|
||||
import cuda.tile as ct
|
||||
import math
|
||||
ConstInt = ct.Constant[int]
|
||||
ConstBool = ct.Constant[bool]
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_kernel(Q, K, V, Out, qk_scale: float, TILE_D: ConstInt, H: ConstInt,
|
||||
TILE_M: ConstInt, TILE_N: ConstInt, CAUSAL: ConstBool):
|
||||
bid_x, bid_y = ct.bid(0), ct.bid(1)
|
||||
batch_idx, head_idx = bid_y // H, bid_y % H
|
||||
|
||||
offs_m = (bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32))[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0),
|
||||
shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
Tc = ct.cdiv(min((bid_x + 1) * TILE_M, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.mma(q, k_t, ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32))
|
||||
qk = qk * qk_scale
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
qk = ct.where(offs_m >= offs_n, qk,
|
||||
ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.maximum(m_i, ct.max(qk, axis=-1, keepdims=True))
|
||||
qk = qk - m_ij
|
||||
p = ct.exp(qk)
|
||||
alpha = ct.exp(m_i - m_ij)
|
||||
l_i = l_i * alpha + ct.sum(p, axis=-1, keepdims=True)
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
|
||||
acc = ct.mma(p.astype(Q.dtype), v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = (acc / l_i).reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
```
|
||||
|
||||
### Launching the Kernel
|
||||
|
||||
```python
|
||||
def run_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
import torch
|
||||
TILE_M, TILE_N = 64, 64 # Platform-specific (see below)
|
||||
batch, num_heads, seq_len, head_dim = q.shape
|
||||
out = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(), grid, fmha_kernel,
|
||||
(q, k, v, out, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return out
|
||||
```
|
||||
|
||||
### Optimizations
|
||||
|
||||
#### exp2 + flush_to_zero
|
||||
|
||||
`exp2(x) = 2^x` is faster than `exp(x)` on GPU. Requires scale adjustment by `1/log(2)`.
|
||||
|
||||
```python
|
||||
## Convert natural-exp scale to base-2 so we can use the faster ct.exp2 intrinsic.
|
||||
## exp(x) == exp2(x / log(2)) == exp2(x * INV_LOG_2).
|
||||
INV_LOG_2 = 1.0 / math.log(2) # ≈ 1.4427
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2 # Pre-multiply the softmax scale once
|
||||
|
||||
## ... in loop:
|
||||
## Fuse the running-max update with the scale multiplication.
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
## Subtract the running max for numerical stability (online softmax).
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
## flush_to_zero=True: flush denormals to 0 -> avoids slow denormal handling on GPU.
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # Correction factor for previous acc/l_i
|
||||
```
|
||||
|
||||
#### Load Order Transpose
|
||||
|
||||
Load K already transposed using `order` parameter, avoiding explicit permute.
|
||||
|
||||
```python
|
||||
## order=(0,1,3,2) swaps the last two axes during the load,
|
||||
## producing K^T directly in registers -- no extra ct.permute() needed.
|
||||
## shape is expressed in the transposed layout: (1, 1, TILE_D, TILE_N).
|
||||
k_t = ct.load(K, index=(..., 0, j), shape=(1,1,TILE_D,TILE_N),
|
||||
order=(0,1,3,2)).reshape((TILE_D, TILE_N))
|
||||
```
|
||||
|
||||
#### Latency Hints
|
||||
|
||||
Prefetch data to overlap memory loads with computation. See the [Performance Tuning docs](https://docs.nvidia.com/cuda/cutile-python/performance.html) for the full list of load/store hints (e.g. `allow_tma`, `latency`).
|
||||
|
||||
```python
|
||||
## latency=N tells the compiler to issue this load N loop iterations in
|
||||
## advance of its use, so the memory transfer overlaps with the MMA work
|
||||
## from earlier iterations. Larger latency = deeper software pipeline but
|
||||
## more register pressure.
|
||||
k_t = ct.load(K, ..., latency=2) # Prefetch K 2 iterations ahead
|
||||
v_tile = ct.load(V, ..., latency=4) # Prefetch V 4 iterations ahead (used later in the loop)
|
||||
```
|
||||
|
||||
#### Occupancy
|
||||
|
||||
Allow multiple thread blocks per SM to hide memory latency. See the [Execution Model docs](https://docs.nvidia.com/cuda/cutile-python/execution.html) for details on how `occupancy` interacts with registers and shared memory.
|
||||
|
||||
```python
|
||||
## occupancy=N is a hint to the compiler to target N concurrent CTAs per SM.
|
||||
## Higher occupancy -> more warps available to hide memory latency,
|
||||
## but constrains the per-CTA register/SMEM budget.
|
||||
@ct.kernel(occupancy=2) # 2 thread blocks (CTAs) co-resident per SM
|
||||
def fmha_optimized(...):
|
||||
```
|
||||
|
||||
#### Approximate Division
|
||||
|
||||
Use fast approximate division for final normalization.
|
||||
|
||||
```python
|
||||
from cuda.tile import RoundingMode as RMd
|
||||
## RMd.APPROX -> hardware approximate reciprocal/divide (MUFU), much faster
|
||||
## than IEEE-compliant division. Safe here because it's the final softmax
|
||||
## normalization step where a small ULP error is acceptable.
|
||||
## flush_to_zero=True flushes denormals to 0 to avoid the slow path.
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
```
|
||||
|
||||
### Platform Configuration
|
||||
|
||||
The same kernel code works on all platforms; only configuration parameters change. Use [`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to select values per architecture, or [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search candidate values automatically.
|
||||
|
||||
| Platform | TILE_M | TILE_N | Occupancy | Rationale |
|
||||
|---|---|---|---|---|
|
||||
| DGX Spark (sm_121) | 64 | 64 | 2 | Smaller tiles, higher occupancy for 48 SMs |
|
||||
| B300 (sm_103) | 256 | 128 | 1 | Large tiles maximize HBM3e throughput |
|
||||
| B300 alternate | 128 | 128 | 2 | Higher occupancy, balanced parallelism |
|
||||
|
||||
```python
|
||||
import cuda.tile as ct
|
||||
|
||||
@ct.kernel(
|
||||
# # TILE_M / TILE_N: rows/cols of the Q and K/V tiles processed per CTA.
|
||||
# # Larger tiles -> more arithmetic intensity; smaller tiles -> higher occupancy.
|
||||
# # occupancy: target concurrent CTAs per SM (latency hiding vs. register pressure).
|
||||
occupancy=ct.ByTarget({
|
||||
"sm_121": 2, # DGX Spark (48 SMs): 2 CTAs/SM for latency hiding
|
||||
"sm_100": 1, # B300: larger tiles already saturate the SM
|
||||
"default": 1, # Conservative fallback for other architectures
|
||||
}),
|
||||
opt_level=3 # Maximum compiler optimization level
|
||||
)
|
||||
def fmha_kernel(...):
|
||||
...
|
||||
```
|
||||
|
||||
### Performance Results
|
||||
|
||||
> **Note:** PyTorch SDPA is used for correctness verification only, not performance comparison.
|
||||
|
||||
#### DGX Spark (sm_121) — Seq 2048
|
||||
|
||||
| Step | Optimization | Latency (ms) | TFLOPS |
|
||||
|---|---|---|---|
|
||||
| 1 | Basic cuTile | 2.19 | 62.8 |
|
||||
| 2 | + exp2 | 2.07 | 66.5 |
|
||||
| 3 | + Load Order | 2.07 | 66.3 |
|
||||
| 4 | + Latency Hints | 2.07 | 66.5 |
|
||||
| 5 | + Occupancy=2 | 1.73 | 79.5 |
|
||||
| 6 | + Approx Div (Final) | 1.69 | 81.1 |
|
||||
|
||||
#### B300 (sm_103) — Various Seq Lengths
|
||||
|
||||
| Seq Len | Latency (ms) | TFLOPS | vs Spark |
|
||||
|---|---|---|---|
|
||||
| 1024 | 0.074 | 465 | 5.7x |
|
||||
| 2048 | 0.178 | 770 | 9.5x |
|
||||
| 4096 | 0.550 | 999 | 15.1x |
|
||||
| 8192 | 1.897 | 1159 | 14.6x |
|
||||
| 16384 | 7.014 | 1254 | 14.2x |
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Issue | Solution |
|
||||
|---|---|
|
||||
| Shape mismatch in ct.mma | Ensure A is (M,K), B is (K,N), C is (M,N) |
|
||||
| dtype errors | Use `.astype()` before mma; accumulator should be float32 |
|
||||
| Incorrect results with causal | Check mask_start calculation and `offs_m >= offs_n` logic |
|
||||
| Low performance | Try different TILE_M/N, check occupancy, verify latency hints |
|
||||
|
||||
### Companion Scripts
|
||||
|
||||
The following scripts are included in this playbook and can be run on DGX Spark or B300:
|
||||
|
||||
- **[`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py)** — Step-by-step optimization tutorial. Builds the FMHA kernel from basic to fully optimized, matching the progression in this guide.
|
||||
- **[`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py)** — Scaling analysis across sequence lengths. Benchmarks each optimization level and generates performance data.
|
||||
|
||||
```bash
|
||||
## Run the optimization tutorial (DGX Spark)
|
||||
python assets/fmha_optimization_tutorial.py --correctness-check
|
||||
|
||||
## Run the scaling analysis
|
||||
python assets/fmha_scaling_analysis.py --iterations 100
|
||||
```
|
||||
|
||||
### References
|
||||
|
||||
- [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/)
|
||||
- [Tile IR Specification](https://docs.nvidia.com/cuda/tile-ir/)
|
||||
- [TileGym (pre-optimized kernels)](https://github.com/NVIDIA/TileGym)
|
||||
- [NVIDIA Blog: Tuning Flash Attention for Peak Performance in CUDA Tile](https://developer.nvidia.com/blog/tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile/)
|
||||
- [Flash Attention Paper](https://arxiv.org/abs/2205.14135)
|
||||
|
||||
## Platform Comparison
|
||||
|
||||
## DGX Spark vs B300 Performance Comparison
|
||||
|
||||
This page summarizes performance scaling between DGX Spark (GB10) and B300 for both kernel benchmarks and end-to-end LLM inference.
|
||||
|
||||
## Kernel Benchmark Scaling
|
||||
|
||||
Use the ratios below as a reference for how kernel performance scales from DGX Spark (GB10) to B300.
|
||||
|
||||
| Kernel | Metric | B300 / GB10 |
|
||||
|--------|--------|-------------|
|
||||
| FMHA (causal, 8192) | TFLOPS | 13.7x |
|
||||
| FMHA (non-causal, 8192) | TFLOPS | 15.1x |
|
||||
| MatMul (8192) | TFLOPS | 18.9x |
|
||||
| BMM (batch8, 4096) | TFLOPS | 19.4x |
|
||||
| Group GEMM (4096) | TFLOPS | 23.9x |
|
||||
| RMSNorm (4096) | GB/s | 33.1x |
|
||||
| RoPE (16384) | GB/s | 22.8x |
|
||||
|
||||
**Key Observations:**
|
||||
- Compute-heavy kernels typically scale 14-24x from GB10 to B300
|
||||
- Memory-bound kernels can scale 20-33x due to HBM bandwidth advantage
|
||||
|
||||
## Qwen2-7B Performance
|
||||
|
||||
### End-to-End Throughput
|
||||
|
||||
| Configuration | DGX Spark | B300 | Platform Speedup |
|
||||
|---------------|-----------|------|------------------|
|
||||
| **cuTile** | 18.52 tok/s | 257.33 tok/s | **13.9x** |
|
||||
|
||||
### CUDA Kernel Time
|
||||
|
||||
| Configuration | DGX Spark | B300 | Platform Speedup |
|
||||
|---------------|-----------|------|------------------|
|
||||
| **cuTile** | 43,080 ms | 2,954 ms | **14.6x** |
|
||||
|
||||
### cuTile Kernel Breakdown
|
||||
|
||||
**DGX Spark (GB10):**
|
||||
|
||||
| Kernel | CUDA Time (ms) | Calls |
|
||||
|--------|----------------|-------|
|
||||
| `fmha_kernel` | 4,185.9 | 28 |
|
||||
| `swiglu_forward_kernel` | 2,459.8 | 1,400 |
|
||||
| `attention_decode_kernel_grouped` | 2,271.8 | 1,372 |
|
||||
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
|
||||
| `rope_kernel` | 355.6 | 1,400 |
|
||||
|
||||
**B300:**
|
||||
|
||||
| Kernel | CUDA Time (ms) | Speedup vs Spark |
|
||||
|--------|----------------|------------------|
|
||||
| `fmha_kernel` | 337.9 | 12.4x |
|
||||
| `swiglu_forward_kernel` | 226.3 | 10.9x |
|
||||
| `attention_decode_kernel_grouped` | 111.0 | 20.5x |
|
||||
| `rms_norm_kernel_static_persistent` | 29.7 | 21.4x |
|
||||
| `rope_kernel` | 16.7 | 21.3x |
|
||||
|
||||
**Same code, different architectures** - cuTile JIT compiles for sm_121 (Spark) and sm_103 (B300)
|
||||
|
||||
## Platform Specifications
|
||||
|
||||
| Specification | DGX Spark (GB10) | B300 |
|
||||
|---------------|------------------|------|
|
||||
| Compute Capability | sm_121 (12.1) | sm_103 (10.3) |
|
||||
| SMs | 48 | 132 |
|
||||
| Memory | 128 GB LPDDR5x | 192 GB HBM3e |
|
||||
| Memory Bandwidth | 273 GB/s | 8 TB/s |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| `docker: permission denied` | User not in docker group | `sudo usermod -aG docker $USER && newgrp docker` |
|
||||
| `401 Client Error: Unauthorized` | Missing HuggingFace token | `export HF_TOKEN=<your_token>` |
|
||||
| `ModuleNotFoundError: tilegym` | TileGym not installed | `cd TileGym && pip install .` |
|
||||
| `RuntimeError: CUDA out of memory` | Model too large | Reduce batch size or use smaller model |
|
||||
| `Killed` during model load | Out of system memory | Clear cache: `sync; echo 3 > /proc/sys/vm/drop_caches` |
|
||||
| Slow first run | JIT compilation | Normal - cuTile compiles kernels on first run |
|
||||
| `FileNotFoundError: input_prompt_small.txt` | Missing input file | Run from `modeling/transformers` directory |
|
||||
| `torch.cuda.OutOfMemoryError` | Insufficient GPU memory | Reduce `--batch_size` parameter |
|
||||
| `ImportError: cuda.tile` | Missing Tile IR | Install: `apt-get install cuda-tile-ir-13-1` |
|
||||
| Benchmark hangs | GPU busy or locked | Check `nvidia-smi` for other processes |
|
||||
|
||||
> [!NOTE]
|
||||
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing between the GPU and CPU.
|
||||
> With many applications still updating to take advantage of UMA, you may encounter memory issues even when within
|
||||
> the memory capacity of DGX Spark. If that happens, manually flush the buffer cache with:
|
||||
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> First run of cuTile kernels includes JIT compilation overhead. Subsequent runs will be faster as compiled kernels are cached.
|
||||
|
||||
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).
|
||||
959
nvidia/cutile-kernels/assets/fmha_optimization_tutorial.py
Normal file
959
nvidia/cutile-kernels/assets/fmha_optimization_tutorial.py
Normal file
@ -0,0 +1,959 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FMHA Optimization Tutorial: From Naive to Optimized cuTile Implementation
|
||||
|
||||
This script demonstrates step-by-step optimization of Flash Multi-Head Attention
|
||||
using NVIDIA cuTile, starting from a basic implementation and progressively
|
||||
adding optimizations until reaching TileGym-level performance.
|
||||
|
||||
Target Platform: DGX Spark (sm121) with pre-determined optimal tile sizes.
|
||||
Note: TileGym supports autotuning, but we use hardcoded values for this tutorial.
|
||||
|
||||
Configuration (matches TileGym bench_fused_attention.py):
|
||||
- Batch: 4, Heads: 32, Head Dim: 128
|
||||
- Sequence Lengths: 1024, 2048, 4096, 8192, 16384
|
||||
- Benchmark: triton.testing.do_bench_cudagraph
|
||||
|
||||
Usage:
|
||||
python fmha_optimization_tutorial.py [--iterations N] [--correctness-check]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
LOG_SEPARATOR = "=" * 80
|
||||
LOG_SUBSEPARATOR = "-" * 60
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
step: int
|
||||
name: str
|
||||
description: str
|
||||
latency_ms: float
|
||||
tflops: float
|
||||
speedup_vs_baseline: float
|
||||
speedup_vs_previous: float
|
||||
correct: bool
|
||||
key_changes: List[str]
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
self.results: List[BenchmarkResult] = []
|
||||
self.logs: List[str] = []
|
||||
|
||||
def log(self, msg: str):
|
||||
print(msg)
|
||||
self.logs.append(msg)
|
||||
|
||||
def section(self, title: str):
|
||||
self.log(f"\n{LOG_SEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SEPARATOR)
|
||||
|
||||
def subsection(self, title: str):
|
||||
self.log(f"\n{LOG_SUBSEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SUBSEPARATOR)
|
||||
|
||||
def add_result(self, result: BenchmarkResult):
|
||||
self.results.append(result)
|
||||
|
||||
def export_json(self, filepath: str):
|
||||
data = {
|
||||
"results": [asdict(r) for r in self.results],
|
||||
"logs": self.logs
|
||||
}
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def export_markdown(self, filepath: str):
|
||||
with open(filepath, 'w') as f:
|
||||
f.write("# FMHA Optimization Tutorial Results\n\n")
|
||||
f.write("## Summary Table\n\n")
|
||||
f.write("| Step | Name | Latency (ms) | TFLOPS | vs Baseline | vs Previous | Correct |\n")
|
||||
f.write("|------|------|--------------|--------|-------------|-------------|--------|\n")
|
||||
for r in self.results:
|
||||
f.write(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {r.speedup_vs_previous:.2f}x | {'Yes' if r.correct else 'No'} |\n")
|
||||
f.write("\n## Detailed Steps\n\n")
|
||||
for r in self.results:
|
||||
f.write(f"### Step {r.step}: {r.name}\n\n")
|
||||
f.write(f"**Description**: {r.description}\n\n")
|
||||
f.write("**Key Changes**:\n")
|
||||
for change in r.key_changes:
|
||||
f.write(f"- {change}\n")
|
||||
f.write(f"\n**Performance**: {r.latency_ms:.3f}ms, {r.tflops:.2f} TFLOPS, {r.speedup_vs_baseline:.2f}x vs baseline\n\n")
|
||||
|
||||
logger = Logger()
|
||||
|
||||
BATCH = 4
|
||||
N_HEADS = 32
|
||||
HEAD_DIM = 128
|
||||
INV_LOG_2 = 1.0 / math.log(2)
|
||||
|
||||
TILE_M = 64
|
||||
TILE_N = 64
|
||||
OCCUPANCY = 2
|
||||
NUM_CTAS = 1
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
raise RuntimeError("CUDA not available")
|
||||
|
||||
DEVICE = None
|
||||
|
||||
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
|
||||
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
|
||||
total_flops = 2 * flops_per_matmul
|
||||
if causal:
|
||||
total_flops *= 0.5
|
||||
return total_flops
|
||||
|
||||
def benchmark_fn(fn, warmup=10, iterations=100):
|
||||
"""Benchmark using triton's do_bench_cudagraph for accurate timing (matches TileGym)"""
|
||||
try:
|
||||
import triton
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
return ms
|
||||
except (ImportError, Exception):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
return (end - start) / iterations * 1000
|
||||
|
||||
def verify_correctness(output, reference, atol=1e-2, rtol=1e-2):
|
||||
try:
|
||||
torch.testing.assert_close(output, reference, atol=atol, rtol=rtol)
|
||||
return True
|
||||
except AssertionError:
|
||||
max_diff = (output - reference).abs().max().item()
|
||||
logger.log(f" [WARN] Max difference: {max_diff:.6f}")
|
||||
return max_diff < 0.1
|
||||
|
||||
def reference_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
|
||||
)
|
||||
|
||||
def step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True):
|
||||
return reference_fmha(q, k, v, sm_scale, is_causal)
|
||||
|
||||
try:
|
||||
import cuda.tile as ct
|
||||
from cuda.tile import RoundingMode as RMd
|
||||
CUTILE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CUTILE_AVAILABLE = False
|
||||
logger.log("[WARN] cuTile not available. Only PyTorch baseline will run.")
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
ConstInt = ct.Constant[int]
|
||||
ConstBool = ct.Constant[bool]
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step2_mma(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 2: Basic cuTile FMHA with MMA (Tensor Cores)
|
||||
- Uses ct.mma() for matrix multiply
|
||||
- Standard exp() for softmax
|
||||
- Online softmax algorithm
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
qk = qk * qk_scale
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True)
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk - m_ij
|
||||
|
||||
p = ct.exp(qk)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step2(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step2_mma,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step3_exp2(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 3: Use exp2 with flush_to_zero for faster math
|
||||
- exp2(x) = 2^x is faster than exp(x) = e^x on GPU
|
||||
- Requires scaling adjustment: multiply by 1/log(2)
|
||||
- flush_to_zero handles denormals efficiently
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step3(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step3_exp2,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step4_load_order(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 4: Optimize K load with order parameter
|
||||
- Use order=(0,1,3,2) to load K already transposed
|
||||
- Avoids explicit ct.permute() operation
|
||||
- Reduces memory traffic
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2)
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step4(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step4_load_order,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step5_latency(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 5: Add latency hints for better pipelining
|
||||
- latency=2 for K load (prefetch)
|
||||
- latency=4 for V load (more prefetch distance)
|
||||
- Helps overlap memory loads with computation
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step5(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step5_latency,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel(occupancy=2)
|
||||
def fmha_step6_occupancy(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 6: Add occupancy hint
|
||||
- @ct.kernel(occupancy=2) improves SM utilization
|
||||
- Allows multiple thread blocks per SM
|
||||
- Better for hiding memory latency
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step6(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step6_occupancy,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel(occupancy=2)
|
||||
def fmha_step7_approx_div(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 7: Use approximate division for final normalization
|
||||
- ct.truediv with rounding_mode=APPROX is faster
|
||||
- Acceptable accuracy loss for inference
|
||||
- This matches TileGym's optimized implementation
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step7(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step7_approx_div,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
"""Run TileGym's optimized FMHA for comparison"""
|
||||
try:
|
||||
import tilegym
|
||||
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
|
||||
except ImportError:
|
||||
logger.log("[WARN] TileGym not available for comparison")
|
||||
return None
|
||||
|
||||
|
||||
def run_benchmark(seq_len, iterations=100, check_correct=True):
|
||||
global DEVICE
|
||||
DEVICE = get_device()
|
||||
|
||||
logger.section(f"FMHA OPTIMIZATION TUTORIAL - SEQ_LEN={seq_len}")
|
||||
logger.log("Configuration:")
|
||||
logger.log(f" - Batch: {BATCH}")
|
||||
logger.log(f" - Heads: {N_HEADS}")
|
||||
logger.log(f" - Head Dim: {HEAD_DIM}")
|
||||
logger.log(f" - Sequence Length: {seq_len}")
|
||||
logger.log(f" - Tile M: {TILE_M}")
|
||||
logger.log(f" - Tile N: {TILE_N}")
|
||||
logger.log(" - Precision: float16")
|
||||
logger.log(" - Causal: True")
|
||||
logger.log(f" - Iterations: {iterations}")
|
||||
logger.log(f" - Device: {DEVICE}")
|
||||
|
||||
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
||||
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
||||
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
||||
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
|
||||
|
||||
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
|
||||
|
||||
ref_output = reference_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
|
||||
steps = [
|
||||
(0, "PyTorch Baseline", "torch.nn.functional.scaled_dot_product_attention",
|
||||
lambda: step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True),
|
||||
["PyTorch SDPA with cuDNN backend", "Highly optimized baseline"]),
|
||||
]
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
steps.extend([
|
||||
(2, "Basic cuTile + MMA", "Tiled FMHA with ct.mma() for Tensor Cores",
|
||||
lambda: run_step2(q, k, v, sm_scale, is_causal=True),
|
||||
["@ct.kernel decorator", "ct.mma() for QK and PV products", "Online softmax with exp()"]),
|
||||
|
||||
(3, "+ exp2 + flush_to_zero", "Faster exponential math",
|
||||
lambda: run_step3(q, k, v, sm_scale, is_causal=True),
|
||||
["ct.exp2() instead of ct.exp()", "flush_to_zero=True for denormals", "qk_scale *= 1/log(2)"]),
|
||||
|
||||
(4, "+ Load Order Transpose", "Avoid explicit transpose",
|
||||
lambda: run_step4(q, k, v, sm_scale, is_causal=True),
|
||||
["order=(0,1,3,2) for K load", "K loaded already transposed", "Removes ct.permute() call"]),
|
||||
|
||||
(5, "+ Latency Hints", "Better memory pipelining",
|
||||
lambda: run_step5(q, k, v, sm_scale, is_causal=True),
|
||||
["latency=2 for K load", "latency=4 for V load", "Overlaps loads with compute"]),
|
||||
|
||||
(6, "+ Occupancy=2", "Better SM utilization",
|
||||
lambda: run_step6(q, k, v, sm_scale, is_causal=True),
|
||||
["@ct.kernel(occupancy=2)", "Multiple blocks per SM", "Hides memory latency"]),
|
||||
|
||||
(7, "+ Approx Division (Final)", "Fast final normalization",
|
||||
lambda: run_step7(q, k, v, sm_scale, is_causal=True),
|
||||
["ct.truediv with APPROX mode", "Matches TileGym implementation", "Full optimization achieved"]),
|
||||
])
|
||||
|
||||
baseline_latency = None
|
||||
prev_latency = None
|
||||
|
||||
for step_idx, name, desc, fn, changes in steps:
|
||||
logger.subsection(f"Step {step_idx}: {name}")
|
||||
logger.log(f"Description: {desc}")
|
||||
logger.log("Key Changes:")
|
||||
for change in changes:
|
||||
logger.log(f" - {change}")
|
||||
|
||||
try:
|
||||
output = fn()
|
||||
latency_ms = benchmark_fn(fn, warmup=10, iterations=iterations)
|
||||
tflops = flops * 1e-12 / (latency_ms * 1e-3)
|
||||
|
||||
if baseline_latency is None:
|
||||
baseline_latency = latency_ms
|
||||
speedup_baseline = 1.0
|
||||
else:
|
||||
speedup_baseline = baseline_latency / latency_ms
|
||||
|
||||
if prev_latency is None:
|
||||
speedup_prev = 1.0
|
||||
else:
|
||||
speedup_prev = prev_latency / latency_ms
|
||||
|
||||
if check_correct and output is not None:
|
||||
correct = verify_correctness(output, ref_output)
|
||||
else:
|
||||
correct = True
|
||||
|
||||
logger.log("\nResults:")
|
||||
logger.log(f" Latency: {latency_ms:.3f} ms")
|
||||
logger.log(f" TFLOPS: {tflops:.2f}")
|
||||
logger.log(f" vs Baseline: {speedup_baseline:.2f}x")
|
||||
logger.log(f" vs Previous: {speedup_prev:.2f}x")
|
||||
logger.log(f" Correct: {'Yes' if correct else 'No'}")
|
||||
|
||||
result = BenchmarkResult(
|
||||
step=step_idx,
|
||||
name=name,
|
||||
description=desc,
|
||||
latency_ms=latency_ms,
|
||||
tflops=tflops,
|
||||
speedup_vs_baseline=speedup_baseline,
|
||||
speedup_vs_previous=speedup_prev,
|
||||
correct=correct,
|
||||
key_changes=changes
|
||||
)
|
||||
logger.add_result(result)
|
||||
|
||||
prev_latency = latency_ms
|
||||
|
||||
except Exception as e:
|
||||
logger.log(f"\n[ERROR] Step {step_idx} failed: {e}")
|
||||
import traceback
|
||||
logger.log(traceback.format_exc())
|
||||
|
||||
tilegym_output = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
if tilegym_output is not None:
|
||||
logger.subsection("TileGym Reference (for comparison)")
|
||||
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
|
||||
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
|
||||
tilegym_speedup = baseline_latency / tilegym_latency if baseline_latency else 1.0
|
||||
|
||||
logger.log("TileGym FMHA:")
|
||||
logger.log(f" Latency: {tilegym_latency:.3f} ms")
|
||||
logger.log(f" TFLOPS: {tilegym_tflops:.2f}")
|
||||
logger.log(f" vs Baseline: {tilegym_speedup:.2f}x")
|
||||
|
||||
result = BenchmarkResult(
|
||||
step=99,
|
||||
name="TileGym Reference",
|
||||
description="TileGym's optimized FMHA implementation",
|
||||
latency_ms=tilegym_latency,
|
||||
tflops=tilegym_tflops,
|
||||
speedup_vs_baseline=tilegym_speedup,
|
||||
speedup_vs_previous=1.0,
|
||||
correct=True,
|
||||
key_changes=["Full TileGym implementation", "Pre-tuned for sm121", "Production ready"]
|
||||
)
|
||||
logger.add_result(result)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FMHA Optimization Tutorial")
|
||||
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
|
||||
parser.add_argument("--seq-len", type=int, default=2048, help="Sequence length (default matches TileGym)")
|
||||
parser.add_argument("--correctness-check", action="store_true", help="Enable correctness checking")
|
||||
parser.add_argument("--output-dir", type=str, default=".", help="Output directory for logs")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.section("FMHA OPTIMIZATION TUTORIAL")
|
||||
logger.log("From Naive to Optimized cuTile Implementation")
|
||||
logger.log("Target Platform: DGX Spark (sm121)")
|
||||
logger.log(f"Tile Sizes: TILE_M={TILE_M}, TILE_N={TILE_N} (hardcoded from TileGym)")
|
||||
logger.log("Note: TileGym supports autotuning, but we use pre-determined optimal values")
|
||||
|
||||
run_benchmark(
|
||||
seq_len=args.seq_len,
|
||||
iterations=args.iterations,
|
||||
check_correct=args.correctness_check
|
||||
)
|
||||
|
||||
logger.section("FINAL SUMMARY")
|
||||
logger.log("\n| Step | Name | Latency (ms) | TFLOPS | vs Baseline | Correct |")
|
||||
logger.log("|------|------|--------------|--------|-------------|---------|")
|
||||
for r in logger.results:
|
||||
logger.log(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {'Yes' if r.correct else 'No'} |")
|
||||
|
||||
json_path = f"{args.output_dir}/fmha_tutorial_results.json"
|
||||
md_path = f"{args.output_dir}/fmha_tutorial_results.md"
|
||||
log_path = f"{args.output_dir}/fmha_tutorial_log.txt"
|
||||
|
||||
logger.export_json(json_path)
|
||||
logger.export_markdown(md_path)
|
||||
|
||||
with open(log_path, 'w') as f:
|
||||
f.write('\n'.join(logger.logs))
|
||||
|
||||
logger.log("\nResults exported to:")
|
||||
logger.log(f" - {json_path}")
|
||||
logger.log(f" - {md_path}")
|
||||
logger.log(f" - {log_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
891
nvidia/cutile-kernels/assets/fmha_scaling_analysis.py
Normal file
891
nvidia/cutile-kernels/assets/fmha_scaling_analysis.py
Normal file
@ -0,0 +1,891 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FMHA Scaling Analysis: How Optimizations Impact Performance at Different Sizes
|
||||
|
||||
This script demonstrates:
|
||||
1. How FMHA performance scales with sequence length
|
||||
2. Which optimizations provide the most benefit at larger sizes
|
||||
3. Target-specific configurations for different GPU architectures
|
||||
|
||||
Target Platforms (from TileGym):
|
||||
- DGX Spark (sm120/sm121): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
||||
- Blackwell B300 (sm100): TILE_M=256, TILE_N=128 or 128x128, num_ctas=1, occupancy=1-2
|
||||
|
||||
Usage:
|
||||
python fmha_scaling_analysis.py [--iterations N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
|
||||
LOG_SEPARATOR = "=" * 80
|
||||
LOG_SUBSEPARATOR = "-" * 60
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
step: int
|
||||
name: str
|
||||
latency_ms: float
|
||||
tflops: float
|
||||
speedup_vs_baseline: float
|
||||
|
||||
@dataclass
|
||||
class SeqLenResult:
|
||||
seq_len: int
|
||||
steps: List[StepResult]
|
||||
best_step: int
|
||||
best_speedup: float
|
||||
tilegym_latency_ms: float
|
||||
tilegym_tflops: float
|
||||
tilegym_speedup: float
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
self.results: List[SeqLenResult] = []
|
||||
self.logs: List[str] = []
|
||||
|
||||
def log(self, msg: str):
|
||||
print(msg)
|
||||
self.logs.append(msg)
|
||||
|
||||
def section(self, title: str):
|
||||
self.log(f"\n{LOG_SEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SEPARATOR)
|
||||
|
||||
def subsection(self, title: str):
|
||||
self.log(f"\n{LOG_SUBSEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SUBSEPARATOR)
|
||||
|
||||
logger = Logger()
|
||||
|
||||
BATCH = 4
|
||||
N_HEADS = 32
|
||||
HEAD_DIM = 128
|
||||
INV_LOG_2 = 1.0 / math.log(2)
|
||||
|
||||
SEQ_LENS = [1024, 2048, 4096, 8192, 16384]
|
||||
|
||||
|
||||
def get_fmha_config():
|
||||
"""
|
||||
Get target-specific FMHA configuration (from TileGym attention.py)
|
||||
|
||||
Returns configs matching TileGym's _fmha_autotune_configs():
|
||||
- sm120/sm121 (DGX Spark): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
||||
- sm100 (Blackwell B300): Two configs to try via autotuning
|
||||
"""
|
||||
gpu_capability = torch.cuda.get_device_capability()
|
||||
|
||||
if gpu_capability in [(12, 0), (12, 1)]:
|
||||
return [
|
||||
SimpleNamespace(
|
||||
name="DGX Spark (sm121)",
|
||||
TILE_M=64,
|
||||
TILE_N=64,
|
||||
num_ctas=1,
|
||||
occupancy=2
|
||||
)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
SimpleNamespace(
|
||||
name="Blackwell B300 (sm100) - Config 1",
|
||||
TILE_M=256,
|
||||
TILE_N=128,
|
||||
num_ctas=1,
|
||||
occupancy=1
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="Blackwell B300 (sm100) - Config 2",
|
||||
TILE_M=128,
|
||||
TILE_N=128,
|
||||
num_ctas=1,
|
||||
occupancy=2
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
raise RuntimeError("CUDA not available")
|
||||
|
||||
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
|
||||
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
|
||||
total_flops = 2 * flops_per_matmul
|
||||
if causal:
|
||||
total_flops *= 0.5
|
||||
return total_flops
|
||||
|
||||
def benchmark_fn(fn, warmup=10, iterations=100):
|
||||
"""Benchmark using triton's do_bench_cudagraph for accurate timing"""
|
||||
try:
|
||||
import triton
|
||||
# Use triton's cudagraph benchmark - same as TileGym
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
return ms
|
||||
except (ImportError, Exception):
|
||||
# Fallback to manual timing
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
return (end - start) / iterations * 1000
|
||||
|
||||
def reference_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
|
||||
)
|
||||
|
||||
try:
|
||||
import cuda.tile as ct
|
||||
from cuda.tile import RoundingMode as RMd
|
||||
CUTILE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CUTILE_AVAILABLE = False
|
||||
logger.log("[WARN] cuTile not available.")
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
ConstInt = ct.Constant[int]
|
||||
ConstBool = ct.Constant[bool]
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_basic(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 1: Basic cuTile - no optimizations"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
qk = qk * qk_scale
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True)
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk - m_ij
|
||||
|
||||
p = ct.exp(qk)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_math_opt(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 2: Math optimizations - exp2 + flush_to_zero"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_memory_opt(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 3: Memory optimizations - load order + latency hints"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel(occupancy=2)
|
||||
def fmha_full_opt_occ2(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 4a: Full optimization with occupancy=2 (for sm120/sm121)"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel(occupancy=1)
|
||||
def fmha_full_opt_occ1(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 4b: Full optimization with occupancy=1 (for sm100 Blackwell)"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_kernel(kernel_fn, q, k, v, sm_scale, tile_m, tile_n, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / tile_m), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
kernel_fn,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, tile_m, tile_n, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
try:
|
||||
import tilegym
|
||||
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def run_scaling_analysis(iterations=100):
|
||||
device = get_device()
|
||||
gpu_cap = torch.cuda.get_device_capability()
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
|
||||
configs = get_fmha_config()
|
||||
primary_cfg = configs[0]
|
||||
TILE_M = primary_cfg.TILE_M
|
||||
TILE_N = primary_cfg.TILE_N
|
||||
|
||||
logger.section("FMHA SCALING ANALYSIS (TileGym Benchmark Match)")
|
||||
logger.log("Matching TileGym bench_fused_attention.py configuration")
|
||||
logger.log(f"\nGPU: {gpu_name} (sm_{gpu_cap[0]}{gpu_cap[1]})")
|
||||
|
||||
logger.subsection("TARGET-SPECIFIC CONFIGURATION (from TileGym)")
|
||||
for cfg in configs:
|
||||
logger.log(f"\n {cfg.name}:")
|
||||
logger.log(f" TILE_M={cfg.TILE_M}, TILE_N={cfg.TILE_N}")
|
||||
logger.log(f" num_ctas={cfg.num_ctas}, occupancy={cfg.occupancy}")
|
||||
|
||||
logger.log(f"\nUsing primary config: TILE_M={TILE_M}, TILE_N={TILE_N}, occupancy={primary_cfg.occupancy}")
|
||||
logger.log("\nTest Configuration (matches TileGym bench_fused_attention.py):")
|
||||
logger.log(f" Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}")
|
||||
logger.log(" Causal: True, Precision: float16")
|
||||
logger.log(f" Sequence Lengths: {SEQ_LENS}")
|
||||
logger.log(" Benchmark: triton.testing.do_bench_cudagraph (same as TileGym)")
|
||||
|
||||
logger.section("OPTIMIZATION STEPS")
|
||||
logger.log(f"""
|
||||
Step 0: PyTorch Baseline
|
||||
- torch.nn.functional.scaled_dot_product_attention
|
||||
- Uses cuDNN Flash Attention backend
|
||||
- Highly optimized reference
|
||||
|
||||
Step 1: Basic cuTile (TILE_M={TILE_M}, TILE_N={TILE_N})
|
||||
- @ct.kernel with ct.mma() for Tensor Cores
|
||||
- Standard exp() for softmax
|
||||
- Explicit transpose with ct.permute()
|
||||
- No memory/occupancy hints
|
||||
|
||||
Step 2: Math Optimizations
|
||||
- ct.exp2() instead of ct.exp() (faster on GPU)
|
||||
- flush_to_zero=True for denormals
|
||||
- Scale adjustment: multiply by 1/log(2)
|
||||
|
||||
Step 3: Memory Optimizations
|
||||
- Load order=(0,1,3,2) for implicit K transpose
|
||||
- Latency hints: K=2, V=4 for prefetching
|
||||
- Overlaps memory loads with computation
|
||||
|
||||
Step 4: Full Optimization (Target-Specific)
|
||||
- @ct.kernel(occupancy={primary_cfg.occupancy}) for {'sm120/121' if primary_cfg.occupancy == 2 else 'sm100'}
|
||||
- ct.truediv with APPROX rounding mode
|
||||
- Matches TileGym production implementation
|
||||
""")
|
||||
|
||||
logger.section("PLATFORM DIFFERENCES: DGX Spark vs Blackwell B300")
|
||||
logger.log("""
|
||||
| Parameter | DGX Spark (sm121) | Blackwell B300 (sm100) |
|
||||
|--------------|-------------------|------------------------|
|
||||
| TILE_M | 64 | 256 or 128 |
|
||||
| TILE_N | 64 | 128 |
|
||||
| num_ctas | 1 | 1 |
|
||||
| occupancy | 2 | 1 or 2 |
|
||||
|
||||
Why the difference?
|
||||
- B300 has more SMs and larger shared memory -> can use bigger tiles
|
||||
- B300 benefits from larger tiles (256x128) with lower occupancy
|
||||
- DGX Spark needs smaller tiles (64x64) with higher occupancy to hide latency
|
||||
- B300's higher memory bandwidth makes larger tiles more efficient
|
||||
""")
|
||||
|
||||
all_results = []
|
||||
|
||||
select_kernel = fmha_full_opt_occ2 if primary_cfg.occupancy == 2 else fmha_full_opt_occ1
|
||||
|
||||
for seq_len in SEQ_LENS:
|
||||
logger.subsection(f"Sequence Length: {seq_len}")
|
||||
|
||||
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||||
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||||
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||||
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
|
||||
|
||||
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
|
||||
|
||||
steps_results = []
|
||||
|
||||
baseline_fn = lambda: reference_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
baseline_latency = benchmark_fn(baseline_fn, warmup=10, iterations=iterations)
|
||||
baseline_tflops = flops * 1e-12 / (baseline_latency * 1e-3)
|
||||
steps_results.append(StepResult(0, "PyTorch Baseline", baseline_latency, baseline_tflops, 1.0))
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
kernels = [
|
||||
(1, "Basic cuTile", fmha_basic),
|
||||
(2, "Math Opt (exp2)", fmha_math_opt),
|
||||
(3, "Memory Opt (order+latency)", fmha_memory_opt),
|
||||
(4, f"Full Opt (occ={primary_cfg.occupancy})", select_kernel),
|
||||
]
|
||||
|
||||
for step, name, kernel in kernels:
|
||||
try:
|
||||
fn = lambda kernel=kernel: run_kernel(kernel, q, k, v, sm_scale, TILE_M, TILE_N, is_causal=True)
|
||||
latency = benchmark_fn(fn, warmup=10, iterations=iterations)
|
||||
tflops = flops * 1e-12 / (latency * 1e-3)
|
||||
speedup = baseline_latency / latency
|
||||
steps_results.append(StepResult(step, name, latency, tflops, speedup))
|
||||
except Exception as e:
|
||||
logger.log(f" [ERROR] Step {step} failed: {e}")
|
||||
|
||||
tilegym_latency = 0.0
|
||||
tilegym_tflops = 0.0
|
||||
tilegym_speedup = 0.0
|
||||
tilegym_out = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
if tilegym_out is not None:
|
||||
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
|
||||
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
|
||||
tilegym_speedup = baseline_latency / tilegym_latency
|
||||
|
||||
best_step = max(steps_results, key=lambda x: x.speedup_vs_baseline)
|
||||
|
||||
result = SeqLenResult(
|
||||
seq_len=seq_len,
|
||||
steps=steps_results,
|
||||
best_step=best_step.step,
|
||||
best_speedup=best_step.speedup_vs_baseline,
|
||||
tilegym_latency_ms=tilegym_latency,
|
||||
tilegym_tflops=tilegym_tflops,
|
||||
tilegym_speedup=tilegym_speedup,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
logger.log("\n | Step | Name | Latency (ms) | TFLOPS | Speedup |")
|
||||
logger.log(" |------|------|--------------|--------|---------|")
|
||||
|
||||
for sr in steps_results:
|
||||
logger.log(f" | {sr.step} | {sr.name:<28} | {sr.latency_ms:>10.3f} | {sr.tflops:>6.2f} | {sr.speedup_vs_baseline:>6.2f}x |")
|
||||
if tilegym_latency > 0:
|
||||
logger.log(f" | TG | TileGym Reference | {tilegym_latency:>10.3f} | {tilegym_tflops:>6.2f} | {tilegym_speedup:>6.2f}x |")
|
||||
|
||||
logger.log(f"\n Best: Step {best_step.step} ({best_step.name}) with {best_step.speedup_vs_baseline:.2f}x speedup")
|
||||
|
||||
logger.results = all_results
|
||||
return all_results
|
||||
|
||||
|
||||
def print_summary(results: List[SeqLenResult]):
|
||||
configs = get_fmha_config()
|
||||
primary_cfg = configs[0]
|
||||
|
||||
logger.section("SCALING SUMMARY")
|
||||
|
||||
logger.log(f"\nTarget Config: TILE_M={primary_cfg.TILE_M}, TILE_N={primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}")
|
||||
|
||||
logger.log("\n## Performance vs Sequence Length\n")
|
||||
logger.log("| Seq Len | Baseline (ms) | Full Opt (ms) | Speedup | TileGym (ms) | TG Speedup |")
|
||||
logger.log("|---------|---------------|---------------|---------|--------------|------------|")
|
||||
for r in results:
|
||||
baseline = next((s for s in r.steps if s.step == 0), None)
|
||||
full_opt = next((s for s in r.steps if s.step == 4), None)
|
||||
if baseline and full_opt:
|
||||
logger.log(f"| {r.seq_len:>7} | {baseline.latency_ms:>13.3f} | {full_opt.latency_ms:>13.3f} | {full_opt.speedup_vs_baseline:>6.2f}x | {r.tilegym_latency_ms:>12.3f} | {r.tilegym_speedup:>9.2f}x |")
|
||||
|
||||
logger.log("\n## Optimization Impact by Sequence Length\n")
|
||||
logger.log("| Seq Len | Basic | +Math | +Memory | +Full | Best |")
|
||||
logger.log("|---------|-------|-------|---------|-------|------|")
|
||||
for r in results:
|
||||
row = f"| {r.seq_len:>7} |"
|
||||
for step in [1, 2, 3, 4]:
|
||||
sr = next((s for s in r.steps if s.step == step), None)
|
||||
if sr:
|
||||
row += f" {sr.speedup_vs_baseline:>5.2f}x |"
|
||||
else:
|
||||
row += " N/A |"
|
||||
row += f" {r.best_speedup:>4.2f}x |"
|
||||
logger.log(row)
|
||||
|
||||
logger.section("KEY INSIGHTS")
|
||||
logger.log("""
|
||||
## Why Larger Sequences Benefit More from Optimization
|
||||
|
||||
1. **Memory Bandwidth Dominance**
|
||||
- Attention has O(N²) memory complexity for the QK^T matrix
|
||||
- At seq_len=8192: 8192² × 4 bytes = 256MB per head per batch
|
||||
- Memory optimizations (order, latency hints) have larger impact
|
||||
|
||||
2. **More K-Loop Iterations**
|
||||
- At seq_len=512: 8 K-tiles (512/64) for sm121, 2 K-tiles (512/256) for sm100
|
||||
- At seq_len=8192: 128 K-tiles for sm121, 32 K-tiles for sm100
|
||||
- Latency hiding through pipelining amortizes over more iterations
|
||||
|
||||
3. **Better Occupancy Utilization**
|
||||
- More tiles = more parallelism opportunities
|
||||
- sm121 uses occupancy=2 (smaller tiles, more blocks)
|
||||
- sm100 uses occupancy=1 with larger tiles (256x128)
|
||||
|
||||
4. **Platform-Specific Tuning**
|
||||
- DGX Spark (sm121): 64x64 tiles, occupancy=2 - optimized for bandwidth-limited workloads
|
||||
- B300 (sm100): 256x128 tiles, occupancy=1 - optimized for compute-heavy workloads
|
||||
|
||||
## Optimization Priority by Problem Size
|
||||
|
||||
Small (seq_len <= 1024):
|
||||
- Basic cuTile often sufficient
|
||||
- Focus on correctness first
|
||||
|
||||
Medium (1024 < seq_len <= 4096):
|
||||
- Math optimizations (exp2) provide ~5% gain
|
||||
- Memory optimizations start to matter
|
||||
|
||||
Large (seq_len > 4096):
|
||||
- Full optimization stack critical
|
||||
- Platform-specific tuning essential
|
||||
- Memory pipelining becomes essential
|
||||
""")
|
||||
|
||||
|
||||
def export_results(results: List[SeqLenResult], output_dir: str):
|
||||
configs = get_fmha_config()
|
||||
primary_cfg = configs[0]
|
||||
|
||||
data = {
|
||||
"config": {
|
||||
"batch": BATCH,
|
||||
"n_heads": N_HEADS,
|
||||
"head_dim": HEAD_DIM,
|
||||
"tile_m": primary_cfg.TILE_M,
|
||||
"tile_n": primary_cfg.TILE_N,
|
||||
"occupancy": primary_cfg.occupancy,
|
||||
"num_ctas": primary_cfg.num_ctas,
|
||||
"platform": primary_cfg.name,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"seq_len": r.seq_len,
|
||||
"steps": [asdict(s) for s in r.steps],
|
||||
"best_step": r.best_step,
|
||||
"best_speedup": r.best_speedup,
|
||||
"tilegym_latency_ms": r.tilegym_latency_ms,
|
||||
"tilegym_tflops": r.tilegym_tflops,
|
||||
"tilegym_speedup": r.tilegym_speedup,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
}
|
||||
|
||||
json_path = f"{output_dir}/fmha_scaling_results.json"
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
md_path = f"{output_dir}/fmha_scaling_results.md"
|
||||
with open(md_path, 'w') as f:
|
||||
f.write("# FMHA Scaling Analysis Results\n\n")
|
||||
f.write("## Configuration\n")
|
||||
f.write(f"- Platform: {primary_cfg.name}\n")
|
||||
f.write(f"- Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}\n")
|
||||
f.write(f"- Tile: {primary_cfg.TILE_M}x{primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}\n\n")
|
||||
|
||||
f.write("## Target-Specific Configs (from TileGym)\n\n")
|
||||
f.write("| Platform | TILE_M | TILE_N | num_ctas | occupancy |\n")
|
||||
f.write("|----------|--------|--------|----------|----------|\n")
|
||||
f.write("| DGX Spark (sm121) | 64 | 64 | 1 | 2 |\n")
|
||||
f.write("| B300 (sm100) Config 1 | 256 | 128 | 1 | 1 |\n")
|
||||
f.write("| B300 (sm100) Config 2 | 128 | 128 | 1 | 2 |\n\n")
|
||||
|
||||
f.write("## Results by Sequence Length\n\n")
|
||||
for r in results:
|
||||
f.write(f"### Seq Len = {r.seq_len}\n\n")
|
||||
f.write("| Step | Name | Latency (ms) | TFLOPS | Speedup |\n")
|
||||
f.write("|------|------|--------------|--------|--------|\n")
|
||||
for s in r.steps:
|
||||
f.write(f"| {s.step} | {s.name} | {s.latency_ms:.3f} | {s.tflops:.2f} | {s.speedup_vs_baseline:.2f}x |\n")
|
||||
if r.tilegym_latency_ms > 0:
|
||||
f.write(f"| TG | TileGym Reference | {r.tilegym_latency_ms:.3f} | {r.tilegym_tflops:.2f} | {r.tilegym_speedup:.2f}x |\n")
|
||||
f.write("\n")
|
||||
|
||||
log_path = f"{output_dir}/fmha_scaling_log.txt"
|
||||
with open(log_path, 'w') as f:
|
||||
f.write('\n'.join(logger.logs))
|
||||
|
||||
logger.log(f"\nResults exported to:")
|
||||
logger.log(f" - {json_path}")
|
||||
logger.log(f" - {md_path}")
|
||||
logger.log(f" - {log_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FMHA Scaling Analysis")
|
||||
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
|
||||
parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
results = run_scaling_analysis(iterations=args.iterations)
|
||||
print_summary(results)
|
||||
export_results(results, args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1412
nvidia/nemoclaw-applications/README.md
Normal file
1412
nvidia/nemoclaw-applications/README.md
Normal file
File diff suppressed because it is too large
Load Diff
105
nvidia/register-to-brev/README.md
Normal file
105
nvidia/register-to-brev/README.md
Normal file
@ -0,0 +1,105 @@
|
||||
# Register DGX Spark to Brev
|
||||
|
||||
> Link your DGX Spark to Brev for remote access and shared environments
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Instructions](#instructions)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
NVIDIA Brev is an AI development platform that makes GPU environments remotely accessible, shareable, and easy to standardize using preconfigured setups called Launchables.
|
||||
|
||||
This walkthrough will help you connect your NVIDIA DGX Spark to Brev so it shows up as a managed GPU environment in Brev. After a one-time registration, your Spark becomes remotely accessible and shareable.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You’ll register your DGX Spark with Brev and it will be visible as a healthy node in the Brev web UI and CLI, ready to share access and accept workloads whenever needed.
|
||||
|
||||
## What to know before starting
|
||||
|
||||
While Brev automates the complex configuration, understanding a few key concepts when establishing the initial connection will be useful:
|
||||
|
||||
* **Terminal Basics**:
|
||||
* Familiarity with the command line to run a few simple setup commands
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Your DGX Spark [device is set up](https://docs.nvidia.com/dgx/dgx-spark/first-boot.html). You will also need the following:
|
||||
|
||||
* **Brev Account**:
|
||||
* Have an NVIDIA Brev account. Create one [here](https://login.brev.nvidia.com/signin) if you don’t have one.
|
||||
|
||||
* **Permissions**:
|
||||
* You have administrative (root or sudo) access on the DGX Spark device to run the registration command.
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Estimated time:** 5-10 minutes
|
||||
* **Risk level:** Low - Registration configures the Spark for secure remote access without altering your existing workloads
|
||||
* **Rollback:** The Brev configuration can be removed through the UI and CLI
|
||||
|
||||
## Instructions
|
||||
|
||||
## Step 1. Log in to Brev
|
||||
|
||||
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm you’re in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
|
||||
|
||||
Click the “Register Compute” button and follow the instructions in the pop-up window.
|
||||
|
||||
## Step 2. Complete Popup Instructions
|
||||
|
||||
* Install the Brev CLI
|
||||
* Configure your compute
|
||||
* Add a name for compute
|
||||
* To configure ssh, ensure the “Enable SSH access” toggle is on
|
||||
* Run the registration command
|
||||
|
||||
## Step 3. Follow Registration Flow
|
||||
|
||||
In the CLI, you’ll be walked through registration. Go through the flow until registration is complete.
|
||||
|
||||
## Step 4. Confirm Spark in Brev UI
|
||||
|
||||
* Go to the [Brev UI](https://brev.nvidia.com)
|
||||
* Navigate to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute)
|
||||
* Confirm that the DGX Spark appears as a registered node with an **Available** status
|
||||
|
||||
## Step 5. Next Steps
|
||||
|
||||
Your Spark is now integrated into Brev as a secure, remotely accessible GPU environment.
|
||||
|
||||
Now that your hardware is connected, you can:
|
||||
|
||||
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI under [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
|
||||
|
||||
## Step 6. Cleanup
|
||||
|
||||
If you ever decide to unregister your Spark with Brev, you can either do so through the Brev UI or the Brev CLI.
|
||||
|
||||
With the CLI simply run:
|
||||
|
||||
```bash
|
||||
brev deregister
|
||||
```
|
||||
|
||||
In the UI:
|
||||
* Go to the [Brev UI](https://brev.nvidia.com)
|
||||
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
|
||||
* Click the “Deregister” menu item on the Spark you wish to delete from Brev.
|
||||
* Confirm your selection.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| Your DGX Spark is showing up in the wrong org | You registered your DGX Spark to the wrong org | Run `brev set <my-org>` and then redo the registration |
|
||||
| Unable to `brev shell <name>` | Need to refresh | `brev refresh` |
|
||||
|
||||
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).
|
||||
Loading…
Reference in New Issue
Block a user