chore: Regenerate all playbooks

This commit is contained in:
GitLab CI 2026-06-03 15:15:33 +00:00
parent 3eff7461e1
commit 231a45230d
7 changed files with 4704 additions and 0 deletions

View File

@ -21,11 +21,13 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
### NVIDIA
- [CLI Coding Agent](nvidia/cli-coding-agent/)
- [Comfy UI](nvidia/comfy-ui/)
- [Connect Three DGX Spark in a Ring Topology](nvidia/connect-three-sparks/)
- [Set Up Local Network Access](nvidia/connect-to-your-spark/)
- [Connect Two Sparks](nvidia/connect-two-sparks/)
- [CUDA-X Data Science](nvidia/cuda-x-data-science/)
- [cuTile Kernels](nvidia/cutile-kernels/)
- [DGX Dashboard](nvidia/dgx-dashboard/)
- [FLUX.1 Dreambooth LoRA Fine-tuning](nvidia/flux-finetuning/)
- [Run Hermes Agent with Local Models](nvidia/hermes-agent/)
@ -41,6 +43,7 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
- [NCCL for Two Sparks](nvidia/nccl/)
- [Fine-tune with NeMo](nvidia/nemo-fine-tune/)
- [Run NemoClaw with a Local LLM](nvidia/nemoclaw/)
- [🦞 Set Up Example NemoClaw Agents 🦞](nvidia/nemoclaw-applications/)
- [Nemotron-3-Nano with llama.cpp](nvidia/nemotron/)
- [NIM on Spark](nvidia/nim-llm/)
- [NVFP4 Quantization](nvidia/nvfp4-quantization/)
@ -52,6 +55,7 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
- [Fine-tune with Pytorch](nvidia/pytorch-fine-tune/)
- [RAG Application in AI Workbench](nvidia/rag-ai-workbench/)
- [Spark & Reachy Photo Booth](nvidia/reachy-photo-booth/)
- [Register DGX Spark to Brev](nvidia/register-to-brev/)
- [SGLang for Inference](nvidia/sglang/)
- [Single-cell RNA Sequencing](nvidia/single-cell/)
- [Speculative Decoding](nvidia/speculative-decoding/)

View File

@ -0,0 +1,474 @@
# CLI Coding Agent
> Build local CLI coding agents with Ollama
## Table of Contents
- [Overview](#overview)
- [Claude Code](#claude-code)
- [OpenCode](#opencode)
- [Codex CLI](#codex-cli)
- [Troubleshooting](#troubleshooting)
---
## Overview
## Basic idea
Use [Ollama](https://ollama.com) on [DGX Spark](https://www.nvidia.com/en-us/products/workstations/dgx-spark/) to run a local coding model and connect a CLI coding agent. This
playbook supports three options: **[Claude Code](https://docs.claude.com/en/docs/claude-code)**, **[OpenCode](https://opencode.ai)**, and **[Codex CLI](https://github.com/openai/codex)**. Each
agent is wired up with Ollama's built-in [launch method](https://ollama.com/blog/launch) (`ollama launch <agent>`), so you
can work without environment variables, provider config files, or external cloud APIs.
## Choose your CLI agent
Pick the tab that matches the CLI agent you want to use:
- **Claude Code**: Fastest path to a working CLI agent with a local Ollama model.
- **OpenCode**: Open-source CLI launched directly from Ollama.
- **Codex CLI**: OpenAI Codex CLI launched directly from Ollama against the local model.
## What you'll accomplish
You will run a local coding model ([Qwen3.6](https://ollama.com/library/qwen3.6)) on your DGX Spark with Ollama, launch your
chosen CLI agent against it with a single command, and complete a small coding task end-to-end.
## What to know before starting
- Comfort with Linux command line basics
- Experience running terminal-based tools and editors
- Familiarity with Python for the short coding task
## Prerequisites
- DGX Spark access with NVIDIA DGX OS 7.3.1 (Ubuntu 24.04.3 LTS base)
- Internet access to download model weights
- [Ollama](https://ollama.com/download) v0.15 or newer (required for [`ollama launch`](https://ollama.com/blog/launch))
- GPU memory depends on the Qwen3.6 variant you choose:
- `qwen3.6:latest` (35B-a3b, MoE) — ~24GB, 256K context
- `qwen3.6:35b-a3b-nvfp4` — ~22GB, NVIDIA FP4 build tuned for Blackwell (DGX Spark)
- `qwen3.6:35b-a3b-q8_0` — ~39GB, higher-quality quant
- `qwen3.6:35b-a3b-bf16` — ~71GB, full precision (fits Spark's unified memory)
## Time & risk
* **Duration**: ~15-25 minutes (mostly model download time)
* **Risk level**: Low
* Large model downloads can fail if network connectivity is unstable
* Ollama versions older than 0.15 do not support `ollama launch`
* **Rollback**: Stop Ollama and delete the downloaded model from `~/.ollama/models`
* **Last Updated:** 04/16/2026
* Switched to `ollama launch` method and upgraded the default model to Qwen3.6
## Claude Code
## Step 1. Confirm your environment
**Description**: Verify the OS version and GPU are visible before installing anything.
```bash
cat /etc/os-release | head -n 2
nvidia-smi
```
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
## Step 2. Install or update Ollama
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
```bash
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
```
If Ollama is already installed, just verify the version:
```bash
ollama --version
```
Expected output should show Ollama v0.15 or newer.
## Step 3. Pull Qwen3.6
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
```bash
ollama pull qwen3.6
```
Optional variants if you want different memory footprints or precision:
```bash
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
```
Expected output should show `qwen3.6` (and any optional variants) in `ollama list`.
## Step 4. Test local inference (optional)
**Description**: Run a quick prompt to confirm the model loads.
```bash
ollama run qwen3.6
```
Try a prompt like:
```text
Write a short README checklist for a Python project.
```
Expected output should show the model responding in the terminal. When you are done, type `/bye` or press `Ctrl+D` to exit the interactive session before continuing.
## Step 5. Launch Claude Code with Ollama
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Claude Code](https://docs.claude.com/en/docs/claude-code) against your local model. No environment variables or config files are required.
```bash
ollama launch claude
```
Expected output should show Claude Code starting and using the local Qwen3.6 model. Qwen3.6 ships with a 256K context window by default; adjust context length through Ollama's settings if you need to tune it further.
## Step 6. Complete a small coding task
**Description**: Create a tiny repo and let Claude Code implement a function and tests.
```bash
mkdir -p ~/cli-agent-demo
cd ~/cli-agent-demo
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
```
If you do not already have pytest installed:
```bash
python3 -m pip install -U pytest
```
In Claude Code:
```text
Please implement add() in math_utils.py and make sure the test passes.
```
Run the test:
```bash
python3 -m pytest -q
```
Expected output should show the test passing.
## Step 7. Cleanup and rollback
**Description**: Remove the model and stop services if you no longer need them.
To stop the service:
```bash
sudo systemctl stop ollama
```
> [!WARNING]
> This will delete the downloaded model files.
```bash
ollama rm qwen3.6
```
## Step 8. Next steps
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
- Use Claude Code on multi-file refactors or test-generation tasks
- Explore the full 256K context window on larger codebases
## OpenCode
## Step 1. Confirm your environment
**Description**: Verify the OS version and GPU are visible before installing anything.
```bash
cat /etc/os-release | head -n 2
nvidia-smi
```
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
## Step 2. Install or update Ollama
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
```bash
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
```
If Ollama is already installed, just verify the version:
```bash
ollama --version
```
Expected output should show Ollama v0.15 or newer.
## Step 3. Pull Qwen3.6
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
```bash
ollama pull qwen3.6
```
Optional variants if you want different memory footprints or precision:
```bash
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
```
Expected output should show `qwen3.6` in `ollama list`.
## Step 4. Test local inference (optional)
**Description**: Run a quick prompt to confirm the model loads.
```bash
ollama run qwen3.6
```
Try a prompt like:
```text
Write a short README checklist for a Python project.
```
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
## Step 5. Launch OpenCode with Ollama
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [OpenCode](https://opencode.ai) against your local model. No [`opencode.json`](https://opencode.ai/docs/config/) provider configuration is required.
```bash
ollama launch opencode
```
If you want to pre-configure OpenCode without launching immediately:
```bash
ollama launch opencode --config
```
Expected output should show OpenCode starting with Ollama preselected as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default.
## Step 6. Complete a small coding task
**Description**: Create a tiny repo and let OpenCode implement a function and tests.
```bash
mkdir -p ~/cli-agent-demo
cd ~/cli-agent-demo
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
```
If you do not already have pytest installed:
```bash
python3 -m pip install -U pytest
```
In OpenCode:
```text
Please implement add() in math_utils.py and make sure the test passes.
```
Run the test:
```bash
python3 -m pytest -q
```
Expected output should show the test passing.
## Step 7. Cleanup and rollback
**Description**: Remove the model and stop services if you no longer need them.
To stop the service:
```bash
sudo systemctl stop ollama
```
> [!WARNING]
> This will delete the downloaded model files.
```bash
ollama rm qwen3.6
```
## Step 8. Next steps
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
- Use OpenCode on multi-file changes or test-generation tasks
- Explore the full 256K context window on larger codebases
## Codex CLI
## Step 1. Confirm your environment
**Description**: Verify the OS version and GPU are visible before installing anything.
```bash
cat /etc/os-release | head -n 2
nvidia-smi
```
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
## Step 2. Install or update Ollama
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
```bash
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
```
If Ollama is already installed, just verify the version:
```bash
ollama --version
```
Expected output should show Ollama v0.15 or newer.
## Step 3. Pull Qwen3.6
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
```bash
ollama pull qwen3.6
```
Optional variants if you want different memory footprints or precision:
```bash
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
```
Expected output should show `qwen3.6` in `ollama list`.
## Step 4. Test local inference (optional)
**Description**: Run a quick prompt to confirm the model loads.
```bash
ollama run qwen3.6
```
Try a prompt like:
```text
Write a short README checklist for a Python project.
```
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
## Step 5. Launch Codex CLI with Ollama
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Codex CLI](https://github.com/openai/codex) against your local model. No `~/.codex/config.toml` and no manual `npm install -g @openai/codex` are required — Ollama handles the Codex integration.
```bash
ollama launch codex
```
Expected output should show Codex CLI starting with Ollama as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default, which is well suited to Codex's agentic workflows.
## Step 6. Complete a small coding task
**Description**: Create a tiny repo and let Codex implement a function and tests.
```bash
mkdir -p ~/cli-agent-demo
cd ~/cli-agent-demo
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
```
If you do not already have pytest installed:
```bash
python3 -m pip install -U pytest
```
In Codex:
```text
Please implement add() in math_utils.py and make sure the test passes.
```
Run the test:
```bash
python3 -m pytest -q
```
Expected output should show the test passing.
## Step 7. Cleanup and rollback
**Description**: Remove the model and stop services if you no longer need them.
To stop the service:
```bash
sudo systemctl stop ollama
```
> [!WARNING]
> This will delete the downloaded model files.
```bash
ollama rm qwen3.6
```
## Step 8. Next steps
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
- Use Codex CLI on multi-file changes or test-generation tasks
- Explore the full 256K context window on larger codebases
## Troubleshooting
| Symptom | Cause | Fix |
|---------|-------|-----|
| `ollama: command not found` | Ollama not installed or PATH not updated | Rerun `curl -fsSL https://ollama.com/install.sh \| sh` and open a new shell |
| `ollama launch` reports unknown command | Ollama is older than v0.15 | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
| Model load fails with version error or HTTP 412 | Ollama version is too old for the model | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
| `model not found` when launching an agent | Model was not pulled | Run `ollama pull qwen3.6` and retry |
| `connection refused` to localhost:11434 | Ollama service not running | Start with `ollama serve` or `sudo systemctl start ollama` |
| `ollama launch <agent>` exits immediately | Agent integration failed to initialize | Re-run `ollama launch <agent>`; if it persists, check `journalctl -u ollama` |
| Slow responses or OOM errors | Model variant too large for GPU memory | Switch to `qwen3.6:35b-a3b-nvfp4` or close other GPU workloads |
> [!NOTE]
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing
> between the GPU and CPU. If you see memory pressure, flush the buffer cache with:
> ```bash
> sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
> ```

View File

@ -0,0 +1,859 @@
# cuTile Kernels
> Run cuTile kernel benchmarks, FMHA implementation, and LLM inference on DGX Spark and B300
## Table of Contents
- [Overview](#overview)
- [Kernel Benchmarks](#kernel-benchmarks)
- [End-to-End Inference](#end-to-end-inference)
- [FMHA Implementation](#fmha-implementation)
- [Attention Basics](#attention-basics)
- [Flash Attention Algorithm](#flash-attention-algorithm)
- [cuTile Pseudocode → Actual Mapping](#cutile-pseudocode-actual-mapping)
- [Kernel Pseudocode](#kernel-pseudocode)
- [cuTile Implementation](#cutile-implementation)
- [Launching the Kernel](#launching-the-kernel)
- [Optimizations](#optimizations)
- [Platform Configuration](#platform-configuration)
- [Performance Results](#performance-results)
- [Common Issues](#common-issues)
- [Companion Scripts](#companion-scripts)
- [References](#references)
- [Platform Comparison](#platform-comparison)
- [End-to-End Throughput](#end-to-end-throughput)
- [CUDA Kernel Time](#cuda-kernel-time)
- [cuTile Kernel Breakdown](#cutile-kernel-breakdown)
- [Troubleshooting](#troubleshooting)
---
## Overview
## Basic idea
[TileGym](https://github.com/NVIDIA/TileGym) is NVIDIA's benchmark suite and integration framework for cuTile kernels - high-performance GPU kernels written using the cuTile Python DSL. cuTile compiles to Tile IR, enabling developers to write efficient kernels without low-level CUDA programming.
This playbook covers three workflows:
1. **[Kernel Benchmarks](kernel-benchmarks)** - Run standalone cuTile kernel benchmarks (FMHA, MatMul, RMSNorm, etc.)
2. **[End-to-End Inference](e2e-inference)** - Run LLM inference with cuTile-optimized kernels via monkey-patching
3. **[FMHA Implementation](fmha)** - Step-by-step tutorial building a Flash Multi-Head Attention kernel from pseudocode to optimized cuTile, with companion scripts to run and benchmark
The same cuTile code runs on both DGX Spark (sm_121) and B300 (sm_103) - cuTile JIT compiles to the appropriate GPU architecture automatically.
## What you'll accomplish
- Run the TileGym benchmark suite on DGX Spark
- Run Qwen2-7B or DeepSeek-V2-Lite inference with cuTile-optimized kernels
- Observe performance scaling between DGX Spark and B300
- Build an FMHA kernel step-by-step from pseudocode to optimized cuTile implementation
## What to know before starting
- Basic familiarity with Docker and command-line tools
- Understanding of GPU compute concepts (TFLOPS, memory bandwidth)
- No CUDA programming experience required
- HuggingFace account with access token (for LLM inference)
## Prerequisites
**Hardware Requirements:**
- DGX Spark with Ubuntu 24.04 or B300 cloud instance
- Minimum 16GB GPU memory for LLM inference
- At least 50GB available storage space for model downloads
**Software Requirements:**
- Docker installed and configured: `docker ps`
- CUDA Toolkit 13.x with Tile IR support
- HuggingFace token for model access (LLM inference only)
- Network access for pulling containers and downloading models
Verify Docker is available:
```bash
docker ps
```
If you get a permission error:
```bash
sudo usermod -aG docker $USER
newgrp docker
```
## Kernel support matrix
| Kernel | Category | Data Types | Description |
|--------|----------|------------|-------------|
| **FMHA** | Attention | float16, float8 | Flash Multi-Head Attention |
| **MLA** | Attention | bfloat16, float8 | Multi-head Latent Attention |
| **MLA Decoding** | Attention | float16, float8 | MLA for decode phase |
| **MatMul** | Matrix Ops | float16, float8 | Matrix multiplication |
| **BMM** | Matrix Ops | float16 | Batched matrix multiplication |
| **Group GEMM** | Matrix Ops | float16, float8 | Grouped GEMM for MoE |
| **RMSNorm** | Normalization | float16, bfloat16 | Root mean square normalization |
| **RoPE** | Positional | float16 | Rotary position embedding |
| **SiLU** | Activation | float16, float32 | SiLU activation with multiply |
| **SwiGLU** | Activation | float16, float32 | SwiGLU fused operation |
| **Softmax** | Activation | float16 | Softmax normalization |
| **Dropout** | Regularization | float16, float32 | Dropout forward |
## Model support for LLM inference
| Model | Supported Kernels | Batch Size | Output Tokens | Notes |
|-------|-------------------|------------|---------------|-------|
| **Qwen2-7B** | RoPE, RMSNorm, SwiGLU, FMHA | 16 | 50 | Standard transformer |
| **DeepSeek-V2-Lite** | RoPE, RMSNorm, SiLU, MLA, MoE | 1 | 100 | MLA attention, MoE layers |
## Ancillary files
All required assets can be found in the [TileGym repository](https://github.com/NVIDIA/TileGym).
- `tests/benchmark/run_all.sh` - Run all kernel benchmarks
- `modeling/transformers/bench_qwen.sh` - Qwen2-7B benchmark script
- `modeling/transformers/bench_deepseek.sh` - DeepSeek-V2-Lite benchmark script
- `modeling/transformers/infer.py` - Main inference script with TileGym integration
- [`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py) - FMHA step-by-step optimization tutorial
- [`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py) - FMHA scaling analysis across sequence lengths
## Time & risk
* **Estimated time:** 30-45 minutes (including model download for LLM inference)
* **Risk level:** Low
* Large downloads may fail due to network issues
* First run includes JIT compilation overhead
* **Rollback:** Remove Docker container to undo all changes
* **Last Updated:** February 2026
* First Publication
## Kernel Benchmarks
## Step 1. Pull CUDA NGC container with CTK 13.x
```bash
docker pull nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
```
Launch an interactive session with GPU access:
```bash
docker run --gpus all -it --rm \
-v ~/TileGym:/workspace/TileGym \
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
/bin/bash
```
> [!NOTE]
> The `-v` flag mounts a local directory to persist the TileGym repository. The `--rm` flag automatically removes the container when you exit; omit it if you want to keep the container for later use.
Or if running outside a container, install Tile IR directly:
```bash
## Requires root privileges - run with sudo or as root
sudo apt-get install cuda-tile-ir-13-1 cuda-compiler-13-1
```
## Step 2. Clone TileGym repository
```bash
git clone https://github.com/NVIDIA/TileGym
cd TileGym
pip install .
```
## Step 3. Run benchmark suite
```bash
cd tests/benchmark/
bash run_all.sh
```
> [!NOTE]
> The benchmark runs sequentially to ensure accurate timing results. This may take 10-15 minutes to complete all kernels.
## Step 4. View results
Results show cuTile performance for each kernel and sequence length.
Expected output should look like:
```text
==========================================
Running bench_fused_attention.py...
==========================================
fused-attention-batch4-head32-d128-fwd-causal=True-float16-TFLOPS:
N_CTX CuTile
0 1024.0 58.188262
1 2048.0 80.906892
2 4096.0 86.189532
3 8192.0 88.891086
4 16384.0 89.491869
✓ PASSED: bench_fused_attention.py
```
## Step 5. Run individual benchmarks
To run specific kernel benchmarks:
```bash
## Flash Multi-Head Attention
python bench_fused_attention.py
## Matrix Multiplication
python bench_matrix_multiplication.py
## RMSNorm
python bench_rmsnorm.py
## RoPE
python bench_rope.py
## SwiGLU
python bench_swiglu.py
```
## Step 6. Clean up
Exit the container:
```bash
exit
```
Remove this workflow's containers (if you ran without `--rm`):
```bash
## Preferred: remove only containers from this workflow's image
docker rm $(docker ps -a --filter ancestor=nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 --format '{{.ID}}')
## Alternative: prune all stopped containers (will prompt for confirmation)
## docker container prune
```
Remove the image (optional):
```bash
docker rmi nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
```
## Step 7. Repeat on B300
Repeat Steps 1-6 on B300 hardware to observe scaling. See the **Platform Comparison** tab for expected scaling results.
## End-to-End Inference
## Step 1. Set up environment
If you haven't already, pull the CUDA container and clone TileGym (see **Kernel Benchmarks** tab for details).
First, clone TileGym on the host:
```bash
mkdir -p ~/TileGym
git clone https://github.com/NVIDIA/TileGym ~/TileGym
```
Then launch the container with the repository mounted:
```bash
docker run --gpus all -it --rm \
-v ~/TileGym:/workspace/TileGym \
-v ~/.cache/huggingface:/root/.cache/huggingface \
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
/bin/bash
```
> [!NOTE]
> The `-v ~/.cache/huggingface:/root/.cache/huggingface` mounts your HuggingFace cache to avoid re-downloading models.
Install TileGym inside the container:
```bash
cd /workspace/TileGym
pip install .
```
Set your HuggingFace token for accessing gated models:
```bash
export HF_TOKEN=<your_huggingface_token>
```
> [!WARNING]
> You need a HuggingFace account and access token. Get one at https://huggingface.co/settings/tokens
## Step 2. Run inference benchmark
Navigate to the transformers benchmark directory:
```bash
cd modeling/transformers
```
**Option A: Run Qwen2-7B benchmark**
```bash
./bench_qwen.sh
```
Configuration: Model `Qwen/Qwen2-7B`, Batch size 16, Output length 50 tokens.
**Option B: Run DeepSeek-V2-Lite benchmark**
```bash
./bench_deepseek.sh
```
Configuration: Model `deepseek-ai/DeepSeek-V2-Lite-Chat`, Batch size 1, Output length 100 tokens.
Both scripts run two configurations:
1. **PyTorch baseline** - Standard HuggingFace inference
2. **TileGym cuTile** - With cuTile kernel replacements
## Step 3. View results
**Sample DGX Spark (GB10) Results for Qwen2-7B:**
```text
========================================
Benchmark Results
========================================
Qwen2-7B_naive_bfloat16 | 15.66 tokens/s | 51.10s | 51151.0ms CUDA
Qwen2-7B_cutile_attn | 18.52 tokens/s | 43.20s | 43079.7ms CUDA
========================================
```
**cuTile Kernel Breakdown (DGX Spark - Qwen2):**
| Kernel | CUDA Time (ms) | Calls |
|--------|----------------|-------|
| `fmha_kernel` | 4185.9 | 28 |
| `swiglu_forward_kernel` | 2459.8 | 1400 |
| `attention_decode_kernel_grouped` | 2271.8 | 1372 |
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
| `rope_kernel` | 355.6 | 1400 |
## Step 4. How TileGym monkey-patching works
TileGym replaces PyTorch model operations with cuTile kernels. The snippet below is taken from TileGym's [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py) and invoked from [`modeling/transformers/infer.py`](https://github.com/NVIDIA/TileGym/blob/main/modeling/transformers/infer.py):
```python
from tilegym.transformers import apply_tilegym_kernel_to_qwen2
apply_tilegym_kernel_to_qwen2(
rope=True, # Replace RoPE with cuTile kernel
rms_norm=True, # Replace RMSNorm with cuTile kernel
swiglu=True, # Replace SwiGLU with cuTile kernel
attn=True, # Replace attention with cuTile FMHA
use_cutile=True # Use cuTile backend (vs Triton)
)
```
**Patched Kernels for Qwen2:**
| Kernel | PyTorch Operation | cuTile Replacement |
|--------|-------------------|-------------------|
| `rms_norm_kernel_static_persistent` | `nn.RMSNorm` | Persistent RMSNorm |
| `rope_kernel` | Rotary position embedding | Fused RoPE |
| `fmha_kernel` | `F.scaled_dot_product_attention` | Flash Attention |
| `swiglu_forward_kernel` | SiLU + Mul | Fused SwiGLU |
| `attention_decode_kernel_grouped` | Decode attention | Grouped decode |
**Patched Kernels for DeepSeek-V2:** (see [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py))
```python
from tilegym.transformers import apply_tilegym_kernel_to_deepseek_v2
apply_tilegym_kernel_to_deepseek_v2(
rope=True, # Replace RoPE with cuTile kernel
rms_norm=True, # Replace RMSNorm with cuTile kernel
swiglu=True, # Replace SiLU+Mul with cuTile kernel
attn=True, # Replace MLA attention with cuTile
moe=True, # Replace MoE routing with cuTile
use_cutile=True
)
```
| Kernel | PyTorch Operation | cuTile Replacement |
|--------|-------------------|-------------------|
| `prefill_mla` | MLA prefill attention | Multi-head Latent Attention |
| `_mla_decoding_split_kv` | MLA decode attention | Split-KV decoding |
| `fused_moe_kernel` | MoE expert routing | Fused MoE |
| `group_gemm_kernel` | Expert FFN | Grouped GEMM |
## Step 5. Platform-specific tuning (Advanced)
cuTile exposes two complementary performance-tuning mechanisms:
- **[`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Select different kernel launch parameters per GPU architecture (`sm_<major><minor>`). The compiler picks the value matching the current target at JIT time; if no entry matches, the `default` value is used. See the [Performance Tuning](https://docs.nvidia.com/cuda/cutile-python/performance.html) and [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) pages.
- **`num_ctas`** - Number of Cooperative Thread Arrays (thread blocks) launched per kernel invocation. Tune to the number of SMs on the target GPU.
- **`occupancy`** - Hint for the number of concurrent CTAs the compiler should target per SM. Higher occupancy hides memory latency but increases register/shared-memory pressure. See the [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) documentation.
- **[`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Search a list of candidate values at runtime and pick the fastest configuration. Results are reported via [`cuda.tile.tune.TuningResult`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.TuningResult.html) / [`Measurement`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.Measurement.html).
```python
import cuda.tile as ct
@ct.kernel(
# # num_ctas: how many thread blocks to launch.
# # Use ByTarget to pick an arch-specific value at JIT time.
num_ctas=ct.ByTarget({
"sm_103": 8, # B300 - more SMs, launch more CTAs
"sm_121": 4, # DGX Spark - fewer SMs (48), use fewer CTAs
"default": 1, # Fallback for any other GPU architecture
}),
# # occupancy: hint for concurrent CTAs per SM (latency hiding vs. register pressure).
occupancy=ct.ByTarget({
"sm_103": 16, # B300 - high occupancy, plenty of registers/SMEM
"sm_121": 12, # DGX Spark - moderate occupancy
"default": 8, # Conservative fallback
}),
opt_level=3 # Maximum compiler optimization level
)
def optimized_kernel(A, B, C):
# # Same kernel code works on all platforms;
# # ByTarget swaps in the arch-specific launch params automatically.
...
```
For automatic tuning, use [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search over candidate values and pick the fastest configuration at runtime:
```python
@ct.kernel(
# # autotune: benchmark each value and pick the fastest.
num_ctas=ct.autotune([1, 2, 4, 8, 16]),
occupancy=ct.autotune([8, 12, 16, 24]),
opt_level=3
)
def autotuned_kernel(A, B, C):
...
```
## Step 6. Repeat on B300
Repeat Steps 1-3 on B300 hardware. The **same code runs without modification** - cuTile JIT compiles for sm_103 automatically.
See the **Platform Comparison** tab for detailed scaling results.
## FMHA Implementation
## FMHA Implementation Guide
> [!NOTE]
> This is a guide to understanding FMHA implementation in cuTile, not a complete reference. For comprehensive documentation, see the [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/).
### Attention Basics
Attention allows a neural network to focus on relevant parts of the input. In transformers (GPT, LLaMA, Qwen), each position computes how much to attend to every other position using three vectors:
- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I contain?"
- **Value (V)**: "Here is my content"
```text
Attention(Q, K, V) = softmax(Q × K^T / √d) × V
Shapes:
Q, K, V = [batch, heads, seq_len, head_dim]
Q × K^T = [batch, heads, seq_len, seq_len] # Attention scores
Output = [batch, heads, seq_len, head_dim]
```
For autoregressive models, **causal masking** ensures each token only attends to previous tokens by setting future scores to -infinity before softmax.
### Flash Attention Algorithm
Standard attention materializes a [seq_len × seq_len] matrix (e.g., 2 GB for seq_len=32768). Flash Attention avoids this by processing in tiles with **online softmax**:
```text
m = -infinity # Running maximum
l = 0 # Running sum of exp(x - m)
acc = 0 # Running weighted sum of values
FOR each K,V tile:
scores = Q_tile @ K_tile.T * scale
m_new = max(m, max(scores))
correction = exp(m - m_new)
l = l * correction + sum(exp(scores - m_new))
acc = acc * correction + exp(scores - m_new) @ V_tile
m = m_new
output = acc / l
```
### cuTile Pseudocode → Actual Mapping
| Concept | Pseudocode | cuTile |
|---|---|---|
| Define kernel | `KERNEL fmha(...)` | `@ct.kernel()` |
| Get block ID | `block_x = BLOCK_ID_X` | `bid_x = ct.bid(0)` |
| Create indices | `range(0, N)` | `ct.arange(N, dtype=ct.int32)` |
| Create constant tile | `tile = zeros(M, N)` | `ct.full((M, N), 0.0, dtype)` |
| Load from memory | `tile = LOAD(ptr, shape)` | `ct.load(tensor, index, shape)` |
| Store to memory | `STORE(ptr, tile)` | `ct.store(tensor, index, tile)` |
| Matrix multiply | `C = A @ B + C` | `ct.mma(A, B, C)` |
| Reduction | `max_val = MAX(tile, axis)` | `ct.max(tile, axis, keepdims)` |
### Kernel Pseudocode
```text
KERNEL fmha(Q, K, V, Out, scale, TILE_M, TILE_N):
tile_row = BLOCK_ID_X
batch_head = BLOCK_ID_Y
batch = batch_head // num_heads
head = batch_head % num_heads
m_i = full(TILE_M, -infinity)
l_i = full(TILE_M, 0)
acc = zeros(TILE_M, head_dim)
q = LOAD(Q[batch, head, tile_row*TILE_M : (tile_row+1)*TILE_M, :])
FOR j = 0 to num_k_tiles:
k = LOAD(K[batch, head, j*TILE_N : (j+1)*TILE_N, :])
v = LOAD(V[batch, head, j*TILE_N : (j+1)*TILE_N, :])
scores = MMA(q, transpose(k)) * scale
IF causal AND in_mask_region:
scores = WHERE(valid_mask, scores, -infinity)
m_new = max(m_i, row_max(scores))
correction = exp(m_i - m_new)
p = exp(scores - m_new)
l_i = l_i * correction + row_sum(p)
acc = acc * correction + MMA(p, v)
m_i = m_new
out = acc / l_i
STORE(Out[batch, head, tile_row*TILE_M :, :], out)
```
### cuTile Implementation
```python
import cuda.tile as ct
import math
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_kernel(Q, K, V, Out, qk_scale: float, TILE_D: ConstInt, H: ConstInt,
TILE_M: ConstInt, TILE_N: ConstInt, CAUSAL: ConstBool):
bid_x, bid_y = ct.bid(0), ct.bid(1)
batch_idx, head_idx = bid_y // H, bid_y % H
offs_m = (bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32))[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0),
shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
Tc = ct.cdiv(min((bid_x + 1) * TILE_M, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.mma(q, k_t, ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32))
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
qk = ct.where(offs_m >= offs_n, qk,
ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.maximum(m_i, ct.max(qk, axis=-1, keepdims=True))
qk = qk - m_ij
p = ct.exp(qk)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + ct.sum(p, axis=-1, keepdims=True)
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
acc = ct.mma(p.astype(Q.dtype), v_tile, acc)
m_i = m_ij
acc = (acc / l_i).reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
```
### Launching the Kernel
```python
def run_fmha(q, k, v, sm_scale, is_causal=True):
import torch
TILE_M, TILE_N = 64, 64 # Platform-specific (see below)
batch, num_heads, seq_len, head_dim = q.shape
out = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch * num_heads, 1)
ct.launch(
torch.cuda.current_stream(), grid, fmha_kernel,
(q, k, v, out, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return out
```
### Optimizations
#### exp2 + flush_to_zero
`exp2(x) = 2^x` is faster than `exp(x)` on GPU. Requires scale adjustment by `1/log(2)`.
```python
## Convert natural-exp scale to base-2 so we can use the faster ct.exp2 intrinsic.
## exp(x) == exp2(x / log(2)) == exp2(x * INV_LOG_2).
INV_LOG_2 = 1.0 / math.log(2) # ≈ 1.4427
qk_scale_log2 = qk_scale * INV_LOG_2 # Pre-multiply the softmax scale once
## ... in loop:
## Fuse the running-max update with the scale multiplication.
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
## Subtract the running max for numerical stability (online softmax).
qk = qk * qk_scale_log2 - m_ij
## flush_to_zero=True: flush denormals to 0 -> avoids slow denormal handling on GPU.
p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # Correction factor for previous acc/l_i
```
#### Load Order Transpose
Load K already transposed using `order` parameter, avoiding explicit permute.
```python
## order=(0,1,3,2) swaps the last two axes during the load,
## producing K^T directly in registers -- no extra ct.permute() needed.
## shape is expressed in the transposed layout: (1, 1, TILE_D, TILE_N).
k_t = ct.load(K, index=(..., 0, j), shape=(1,1,TILE_D,TILE_N),
order=(0,1,3,2)).reshape((TILE_D, TILE_N))
```
#### Latency Hints
Prefetch data to overlap memory loads with computation. See the [Performance Tuning docs](https://docs.nvidia.com/cuda/cutile-python/performance.html) for the full list of load/store hints (e.g. `allow_tma`, `latency`).
```python
## latency=N tells the compiler to issue this load N loop iterations in
## advance of its use, so the memory transfer overlaps with the MMA work
## from earlier iterations. Larger latency = deeper software pipeline but
## more register pressure.
k_t = ct.load(K, ..., latency=2) # Prefetch K 2 iterations ahead
v_tile = ct.load(V, ..., latency=4) # Prefetch V 4 iterations ahead (used later in the loop)
```
#### Occupancy
Allow multiple thread blocks per SM to hide memory latency. See the [Execution Model docs](https://docs.nvidia.com/cuda/cutile-python/execution.html) for details on how `occupancy` interacts with registers and shared memory.
```python
## occupancy=N is a hint to the compiler to target N concurrent CTAs per SM.
## Higher occupancy -> more warps available to hide memory latency,
## but constrains the per-CTA register/SMEM budget.
@ct.kernel(occupancy=2) # 2 thread blocks (CTAs) co-resident per SM
def fmha_optimized(...):
```
#### Approximate Division
Use fast approximate division for final normalization.
```python
from cuda.tile import RoundingMode as RMd
## RMd.APPROX -> hardware approximate reciprocal/divide (MUFU), much faster
## than IEEE-compliant division. Safe here because it's the final softmax
## normalization step where a small ULP error is acceptable.
## flush_to_zero=True flushes denormals to 0 to avoid the slow path.
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
```
### Platform Configuration
The same kernel code works on all platforms; only configuration parameters change. Use [`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to select values per architecture, or [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search candidate values automatically.
| Platform | TILE_M | TILE_N | Occupancy | Rationale |
|---|---|---|---|---|
| DGX Spark (sm_121) | 64 | 64 | 2 | Smaller tiles, higher occupancy for 48 SMs |
| B300 (sm_103) | 256 | 128 | 1 | Large tiles maximize HBM3e throughput |
| B300 alternate | 128 | 128 | 2 | Higher occupancy, balanced parallelism |
```python
import cuda.tile as ct
@ct.kernel(
# # TILE_M / TILE_N: rows/cols of the Q and K/V tiles processed per CTA.
# # Larger tiles -> more arithmetic intensity; smaller tiles -> higher occupancy.
# # occupancy: target concurrent CTAs per SM (latency hiding vs. register pressure).
occupancy=ct.ByTarget({
"sm_121": 2, # DGX Spark (48 SMs): 2 CTAs/SM for latency hiding
"sm_100": 1, # B300: larger tiles already saturate the SM
"default": 1, # Conservative fallback for other architectures
}),
opt_level=3 # Maximum compiler optimization level
)
def fmha_kernel(...):
...
```
### Performance Results
> **Note:** PyTorch SDPA is used for correctness verification only, not performance comparison.
#### DGX Spark (sm_121) — Seq 2048
| Step | Optimization | Latency (ms) | TFLOPS |
|---|---|---|---|
| 1 | Basic cuTile | 2.19 | 62.8 |
| 2 | + exp2 | 2.07 | 66.5 |
| 3 | + Load Order | 2.07 | 66.3 |
| 4 | + Latency Hints | 2.07 | 66.5 |
| 5 | + Occupancy=2 | 1.73 | 79.5 |
| 6 | + Approx Div (Final) | 1.69 | 81.1 |
#### B300 (sm_103) — Various Seq Lengths
| Seq Len | Latency (ms) | TFLOPS | vs Spark |
|---|---|---|---|
| 1024 | 0.074 | 465 | 5.7x |
| 2048 | 0.178 | 770 | 9.5x |
| 4096 | 0.550 | 999 | 15.1x |
| 8192 | 1.897 | 1159 | 14.6x |
| 16384 | 7.014 | 1254 | 14.2x |
### Common Issues
| Issue | Solution |
|---|---|
| Shape mismatch in ct.mma | Ensure A is (M,K), B is (K,N), C is (M,N) |
| dtype errors | Use `.astype()` before mma; accumulator should be float32 |
| Incorrect results with causal | Check mask_start calculation and `offs_m >= offs_n` logic |
| Low performance | Try different TILE_M/N, check occupancy, verify latency hints |
### Companion Scripts
The following scripts are included in this playbook and can be run on DGX Spark or B300:
- **[`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py)** — Step-by-step optimization tutorial. Builds the FMHA kernel from basic to fully optimized, matching the progression in this guide.
- **[`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py)** — Scaling analysis across sequence lengths. Benchmarks each optimization level and generates performance data.
```bash
## Run the optimization tutorial (DGX Spark)
python assets/fmha_optimization_tutorial.py --correctness-check
## Run the scaling analysis
python assets/fmha_scaling_analysis.py --iterations 100
```
### References
- [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/)
- [Tile IR Specification](https://docs.nvidia.com/cuda/tile-ir/)
- [TileGym (pre-optimized kernels)](https://github.com/NVIDIA/TileGym)
- [NVIDIA Blog: Tuning Flash Attention for Peak Performance in CUDA Tile](https://developer.nvidia.com/blog/tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile/)
- [Flash Attention Paper](https://arxiv.org/abs/2205.14135)
## Platform Comparison
## DGX Spark vs B300 Performance Comparison
This page summarizes performance scaling between DGX Spark (GB10) and B300 for both kernel benchmarks and end-to-end LLM inference.
## Kernel Benchmark Scaling
Use the ratios below as a reference for how kernel performance scales from DGX Spark (GB10) to B300.
| Kernel | Metric | B300 / GB10 |
|--------|--------|-------------|
| FMHA (causal, 8192) | TFLOPS | 13.7x |
| FMHA (non-causal, 8192) | TFLOPS | 15.1x |
| MatMul (8192) | TFLOPS | 18.9x |
| BMM (batch8, 4096) | TFLOPS | 19.4x |
| Group GEMM (4096) | TFLOPS | 23.9x |
| RMSNorm (4096) | GB/s | 33.1x |
| RoPE (16384) | GB/s | 22.8x |
**Key Observations:**
- Compute-heavy kernels typically scale 14-24x from GB10 to B300
- Memory-bound kernels can scale 20-33x due to HBM bandwidth advantage
## Qwen2-7B Performance
### End-to-End Throughput
| Configuration | DGX Spark | B300 | Platform Speedup |
|---------------|-----------|------|------------------|
| **cuTile** | 18.52 tok/s | 257.33 tok/s | **13.9x** |
### CUDA Kernel Time
| Configuration | DGX Spark | B300 | Platform Speedup |
|---------------|-----------|------|------------------|
| **cuTile** | 43,080 ms | 2,954 ms | **14.6x** |
### cuTile Kernel Breakdown
**DGX Spark (GB10):**
| Kernel | CUDA Time (ms) | Calls |
|--------|----------------|-------|
| `fmha_kernel` | 4,185.9 | 28 |
| `swiglu_forward_kernel` | 2,459.8 | 1,400 |
| `attention_decode_kernel_grouped` | 2,271.8 | 1,372 |
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
| `rope_kernel` | 355.6 | 1,400 |
**B300:**
| Kernel | CUDA Time (ms) | Speedup vs Spark |
|--------|----------------|------------------|
| `fmha_kernel` | 337.9 | 12.4x |
| `swiglu_forward_kernel` | 226.3 | 10.9x |
| `attention_decode_kernel_grouped` | 111.0 | 20.5x |
| `rms_norm_kernel_static_persistent` | 29.7 | 21.4x |
| `rope_kernel` | 16.7 | 21.3x |
**Same code, different architectures** - cuTile JIT compiles for sm_121 (Spark) and sm_103 (B300)
## Platform Specifications
| Specification | DGX Spark (GB10) | B300 |
|---------------|------------------|------|
| Compute Capability | sm_121 (12.1) | sm_103 (10.3) |
| SMs | 48 | 132 |
| Memory | 128 GB LPDDR5x | 192 GB HBM3e |
| Memory Bandwidth | 273 GB/s | 8 TB/s |
## Troubleshooting
| Symptom | Cause | Fix |
|---------|-------|-----|
| `docker: permission denied` | User not in docker group | `sudo usermod -aG docker $USER && newgrp docker` |
| `401 Client Error: Unauthorized` | Missing HuggingFace token | `export HF_TOKEN=<your_token>` |
| `ModuleNotFoundError: tilegym` | TileGym not installed | `cd TileGym && pip install .` |
| `RuntimeError: CUDA out of memory` | Model too large | Reduce batch size or use smaller model |
| `Killed` during model load | Out of system memory | Clear cache: `sync; echo 3 > /proc/sys/vm/drop_caches` |
| Slow first run | JIT compilation | Normal - cuTile compiles kernels on first run |
| `FileNotFoundError: input_prompt_small.txt` | Missing input file | Run from `modeling/transformers` directory |
| `torch.cuda.OutOfMemoryError` | Insufficient GPU memory | Reduce `--batch_size` parameter |
| `ImportError: cuda.tile` | Missing Tile IR | Install: `apt-get install cuda-tile-ir-13-1` |
| Benchmark hangs | GPU busy or locked | Check `nvidia-smi` for other processes |
> [!NOTE]
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing between the GPU and CPU.
> With many applications still updating to take advantage of UMA, you may encounter memory issues even when within
> the memory capacity of DGX Spark. If that happens, manually flush the buffer cache with:
```bash
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
```
> [!TIP]
> First run of cuTile kernels includes JIT compilation overhead. Subsequent runs will be faster as compiled kernels are cached.
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).

View File

@ -0,0 +1,959 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
FMHA Optimization Tutorial: From Naive to Optimized cuTile Implementation
This script demonstrates step-by-step optimization of Flash Multi-Head Attention
using NVIDIA cuTile, starting from a basic implementation and progressively
adding optimizations until reaching TileGym-level performance.
Target Platform: DGX Spark (sm121) with pre-determined optimal tile sizes.
Note: TileGym supports autotuning, but we use hardcoded values for this tutorial.
Configuration (matches TileGym bench_fused_attention.py):
- Batch: 4, Heads: 32, Head Dim: 128
- Sequence Lengths: 1024, 2048, 4096, 8192, 16384
- Benchmark: triton.testing.do_bench_cudagraph
Usage:
python fmha_optimization_tutorial.py [--iterations N] [--correctness-check]
"""
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from typing import List, Optional
import sys
import torch
LOG_SEPARATOR = "=" * 80
LOG_SUBSEPARATOR = "-" * 60
@dataclass
class BenchmarkResult:
step: int
name: str
description: str
latency_ms: float
tflops: float
speedup_vs_baseline: float
speedup_vs_previous: float
correct: bool
key_changes: List[str]
class Logger:
def __init__(self):
self.results: List[BenchmarkResult] = []
self.logs: List[str] = []
def log(self, msg: str):
print(msg)
self.logs.append(msg)
def section(self, title: str):
self.log(f"\n{LOG_SEPARATOR}")
self.log(f" {title}")
self.log(LOG_SEPARATOR)
def subsection(self, title: str):
self.log(f"\n{LOG_SUBSEPARATOR}")
self.log(f" {title}")
self.log(LOG_SUBSEPARATOR)
def add_result(self, result: BenchmarkResult):
self.results.append(result)
def export_json(self, filepath: str):
data = {
"results": [asdict(r) for r in self.results],
"logs": self.logs
}
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
def export_markdown(self, filepath: str):
with open(filepath, 'w') as f:
f.write("# FMHA Optimization Tutorial Results\n\n")
f.write("## Summary Table\n\n")
f.write("| Step | Name | Latency (ms) | TFLOPS | vs Baseline | vs Previous | Correct |\n")
f.write("|------|------|--------------|--------|-------------|-------------|--------|\n")
for r in self.results:
f.write(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {r.speedup_vs_previous:.2f}x | {'Yes' if r.correct else 'No'} |\n")
f.write("\n## Detailed Steps\n\n")
for r in self.results:
f.write(f"### Step {r.step}: {r.name}\n\n")
f.write(f"**Description**: {r.description}\n\n")
f.write("**Key Changes**:\n")
for change in r.key_changes:
f.write(f"- {change}\n")
f.write(f"\n**Performance**: {r.latency_ms:.3f}ms, {r.tflops:.2f} TFLOPS, {r.speedup_vs_baseline:.2f}x vs baseline\n\n")
logger = Logger()
BATCH = 4
N_HEADS = 32
HEAD_DIM = 128
INV_LOG_2 = 1.0 / math.log(2)
TILE_M = 64
TILE_N = 64
OCCUPANCY = 2
NUM_CTAS = 1
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
raise RuntimeError("CUDA not available")
DEVICE = None
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
return total_flops
def benchmark_fn(fn, warmup=10, iterations=100):
"""Benchmark using triton's do_bench_cudagraph for accurate timing (matches TileGym)"""
try:
import triton
ms = triton.testing.do_bench_cudagraph(fn)
return ms
except (ImportError, Exception):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / iterations * 1000
def verify_correctness(output, reference, atol=1e-2, rtol=1e-2):
try:
torch.testing.assert_close(output, reference, atol=atol, rtol=rtol)
return True
except AssertionError:
max_diff = (output - reference).abs().max().item()
logger.log(f" [WARN] Max difference: {max_diff:.6f}")
return max_diff < 0.1
def reference_fmha(q, k, v, sm_scale, is_causal=True):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
)
def step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True):
return reference_fmha(q, k, v, sm_scale, is_causal)
try:
import cuda.tile as ct
from cuda.tile import RoundingMode as RMd
CUTILE_AVAILABLE = True
except ImportError:
CUTILE_AVAILABLE = False
logger.log("[WARN] cuTile not available. Only PyTorch baseline will run.")
if CUTILE_AVAILABLE:
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_step2_mma(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 2: Basic cuTile FMHA with MMA (Tensor Cores)
- Uses ct.mma() for matrix multiply
- Standard exp() for softmax
- Online softmax algorithm
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True)
m_ij = ct.maximum(m_i, m_ij)
qk = qk - m_ij
p = ct.exp(qk)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step2(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step2_mma,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step3_exp2(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 3: Use exp2 with flush_to_zero for faster math
- exp2(x) = 2^x is faster than exp(x) = e^x on GPU
- Requires scaling adjustment: multiply by 1/log(2)
- flush_to_zero handles denormals efficiently
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step3(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step3_exp2,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step4_load_order(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 4: Optimize K load with order parameter
- Use order=(0,1,3,2) to load K already transposed
- Avoids explicit ct.permute() operation
- Reduces memory traffic
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2)
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step4(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step4_load_order,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step5_latency(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 5: Add latency hints for better pipelining
- latency=2 for K load (prefetch)
- latency=4 for V load (more prefetch distance)
- Helps overlap memory loads with computation
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step5(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step5_latency,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel(occupancy=2)
def fmha_step6_occupancy(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 6: Add occupancy hint
- @ct.kernel(occupancy=2) improves SM utilization
- Allows multiple thread blocks per SM
- Better for hiding memory latency
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step6(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step6_occupancy,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel(occupancy=2)
def fmha_step7_approx_div(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 7: Use approximate division for final normalization
- ct.truediv with rounding_mode=APPROX is faster
- Acceptable accuracy loss for inference
- This matches TileGym's optimized implementation
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step7(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step7_approx_div,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
"""Run TileGym's optimized FMHA for comparison"""
try:
import tilegym
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
except ImportError:
logger.log("[WARN] TileGym not available for comparison")
return None
def run_benchmark(seq_len, iterations=100, check_correct=True):
global DEVICE
DEVICE = get_device()
logger.section(f"FMHA OPTIMIZATION TUTORIAL - SEQ_LEN={seq_len}")
logger.log("Configuration:")
logger.log(f" - Batch: {BATCH}")
logger.log(f" - Heads: {N_HEADS}")
logger.log(f" - Head Dim: {HEAD_DIM}")
logger.log(f" - Sequence Length: {seq_len}")
logger.log(f" - Tile M: {TILE_M}")
logger.log(f" - Tile N: {TILE_N}")
logger.log(" - Precision: float16")
logger.log(" - Causal: True")
logger.log(f" - Iterations: {iterations}")
logger.log(f" - Device: {DEVICE}")
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
ref_output = reference_fmha(q, k, v, sm_scale, is_causal=True)
steps = [
(0, "PyTorch Baseline", "torch.nn.functional.scaled_dot_product_attention",
lambda: step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True),
["PyTorch SDPA with cuDNN backend", "Highly optimized baseline"]),
]
if CUTILE_AVAILABLE:
steps.extend([
(2, "Basic cuTile + MMA", "Tiled FMHA with ct.mma() for Tensor Cores",
lambda: run_step2(q, k, v, sm_scale, is_causal=True),
["@ct.kernel decorator", "ct.mma() for QK and PV products", "Online softmax with exp()"]),
(3, "+ exp2 + flush_to_zero", "Faster exponential math",
lambda: run_step3(q, k, v, sm_scale, is_causal=True),
["ct.exp2() instead of ct.exp()", "flush_to_zero=True for denormals", "qk_scale *= 1/log(2)"]),
(4, "+ Load Order Transpose", "Avoid explicit transpose",
lambda: run_step4(q, k, v, sm_scale, is_causal=True),
["order=(0,1,3,2) for K load", "K loaded already transposed", "Removes ct.permute() call"]),
(5, "+ Latency Hints", "Better memory pipelining",
lambda: run_step5(q, k, v, sm_scale, is_causal=True),
["latency=2 for K load", "latency=4 for V load", "Overlaps loads with compute"]),
(6, "+ Occupancy=2", "Better SM utilization",
lambda: run_step6(q, k, v, sm_scale, is_causal=True),
["@ct.kernel(occupancy=2)", "Multiple blocks per SM", "Hides memory latency"]),
(7, "+ Approx Division (Final)", "Fast final normalization",
lambda: run_step7(q, k, v, sm_scale, is_causal=True),
["ct.truediv with APPROX mode", "Matches TileGym implementation", "Full optimization achieved"]),
])
baseline_latency = None
prev_latency = None
for step_idx, name, desc, fn, changes in steps:
logger.subsection(f"Step {step_idx}: {name}")
logger.log(f"Description: {desc}")
logger.log("Key Changes:")
for change in changes:
logger.log(f" - {change}")
try:
output = fn()
latency_ms = benchmark_fn(fn, warmup=10, iterations=iterations)
tflops = flops * 1e-12 / (latency_ms * 1e-3)
if baseline_latency is None:
baseline_latency = latency_ms
speedup_baseline = 1.0
else:
speedup_baseline = baseline_latency / latency_ms
if prev_latency is None:
speedup_prev = 1.0
else:
speedup_prev = prev_latency / latency_ms
if check_correct and output is not None:
correct = verify_correctness(output, ref_output)
else:
correct = True
logger.log("\nResults:")
logger.log(f" Latency: {latency_ms:.3f} ms")
logger.log(f" TFLOPS: {tflops:.2f}")
logger.log(f" vs Baseline: {speedup_baseline:.2f}x")
logger.log(f" vs Previous: {speedup_prev:.2f}x")
logger.log(f" Correct: {'Yes' if correct else 'No'}")
result = BenchmarkResult(
step=step_idx,
name=name,
description=desc,
latency_ms=latency_ms,
tflops=tflops,
speedup_vs_baseline=speedup_baseline,
speedup_vs_previous=speedup_prev,
correct=correct,
key_changes=changes
)
logger.add_result(result)
prev_latency = latency_ms
except Exception as e:
logger.log(f"\n[ERROR] Step {step_idx} failed: {e}")
import traceback
logger.log(traceback.format_exc())
tilegym_output = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
if tilegym_output is not None:
logger.subsection("TileGym Reference (for comparison)")
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
tilegym_speedup = baseline_latency / tilegym_latency if baseline_latency else 1.0
logger.log("TileGym FMHA:")
logger.log(f" Latency: {tilegym_latency:.3f} ms")
logger.log(f" TFLOPS: {tilegym_tflops:.2f}")
logger.log(f" vs Baseline: {tilegym_speedup:.2f}x")
result = BenchmarkResult(
step=99,
name="TileGym Reference",
description="TileGym's optimized FMHA implementation",
latency_ms=tilegym_latency,
tflops=tilegym_tflops,
speedup_vs_baseline=tilegym_speedup,
speedup_vs_previous=1.0,
correct=True,
key_changes=["Full TileGym implementation", "Pre-tuned for sm121", "Production ready"]
)
logger.add_result(result)
def main():
parser = argparse.ArgumentParser(description="FMHA Optimization Tutorial")
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
parser.add_argument("--seq-len", type=int, default=2048, help="Sequence length (default matches TileGym)")
parser.add_argument("--correctness-check", action="store_true", help="Enable correctness checking")
parser.add_argument("--output-dir", type=str, default=".", help="Output directory for logs")
args = parser.parse_args()
logger.section("FMHA OPTIMIZATION TUTORIAL")
logger.log("From Naive to Optimized cuTile Implementation")
logger.log("Target Platform: DGX Spark (sm121)")
logger.log(f"Tile Sizes: TILE_M={TILE_M}, TILE_N={TILE_N} (hardcoded from TileGym)")
logger.log("Note: TileGym supports autotuning, but we use pre-determined optimal values")
run_benchmark(
seq_len=args.seq_len,
iterations=args.iterations,
check_correct=args.correctness_check
)
logger.section("FINAL SUMMARY")
logger.log("\n| Step | Name | Latency (ms) | TFLOPS | vs Baseline | Correct |")
logger.log("|------|------|--------------|--------|-------------|---------|")
for r in logger.results:
logger.log(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {'Yes' if r.correct else 'No'} |")
json_path = f"{args.output_dir}/fmha_tutorial_results.json"
md_path = f"{args.output_dir}/fmha_tutorial_results.md"
log_path = f"{args.output_dir}/fmha_tutorial_log.txt"
logger.export_json(json_path)
logger.export_markdown(md_path)
with open(log_path, 'w') as f:
f.write('\n'.join(logger.logs))
logger.log("\nResults exported to:")
logger.log(f" - {json_path}")
logger.log(f" - {md_path}")
logger.log(f" - {log_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,891 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
FMHA Scaling Analysis: How Optimizations Impact Performance at Different Sizes
This script demonstrates:
1. How FMHA performance scales with sequence length
2. Which optimizations provide the most benefit at larger sizes
3. Target-specific configurations for different GPU architectures
Target Platforms (from TileGym):
- DGX Spark (sm120/sm121): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
- Blackwell B300 (sm100): TILE_M=256, TILE_N=128 or 128x128, num_ctas=1, occupancy=1-2
Usage:
python fmha_scaling_analysis.py [--iterations N]
"""
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from typing import List
from types import SimpleNamespace
import torch
LOG_SEPARATOR = "=" * 80
LOG_SUBSEPARATOR = "-" * 60
@dataclass
class StepResult:
step: int
name: str
latency_ms: float
tflops: float
speedup_vs_baseline: float
@dataclass
class SeqLenResult:
seq_len: int
steps: List[StepResult]
best_step: int
best_speedup: float
tilegym_latency_ms: float
tilegym_tflops: float
tilegym_speedup: float
class Logger:
def __init__(self):
self.results: List[SeqLenResult] = []
self.logs: List[str] = []
def log(self, msg: str):
print(msg)
self.logs.append(msg)
def section(self, title: str):
self.log(f"\n{LOG_SEPARATOR}")
self.log(f" {title}")
self.log(LOG_SEPARATOR)
def subsection(self, title: str):
self.log(f"\n{LOG_SUBSEPARATOR}")
self.log(f" {title}")
self.log(LOG_SUBSEPARATOR)
logger = Logger()
BATCH = 4
N_HEADS = 32
HEAD_DIM = 128
INV_LOG_2 = 1.0 / math.log(2)
SEQ_LENS = [1024, 2048, 4096, 8192, 16384]
def get_fmha_config():
"""
Get target-specific FMHA configuration (from TileGym attention.py)
Returns configs matching TileGym's _fmha_autotune_configs():
- sm120/sm121 (DGX Spark): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
- sm100 (Blackwell B300): Two configs to try via autotuning
"""
gpu_capability = torch.cuda.get_device_capability()
if gpu_capability in [(12, 0), (12, 1)]:
return [
SimpleNamespace(
name="DGX Spark (sm121)",
TILE_M=64,
TILE_N=64,
num_ctas=1,
occupancy=2
)
]
else:
return [
SimpleNamespace(
name="Blackwell B300 (sm100) - Config 1",
TILE_M=256,
TILE_N=128,
num_ctas=1,
occupancy=1
),
SimpleNamespace(
name="Blackwell B300 (sm100) - Config 2",
TILE_M=128,
TILE_N=128,
num_ctas=1,
occupancy=2
),
]
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
raise RuntimeError("CUDA not available")
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
return total_flops
def benchmark_fn(fn, warmup=10, iterations=100):
"""Benchmark using triton's do_bench_cudagraph for accurate timing"""
try:
import triton
# Use triton's cudagraph benchmark - same as TileGym
ms = triton.testing.do_bench_cudagraph(fn)
return ms
except (ImportError, Exception):
# Fallback to manual timing
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / iterations * 1000
def reference_fmha(q, k, v, sm_scale, is_causal=True):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
)
try:
import cuda.tile as ct
from cuda.tile import RoundingMode as RMd
CUTILE_AVAILABLE = True
except ImportError:
CUTILE_AVAILABLE = False
logger.log("[WARN] cuTile not available.")
if CUTILE_AVAILABLE:
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_basic(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 1: Basic cuTile - no optimizations"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True)
m_ij = ct.maximum(m_i, m_ij)
qk = qk - m_ij
p = ct.exp(qk)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel()
def fmha_math_opt(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 2: Math optimizations - exp2 + flush_to_zero"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel()
def fmha_memory_opt(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 3: Memory optimizations - load order + latency hints"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel(occupancy=2)
def fmha_full_opt_occ2(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 4a: Full optimization with occupancy=2 (for sm120/sm121)"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel(occupancy=1)
def fmha_full_opt_occ1(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 4b: Full optimization with occupancy=1 (for sm100 Blackwell)"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_kernel(kernel_fn, q, k, v, sm_scale, tile_m, tile_n, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / tile_m), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
kernel_fn,
(q, k, v, o, sm_scale, head_dim, num_heads, tile_m, tile_n, is_causal)
)
return o
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
try:
import tilegym
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
except ImportError:
return None
def run_scaling_analysis(iterations=100):
device = get_device()
gpu_cap = torch.cuda.get_device_capability()
gpu_name = torch.cuda.get_device_name()
configs = get_fmha_config()
primary_cfg = configs[0]
TILE_M = primary_cfg.TILE_M
TILE_N = primary_cfg.TILE_N
logger.section("FMHA SCALING ANALYSIS (TileGym Benchmark Match)")
logger.log("Matching TileGym bench_fused_attention.py configuration")
logger.log(f"\nGPU: {gpu_name} (sm_{gpu_cap[0]}{gpu_cap[1]})")
logger.subsection("TARGET-SPECIFIC CONFIGURATION (from TileGym)")
for cfg in configs:
logger.log(f"\n {cfg.name}:")
logger.log(f" TILE_M={cfg.TILE_M}, TILE_N={cfg.TILE_N}")
logger.log(f" num_ctas={cfg.num_ctas}, occupancy={cfg.occupancy}")
logger.log(f"\nUsing primary config: TILE_M={TILE_M}, TILE_N={TILE_N}, occupancy={primary_cfg.occupancy}")
logger.log("\nTest Configuration (matches TileGym bench_fused_attention.py):")
logger.log(f" Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}")
logger.log(" Causal: True, Precision: float16")
logger.log(f" Sequence Lengths: {SEQ_LENS}")
logger.log(" Benchmark: triton.testing.do_bench_cudagraph (same as TileGym)")
logger.section("OPTIMIZATION STEPS")
logger.log(f"""
Step 0: PyTorch Baseline
- torch.nn.functional.scaled_dot_product_attention
- Uses cuDNN Flash Attention backend
- Highly optimized reference
Step 1: Basic cuTile (TILE_M={TILE_M}, TILE_N={TILE_N})
- @ct.kernel with ct.mma() for Tensor Cores
- Standard exp() for softmax
- Explicit transpose with ct.permute()
- No memory/occupancy hints
Step 2: Math Optimizations
- ct.exp2() instead of ct.exp() (faster on GPU)
- flush_to_zero=True for denormals
- Scale adjustment: multiply by 1/log(2)
Step 3: Memory Optimizations
- Load order=(0,1,3,2) for implicit K transpose
- Latency hints: K=2, V=4 for prefetching
- Overlaps memory loads with computation
Step 4: Full Optimization (Target-Specific)
- @ct.kernel(occupancy={primary_cfg.occupancy}) for {'sm120/121' if primary_cfg.occupancy == 2 else 'sm100'}
- ct.truediv with APPROX rounding mode
- Matches TileGym production implementation
""")
logger.section("PLATFORM DIFFERENCES: DGX Spark vs Blackwell B300")
logger.log("""
| Parameter | DGX Spark (sm121) | Blackwell B300 (sm100) |
|--------------|-------------------|------------------------|
| TILE_M | 64 | 256 or 128 |
| TILE_N | 64 | 128 |
| num_ctas | 1 | 1 |
| occupancy | 2 | 1 or 2 |
Why the difference?
- B300 has more SMs and larger shared memory -> can use bigger tiles
- B300 benefits from larger tiles (256x128) with lower occupancy
- DGX Spark needs smaller tiles (64x64) with higher occupancy to hide latency
- B300's higher memory bandwidth makes larger tiles more efficient
""")
all_results = []
select_kernel = fmha_full_opt_occ2 if primary_cfg.occupancy == 2 else fmha_full_opt_occ1
for seq_len in SEQ_LENS:
logger.subsection(f"Sequence Length: {seq_len}")
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
steps_results = []
baseline_fn = lambda: reference_fmha(q, k, v, sm_scale, is_causal=True)
baseline_latency = benchmark_fn(baseline_fn, warmup=10, iterations=iterations)
baseline_tflops = flops * 1e-12 / (baseline_latency * 1e-3)
steps_results.append(StepResult(0, "PyTorch Baseline", baseline_latency, baseline_tflops, 1.0))
if CUTILE_AVAILABLE:
kernels = [
(1, "Basic cuTile", fmha_basic),
(2, "Math Opt (exp2)", fmha_math_opt),
(3, "Memory Opt (order+latency)", fmha_memory_opt),
(4, f"Full Opt (occ={primary_cfg.occupancy})", select_kernel),
]
for step, name, kernel in kernels:
try:
fn = lambda kernel=kernel: run_kernel(kernel, q, k, v, sm_scale, TILE_M, TILE_N, is_causal=True)
latency = benchmark_fn(fn, warmup=10, iterations=iterations)
tflops = flops * 1e-12 / (latency * 1e-3)
speedup = baseline_latency / latency
steps_results.append(StepResult(step, name, latency, tflops, speedup))
except Exception as e:
logger.log(f" [ERROR] Step {step} failed: {e}")
tilegym_latency = 0.0
tilegym_tflops = 0.0
tilegym_speedup = 0.0
tilegym_out = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
if tilegym_out is not None:
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
tilegym_speedup = baseline_latency / tilegym_latency
best_step = max(steps_results, key=lambda x: x.speedup_vs_baseline)
result = SeqLenResult(
seq_len=seq_len,
steps=steps_results,
best_step=best_step.step,
best_speedup=best_step.speedup_vs_baseline,
tilegym_latency_ms=tilegym_latency,
tilegym_tflops=tilegym_tflops,
tilegym_speedup=tilegym_speedup,
)
all_results.append(result)
logger.log("\n | Step | Name | Latency (ms) | TFLOPS | Speedup |")
logger.log(" |------|------|--------------|--------|---------|")
for sr in steps_results:
logger.log(f" | {sr.step} | {sr.name:<28} | {sr.latency_ms:>10.3f} | {sr.tflops:>6.2f} | {sr.speedup_vs_baseline:>6.2f}x |")
if tilegym_latency > 0:
logger.log(f" | TG | TileGym Reference | {tilegym_latency:>10.3f} | {tilegym_tflops:>6.2f} | {tilegym_speedup:>6.2f}x |")
logger.log(f"\n Best: Step {best_step.step} ({best_step.name}) with {best_step.speedup_vs_baseline:.2f}x speedup")
logger.results = all_results
return all_results
def print_summary(results: List[SeqLenResult]):
configs = get_fmha_config()
primary_cfg = configs[0]
logger.section("SCALING SUMMARY")
logger.log(f"\nTarget Config: TILE_M={primary_cfg.TILE_M}, TILE_N={primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}")
logger.log("\n## Performance vs Sequence Length\n")
logger.log("| Seq Len | Baseline (ms) | Full Opt (ms) | Speedup | TileGym (ms) | TG Speedup |")
logger.log("|---------|---------------|---------------|---------|--------------|------------|")
for r in results:
baseline = next((s for s in r.steps if s.step == 0), None)
full_opt = next((s for s in r.steps if s.step == 4), None)
if baseline and full_opt:
logger.log(f"| {r.seq_len:>7} | {baseline.latency_ms:>13.3f} | {full_opt.latency_ms:>13.3f} | {full_opt.speedup_vs_baseline:>6.2f}x | {r.tilegym_latency_ms:>12.3f} | {r.tilegym_speedup:>9.2f}x |")
logger.log("\n## Optimization Impact by Sequence Length\n")
logger.log("| Seq Len | Basic | +Math | +Memory | +Full | Best |")
logger.log("|---------|-------|-------|---------|-------|------|")
for r in results:
row = f"| {r.seq_len:>7} |"
for step in [1, 2, 3, 4]:
sr = next((s for s in r.steps if s.step == step), None)
if sr:
row += f" {sr.speedup_vs_baseline:>5.2f}x |"
else:
row += " N/A |"
row += f" {r.best_speedup:>4.2f}x |"
logger.log(row)
logger.section("KEY INSIGHTS")
logger.log("""
## Why Larger Sequences Benefit More from Optimization
1. **Memory Bandwidth Dominance**
- Attention has O() memory complexity for the QK^T matrix
- At seq_len=8192: 8192² × 4 bytes = 256MB per head per batch
- Memory optimizations (order, latency hints) have larger impact
2. **More K-Loop Iterations**
- At seq_len=512: 8 K-tiles (512/64) for sm121, 2 K-tiles (512/256) for sm100
- At seq_len=8192: 128 K-tiles for sm121, 32 K-tiles for sm100
- Latency hiding through pipelining amortizes over more iterations
3. **Better Occupancy Utilization**
- More tiles = more parallelism opportunities
- sm121 uses occupancy=2 (smaller tiles, more blocks)
- sm100 uses occupancy=1 with larger tiles (256x128)
4. **Platform-Specific Tuning**
- DGX Spark (sm121): 64x64 tiles, occupancy=2 - optimized for bandwidth-limited workloads
- B300 (sm100): 256x128 tiles, occupancy=1 - optimized for compute-heavy workloads
## Optimization Priority by Problem Size
Small (seq_len <= 1024):
- Basic cuTile often sufficient
- Focus on correctness first
Medium (1024 < seq_len <= 4096):
- Math optimizations (exp2) provide ~5% gain
- Memory optimizations start to matter
Large (seq_len > 4096):
- Full optimization stack critical
- Platform-specific tuning essential
- Memory pipelining becomes essential
""")
def export_results(results: List[SeqLenResult], output_dir: str):
configs = get_fmha_config()
primary_cfg = configs[0]
data = {
"config": {
"batch": BATCH,
"n_heads": N_HEADS,
"head_dim": HEAD_DIM,
"tile_m": primary_cfg.TILE_M,
"tile_n": primary_cfg.TILE_N,
"occupancy": primary_cfg.occupancy,
"num_ctas": primary_cfg.num_ctas,
"platform": primary_cfg.name,
},
"results": [
{
"seq_len": r.seq_len,
"steps": [asdict(s) for s in r.steps],
"best_step": r.best_step,
"best_speedup": r.best_speedup,
"tilegym_latency_ms": r.tilegym_latency_ms,
"tilegym_tflops": r.tilegym_tflops,
"tilegym_speedup": r.tilegym_speedup,
}
for r in results
]
}
json_path = f"{output_dir}/fmha_scaling_results.json"
with open(json_path, 'w') as f:
json.dump(data, f, indent=2)
md_path = f"{output_dir}/fmha_scaling_results.md"
with open(md_path, 'w') as f:
f.write("# FMHA Scaling Analysis Results\n\n")
f.write("## Configuration\n")
f.write(f"- Platform: {primary_cfg.name}\n")
f.write(f"- Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}\n")
f.write(f"- Tile: {primary_cfg.TILE_M}x{primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}\n\n")
f.write("## Target-Specific Configs (from TileGym)\n\n")
f.write("| Platform | TILE_M | TILE_N | num_ctas | occupancy |\n")
f.write("|----------|--------|--------|----------|----------|\n")
f.write("| DGX Spark (sm121) | 64 | 64 | 1 | 2 |\n")
f.write("| B300 (sm100) Config 1 | 256 | 128 | 1 | 1 |\n")
f.write("| B300 (sm100) Config 2 | 128 | 128 | 1 | 2 |\n\n")
f.write("## Results by Sequence Length\n\n")
for r in results:
f.write(f"### Seq Len = {r.seq_len}\n\n")
f.write("| Step | Name | Latency (ms) | TFLOPS | Speedup |\n")
f.write("|------|------|--------------|--------|--------|\n")
for s in r.steps:
f.write(f"| {s.step} | {s.name} | {s.latency_ms:.3f} | {s.tflops:.2f} | {s.speedup_vs_baseline:.2f}x |\n")
if r.tilegym_latency_ms > 0:
f.write(f"| TG | TileGym Reference | {r.tilegym_latency_ms:.3f} | {r.tilegym_tflops:.2f} | {r.tilegym_speedup:.2f}x |\n")
f.write("\n")
log_path = f"{output_dir}/fmha_scaling_log.txt"
with open(log_path, 'w') as f:
f.write('\n'.join(logger.logs))
logger.log(f"\nResults exported to:")
logger.log(f" - {json_path}")
logger.log(f" - {md_path}")
logger.log(f" - {log_path}")
def main():
parser = argparse.ArgumentParser(description="FMHA Scaling Analysis")
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
args = parser.parse_args()
results = run_scaling_analysis(iterations=args.iterations)
print_summary(results)
export_results(results, args.output_dir)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,105 @@
# Register DGX Spark to Brev
> Link your DGX Spark to Brev for remote access and shared environments
## Table of Contents
- [Overview](#overview)
- [Instructions](#instructions)
- [Troubleshooting](#troubleshooting)
---
## Overview
## Basic idea
NVIDIA Brev is an AI development platform that makes GPU environments remotely accessible, shareable, and easy to standardize using preconfigured setups called Launchables.
This walkthrough will help you connect your NVIDIA DGX Spark to Brev so it shows up as a managed GPU environment in Brev. After a one-time registration, your Spark becomes remotely accessible and shareable.
## What you'll accomplish
Youll register your DGX Spark with Brev and it will be visible as a healthy node in the Brev web UI and CLI, ready to share access and accept workloads whenever needed.
## What to know before starting
While Brev automates the complex configuration, understanding a few key concepts when establishing the initial connection will be useful:
* **Terminal Basics**:
* Familiarity with the command line to run a few simple setup commands
## Prerequisites
Your DGX Spark [device is set up](https://docs.nvidia.com/dgx/dgx-spark/first-boot.html). You will also need the following:
* **Brev Account**:
* Have an NVIDIA Brev account. Create one [here](https://login.brev.nvidia.com/signin) if you dont have one.
* **Permissions**:
* You have administrative (root or sudo) access on the DGX Spark device to run the registration command.
## Time & risk
* **Estimated time:** 5-10 minutes
* **Risk level:** Low - Registration configures the Spark for secure remote access without altering your existing workloads
* **Rollback:** The Brev configuration can be removed through the UI and CLI
## Instructions
## Step 1. Log in to Brev
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm youre in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
Click the “Register Compute” button and follow the instructions in the pop-up window.
## Step 2. Complete Popup Instructions
* Install the Brev CLI
* Configure your compute
* Add a name for compute
* To configure ssh, ensure the “Enable SSH access” toggle is on
* Run the registration command
## Step 3. Follow Registration Flow
In the CLI, youll be walked through registration. Go through the flow until registration is complete.
## Step 4. Confirm Spark in Brev UI
* Go to the [Brev UI](https://brev.nvidia.com)
* Navigate to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute)
* Confirm that the DGX Spark appears as a registered node with an **Available** status
## Step 5. Next Steps
Your Spark is now integrated into Brev as a secure, remotely accessible GPU environment.
Now that your hardware is connected, you can:
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI under [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
## Step 6. Cleanup
If you ever decide to unregister your Spark with Brev, you can either do so through the Brev UI or the Brev CLI.
With the CLI simply run:
```bash
brev deregister
```
In the UI:
* Go to the [Brev UI](https://brev.nvidia.com)
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
* Click the “Deregister” menu item on the Spark you wish to delete from Brev.
* Confirm your selection.
## Troubleshooting
| Symptom | Cause | Fix |
|---------|-------|-----|
| Your DGX Spark is showing up in the wrong org | You registered your DGX Spark to the wrong org | Run `brev set <my-org>` and then redo the registration |
| Unable to `brev shell <name>` | Need to refresh | `brev refresh` |
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).