Compare commits

..

1 Commits

Author SHA1 Message Date
Ramzey Ghanaim
cda3f97231
Merge 050f799875 into 3eff7461e1 2026-06-02 15:40:43 -07:00
11 changed files with 141 additions and 4808 deletions

View File

@ -21,13 +21,11 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
### NVIDIA
- [CLI Coding Agent](nvidia/cli-coding-agent/)
- [Comfy UI](nvidia/comfy-ui/)
- [Connect Three DGX Spark in a Ring Topology](nvidia/connect-three-sparks/)
- [Set Up Local Network Access](nvidia/connect-to-your-spark/)
- [Connect Two Sparks](nvidia/connect-two-sparks/)
- [CUDA-X Data Science](nvidia/cuda-x-data-science/)
- [cuTile Kernels](nvidia/cutile-kernels/)
- [DGX Dashboard](nvidia/dgx-dashboard/)
- [FLUX.1 Dreambooth LoRA Fine-tuning](nvidia/flux-finetuning/)
- [Run Hermes Agent with Local Models](nvidia/hermes-agent/)
@ -43,7 +41,6 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
- [NCCL for Two Sparks](nvidia/nccl/)
- [Fine-tune with NeMo](nvidia/nemo-fine-tune/)
- [Run NemoClaw with a Local LLM](nvidia/nemoclaw/)
- [🦞 Set Up Example NemoClaw Agents 🦞](nvidia/nemoclaw-applications/)
- [Nemotron-3-Nano with llama.cpp](nvidia/nemotron/)
- [NIM on Spark](nvidia/nim-llm/)
- [NVFP4 Quantization](nvidia/nvfp4-quantization/)
@ -55,7 +52,6 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
- [Fine-tune with Pytorch](nvidia/pytorch-fine-tune/)
- [RAG Application in AI Workbench](nvidia/rag-ai-workbench/)
- [Spark & Reachy Photo Booth](nvidia/reachy-photo-booth/)
- [Register DGX Spark to Brev](nvidia/register-to-brev/)
- [SGLang for Inference](nvidia/sglang/)
- [Single-cell RNA Sequencing](nvidia/single-cell/)
- [Speculative Decoding](nvidia/speculative-decoding/)

View File

@ -1,474 +0,0 @@
# CLI Coding Agent
> Build local CLI coding agents with Ollama
## Table of Contents
- [Overview](#overview)
- [Claude Code](#claude-code)
- [OpenCode](#opencode)
- [Codex CLI](#codex-cli)
- [Troubleshooting](#troubleshooting)
---
## Overview
## Basic idea
Use [Ollama](https://ollama.com) on [DGX Spark](https://www.nvidia.com/en-us/products/workstations/dgx-spark/) to run a local coding model and connect a CLI coding agent. This
playbook supports three options: **[Claude Code](https://docs.claude.com/en/docs/claude-code)**, **[OpenCode](https://opencode.ai)**, and **[Codex CLI](https://github.com/openai/codex)**. Each
agent is wired up with Ollama's built-in [launch method](https://ollama.com/blog/launch) (`ollama launch <agent>`), so you
can work without environment variables, provider config files, or external cloud APIs.
## Choose your CLI agent
Pick the tab that matches the CLI agent you want to use:
- **Claude Code**: Fastest path to a working CLI agent with a local Ollama model.
- **OpenCode**: Open-source CLI launched directly from Ollama.
- **Codex CLI**: OpenAI Codex CLI launched directly from Ollama against the local model.
## What you'll accomplish
You will run a local coding model ([Qwen3.6](https://ollama.com/library/qwen3.6)) on your DGX Spark with Ollama, launch your
chosen CLI agent against it with a single command, and complete a small coding task end-to-end.
## What to know before starting
- Comfort with Linux command line basics
- Experience running terminal-based tools and editors
- Familiarity with Python for the short coding task
## Prerequisites
- DGX Spark access with NVIDIA DGX OS 7.3.1 (Ubuntu 24.04.3 LTS base)
- Internet access to download model weights
- [Ollama](https://ollama.com/download) v0.15 or newer (required for [`ollama launch`](https://ollama.com/blog/launch))
- GPU memory depends on the Qwen3.6 variant you choose:
- `qwen3.6:latest` (35B-a3b, MoE) — ~24GB, 256K context
- `qwen3.6:35b-a3b-nvfp4` — ~22GB, NVIDIA FP4 build tuned for Blackwell (DGX Spark)
- `qwen3.6:35b-a3b-q8_0` — ~39GB, higher-quality quant
- `qwen3.6:35b-a3b-bf16` — ~71GB, full precision (fits Spark's unified memory)
## Time & risk
* **Duration**: ~15-25 minutes (mostly model download time)
* **Risk level**: Low
* Large model downloads can fail if network connectivity is unstable
* Ollama versions older than 0.15 do not support `ollama launch`
* **Rollback**: Stop Ollama and delete the downloaded model from `~/.ollama/models`
* **Last Updated:** 04/16/2026
* Switched to `ollama launch` method and upgraded the default model to Qwen3.6
## Claude Code
## Step 1. Confirm your environment
**Description**: Verify the OS version and GPU are visible before installing anything.
```bash
cat /etc/os-release | head -n 2
nvidia-smi
```
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
## Step 2. Install or update Ollama
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
```bash
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
```
If Ollama is already installed, just verify the version:
```bash
ollama --version
```
Expected output should show Ollama v0.15 or newer.
## Step 3. Pull Qwen3.6
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
```bash
ollama pull qwen3.6
```
Optional variants if you want different memory footprints or precision:
```bash
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
```
Expected output should show `qwen3.6` (and any optional variants) in `ollama list`.
## Step 4. Test local inference (optional)
**Description**: Run a quick prompt to confirm the model loads.
```bash
ollama run qwen3.6
```
Try a prompt like:
```text
Write a short README checklist for a Python project.
```
Expected output should show the model responding in the terminal. When you are done, type `/bye` or press `Ctrl+D` to exit the interactive session before continuing.
## Step 5. Launch Claude Code with Ollama
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Claude Code](https://docs.claude.com/en/docs/claude-code) against your local model. No environment variables or config files are required.
```bash
ollama launch claude
```
Expected output should show Claude Code starting and using the local Qwen3.6 model. Qwen3.6 ships with a 256K context window by default; adjust context length through Ollama's settings if you need to tune it further.
## Step 6. Complete a small coding task
**Description**: Create a tiny repo and let Claude Code implement a function and tests.
```bash
mkdir -p ~/cli-agent-demo
cd ~/cli-agent-demo
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
```
If you do not already have pytest installed:
```bash
python3 -m pip install -U pytest
```
In Claude Code:
```text
Please implement add() in math_utils.py and make sure the test passes.
```
Run the test:
```bash
python3 -m pytest -q
```
Expected output should show the test passing.
## Step 7. Cleanup and rollback
**Description**: Remove the model and stop services if you no longer need them.
To stop the service:
```bash
sudo systemctl stop ollama
```
> [!WARNING]
> This will delete the downloaded model files.
```bash
ollama rm qwen3.6
```
## Step 8. Next steps
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
- Use Claude Code on multi-file refactors or test-generation tasks
- Explore the full 256K context window on larger codebases
## OpenCode
## Step 1. Confirm your environment
**Description**: Verify the OS version and GPU are visible before installing anything.
```bash
cat /etc/os-release | head -n 2
nvidia-smi
```
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
## Step 2. Install or update Ollama
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
```bash
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
```
If Ollama is already installed, just verify the version:
```bash
ollama --version
```
Expected output should show Ollama v0.15 or newer.
## Step 3. Pull Qwen3.6
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
```bash
ollama pull qwen3.6
```
Optional variants if you want different memory footprints or precision:
```bash
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
```
Expected output should show `qwen3.6` in `ollama list`.
## Step 4. Test local inference (optional)
**Description**: Run a quick prompt to confirm the model loads.
```bash
ollama run qwen3.6
```
Try a prompt like:
```text
Write a short README checklist for a Python project.
```
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
## Step 5. Launch OpenCode with Ollama
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [OpenCode](https://opencode.ai) against your local model. No [`opencode.json`](https://opencode.ai/docs/config/) provider configuration is required.
```bash
ollama launch opencode
```
If you want to pre-configure OpenCode without launching immediately:
```bash
ollama launch opencode --config
```
Expected output should show OpenCode starting with Ollama preselected as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default.
## Step 6. Complete a small coding task
**Description**: Create a tiny repo and let OpenCode implement a function and tests.
```bash
mkdir -p ~/cli-agent-demo
cd ~/cli-agent-demo
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
```
If you do not already have pytest installed:
```bash
python3 -m pip install -U pytest
```
In OpenCode:
```text
Please implement add() in math_utils.py and make sure the test passes.
```
Run the test:
```bash
python3 -m pytest -q
```
Expected output should show the test passing.
## Step 7. Cleanup and rollback
**Description**: Remove the model and stop services if you no longer need them.
To stop the service:
```bash
sudo systemctl stop ollama
```
> [!WARNING]
> This will delete the downloaded model files.
```bash
ollama rm qwen3.6
```
## Step 8. Next steps
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
- Use OpenCode on multi-file changes or test-generation tasks
- Explore the full 256K context window on larger codebases
## Codex CLI
## Step 1. Confirm your environment
**Description**: Verify the OS version and GPU are visible before installing anything.
```bash
cat /etc/os-release | head -n 2
nvidia-smi
```
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
## Step 2. Install or update Ollama
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
```bash
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
```
If Ollama is already installed, just verify the version:
```bash
ollama --version
```
Expected output should show Ollama v0.15 or newer.
## Step 3. Pull Qwen3.6
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
```bash
ollama pull qwen3.6
```
Optional variants if you want different memory footprints or precision:
```bash
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
```
Expected output should show `qwen3.6` in `ollama list`.
## Step 4. Test local inference (optional)
**Description**: Run a quick prompt to confirm the model loads.
```bash
ollama run qwen3.6
```
Try a prompt like:
```text
Write a short README checklist for a Python project.
```
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
## Step 5. Launch Codex CLI with Ollama
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Codex CLI](https://github.com/openai/codex) against your local model. No `~/.codex/config.toml` and no manual `npm install -g @openai/codex` are required — Ollama handles the Codex integration.
```bash
ollama launch codex
```
Expected output should show Codex CLI starting with Ollama as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default, which is well suited to Codex's agentic workflows.
## Step 6. Complete a small coding task
**Description**: Create a tiny repo and let Codex implement a function and tests.
```bash
mkdir -p ~/cli-agent-demo
cd ~/cli-agent-demo
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
```
If you do not already have pytest installed:
```bash
python3 -m pip install -U pytest
```
In Codex:
```text
Please implement add() in math_utils.py and make sure the test passes.
```
Run the test:
```bash
python3 -m pytest -q
```
Expected output should show the test passing.
## Step 7. Cleanup and rollback
**Description**: Remove the model and stop services if you no longer need them.
To stop the service:
```bash
sudo systemctl stop ollama
```
> [!WARNING]
> This will delete the downloaded model files.
```bash
ollama rm qwen3.6
```
## Step 8. Next steps
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
- Use Codex CLI on multi-file changes or test-generation tasks
- Explore the full 256K context window on larger codebases
## Troubleshooting
| Symptom | Cause | Fix |
|---------|-------|-----|
| `ollama: command not found` | Ollama not installed or PATH not updated | Rerun `curl -fsSL https://ollama.com/install.sh \| sh` and open a new shell |
| `ollama launch` reports unknown command | Ollama is older than v0.15 | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
| Model load fails with version error or HTTP 412 | Ollama version is too old for the model | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
| `model not found` when launching an agent | Model was not pulled | Run `ollama pull qwen3.6` and retry |
| `connection refused` to localhost:11434 | Ollama service not running | Start with `ollama serve` or `sudo systemctl start ollama` |
| `ollama launch <agent>` exits immediately | Agent integration failed to initialize | Re-run `ollama launch <agent>`; if it persists, check `journalctl -u ollama` |
| Slow responses or OOM errors | Model variant too large for GPU memory | Switch to `qwen3.6:35b-a3b-nvfp4` or close other GPU workloads |
> [!NOTE]
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing
> between the GPU and CPU. If you see memory pressure, flush the buffer cache with:
> ```bash
> sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
> ```

View File

@ -1,859 +0,0 @@
# cuTile Kernels
> Run cuTile kernel benchmarks, FMHA implementation, and LLM inference on DGX Spark and B300
## Table of Contents
- [Overview](#overview)
- [Kernel Benchmarks](#kernel-benchmarks)
- [End-to-End Inference](#end-to-end-inference)
- [FMHA Implementation](#fmha-implementation)
- [Attention Basics](#attention-basics)
- [Flash Attention Algorithm](#flash-attention-algorithm)
- [cuTile Pseudocode → Actual Mapping](#cutile-pseudocode-actual-mapping)
- [Kernel Pseudocode](#kernel-pseudocode)
- [cuTile Implementation](#cutile-implementation)
- [Launching the Kernel](#launching-the-kernel)
- [Optimizations](#optimizations)
- [Platform Configuration](#platform-configuration)
- [Performance Results](#performance-results)
- [Common Issues](#common-issues)
- [Companion Scripts](#companion-scripts)
- [References](#references)
- [Platform Comparison](#platform-comparison)
- [End-to-End Throughput](#end-to-end-throughput)
- [CUDA Kernel Time](#cuda-kernel-time)
- [cuTile Kernel Breakdown](#cutile-kernel-breakdown)
- [Troubleshooting](#troubleshooting)
---
## Overview
## Basic idea
[TileGym](https://github.com/NVIDIA/TileGym) is NVIDIA's benchmark suite and integration framework for cuTile kernels - high-performance GPU kernels written using the cuTile Python DSL. cuTile compiles to Tile IR, enabling developers to write efficient kernels without low-level CUDA programming.
This playbook covers three workflows:
1. **[Kernel Benchmarks](kernel-benchmarks)** - Run standalone cuTile kernel benchmarks (FMHA, MatMul, RMSNorm, etc.)
2. **[End-to-End Inference](e2e-inference)** - Run LLM inference with cuTile-optimized kernels via monkey-patching
3. **[FMHA Implementation](fmha)** - Step-by-step tutorial building a Flash Multi-Head Attention kernel from pseudocode to optimized cuTile, with companion scripts to run and benchmark
The same cuTile code runs on both DGX Spark (sm_121) and B300 (sm_103) - cuTile JIT compiles to the appropriate GPU architecture automatically.
## What you'll accomplish
- Run the TileGym benchmark suite on DGX Spark
- Run Qwen2-7B or DeepSeek-V2-Lite inference with cuTile-optimized kernels
- Observe performance scaling between DGX Spark and B300
- Build an FMHA kernel step-by-step from pseudocode to optimized cuTile implementation
## What to know before starting
- Basic familiarity with Docker and command-line tools
- Understanding of GPU compute concepts (TFLOPS, memory bandwidth)
- No CUDA programming experience required
- HuggingFace account with access token (for LLM inference)
## Prerequisites
**Hardware Requirements:**
- DGX Spark with Ubuntu 24.04 or B300 cloud instance
- Minimum 16GB GPU memory for LLM inference
- At least 50GB available storage space for model downloads
**Software Requirements:**
- Docker installed and configured: `docker ps`
- CUDA Toolkit 13.x with Tile IR support
- HuggingFace token for model access (LLM inference only)
- Network access for pulling containers and downloading models
Verify Docker is available:
```bash
docker ps
```
If you get a permission error:
```bash
sudo usermod -aG docker $USER
newgrp docker
```
## Kernel support matrix
| Kernel | Category | Data Types | Description |
|--------|----------|------------|-------------|
| **FMHA** | Attention | float16, float8 | Flash Multi-Head Attention |
| **MLA** | Attention | bfloat16, float8 | Multi-head Latent Attention |
| **MLA Decoding** | Attention | float16, float8 | MLA for decode phase |
| **MatMul** | Matrix Ops | float16, float8 | Matrix multiplication |
| **BMM** | Matrix Ops | float16 | Batched matrix multiplication |
| **Group GEMM** | Matrix Ops | float16, float8 | Grouped GEMM for MoE |
| **RMSNorm** | Normalization | float16, bfloat16 | Root mean square normalization |
| **RoPE** | Positional | float16 | Rotary position embedding |
| **SiLU** | Activation | float16, float32 | SiLU activation with multiply |
| **SwiGLU** | Activation | float16, float32 | SwiGLU fused operation |
| **Softmax** | Activation | float16 | Softmax normalization |
| **Dropout** | Regularization | float16, float32 | Dropout forward |
## Model support for LLM inference
| Model | Supported Kernels | Batch Size | Output Tokens | Notes |
|-------|-------------------|------------|---------------|-------|
| **Qwen2-7B** | RoPE, RMSNorm, SwiGLU, FMHA | 16 | 50 | Standard transformer |
| **DeepSeek-V2-Lite** | RoPE, RMSNorm, SiLU, MLA, MoE | 1 | 100 | MLA attention, MoE layers |
## Ancillary files
All required assets can be found in the [TileGym repository](https://github.com/NVIDIA/TileGym).
- `tests/benchmark/run_all.sh` - Run all kernel benchmarks
- `modeling/transformers/bench_qwen.sh` - Qwen2-7B benchmark script
- `modeling/transformers/bench_deepseek.sh` - DeepSeek-V2-Lite benchmark script
- `modeling/transformers/infer.py` - Main inference script with TileGym integration
- [`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py) - FMHA step-by-step optimization tutorial
- [`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py) - FMHA scaling analysis across sequence lengths
## Time & risk
* **Estimated time:** 30-45 minutes (including model download for LLM inference)
* **Risk level:** Low
* Large downloads may fail due to network issues
* First run includes JIT compilation overhead
* **Rollback:** Remove Docker container to undo all changes
* **Last Updated:** February 2026
* First Publication
## Kernel Benchmarks
## Step 1. Pull CUDA NGC container with CTK 13.x
```bash
docker pull nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
```
Launch an interactive session with GPU access:
```bash
docker run --gpus all -it --rm \
-v ~/TileGym:/workspace/TileGym \
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
/bin/bash
```
> [!NOTE]
> The `-v` flag mounts a local directory to persist the TileGym repository. The `--rm` flag automatically removes the container when you exit; omit it if you want to keep the container for later use.
Or if running outside a container, install Tile IR directly:
```bash
## Requires root privileges - run with sudo or as root
sudo apt-get install cuda-tile-ir-13-1 cuda-compiler-13-1
```
## Step 2. Clone TileGym repository
```bash
git clone https://github.com/NVIDIA/TileGym
cd TileGym
pip install .
```
## Step 3. Run benchmark suite
```bash
cd tests/benchmark/
bash run_all.sh
```
> [!NOTE]
> The benchmark runs sequentially to ensure accurate timing results. This may take 10-15 minutes to complete all kernels.
## Step 4. View results
Results show cuTile performance for each kernel and sequence length.
Expected output should look like:
```text
==========================================
Running bench_fused_attention.py...
==========================================
fused-attention-batch4-head32-d128-fwd-causal=True-float16-TFLOPS:
N_CTX CuTile
0 1024.0 58.188262
1 2048.0 80.906892
2 4096.0 86.189532
3 8192.0 88.891086
4 16384.0 89.491869
✓ PASSED: bench_fused_attention.py
```
## Step 5. Run individual benchmarks
To run specific kernel benchmarks:
```bash
## Flash Multi-Head Attention
python bench_fused_attention.py
## Matrix Multiplication
python bench_matrix_multiplication.py
## RMSNorm
python bench_rmsnorm.py
## RoPE
python bench_rope.py
## SwiGLU
python bench_swiglu.py
```
## Step 6. Clean up
Exit the container:
```bash
exit
```
Remove this workflow's containers (if you ran without `--rm`):
```bash
## Preferred: remove only containers from this workflow's image
docker rm $(docker ps -a --filter ancestor=nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 --format '{{.ID}}')
## Alternative: prune all stopped containers (will prompt for confirmation)
## docker container prune
```
Remove the image (optional):
```bash
docker rmi nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
```
## Step 7. Repeat on B300
Repeat Steps 1-6 on B300 hardware to observe scaling. See the **Platform Comparison** tab for expected scaling results.
## End-to-End Inference
## Step 1. Set up environment
If you haven't already, pull the CUDA container and clone TileGym (see **Kernel Benchmarks** tab for details).
First, clone TileGym on the host:
```bash
mkdir -p ~/TileGym
git clone https://github.com/NVIDIA/TileGym ~/TileGym
```
Then launch the container with the repository mounted:
```bash
docker run --gpus all -it --rm \
-v ~/TileGym:/workspace/TileGym \
-v ~/.cache/huggingface:/root/.cache/huggingface \
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
/bin/bash
```
> [!NOTE]
> The `-v ~/.cache/huggingface:/root/.cache/huggingface` mounts your HuggingFace cache to avoid re-downloading models.
Install TileGym inside the container:
```bash
cd /workspace/TileGym
pip install .
```
Set your HuggingFace token for accessing gated models:
```bash
export HF_TOKEN=<your_huggingface_token>
```
> [!WARNING]
> You need a HuggingFace account and access token. Get one at https://huggingface.co/settings/tokens
## Step 2. Run inference benchmark
Navigate to the transformers benchmark directory:
```bash
cd modeling/transformers
```
**Option A: Run Qwen2-7B benchmark**
```bash
./bench_qwen.sh
```
Configuration: Model `Qwen/Qwen2-7B`, Batch size 16, Output length 50 tokens.
**Option B: Run DeepSeek-V2-Lite benchmark**
```bash
./bench_deepseek.sh
```
Configuration: Model `deepseek-ai/DeepSeek-V2-Lite-Chat`, Batch size 1, Output length 100 tokens.
Both scripts run two configurations:
1. **PyTorch baseline** - Standard HuggingFace inference
2. **TileGym cuTile** - With cuTile kernel replacements
## Step 3. View results
**Sample DGX Spark (GB10) Results for Qwen2-7B:**
```text
========================================
Benchmark Results
========================================
Qwen2-7B_naive_bfloat16 | 15.66 tokens/s | 51.10s | 51151.0ms CUDA
Qwen2-7B_cutile_attn | 18.52 tokens/s | 43.20s | 43079.7ms CUDA
========================================
```
**cuTile Kernel Breakdown (DGX Spark - Qwen2):**
| Kernel | CUDA Time (ms) | Calls |
|--------|----------------|-------|
| `fmha_kernel` | 4185.9 | 28 |
| `swiglu_forward_kernel` | 2459.8 | 1400 |
| `attention_decode_kernel_grouped` | 2271.8 | 1372 |
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
| `rope_kernel` | 355.6 | 1400 |
## Step 4. How TileGym monkey-patching works
TileGym replaces PyTorch model operations with cuTile kernels. The snippet below is taken from TileGym's [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py) and invoked from [`modeling/transformers/infer.py`](https://github.com/NVIDIA/TileGym/blob/main/modeling/transformers/infer.py):
```python
from tilegym.transformers import apply_tilegym_kernel_to_qwen2
apply_tilegym_kernel_to_qwen2(
rope=True, # Replace RoPE with cuTile kernel
rms_norm=True, # Replace RMSNorm with cuTile kernel
swiglu=True, # Replace SwiGLU with cuTile kernel
attn=True, # Replace attention with cuTile FMHA
use_cutile=True # Use cuTile backend (vs Triton)
)
```
**Patched Kernels for Qwen2:**
| Kernel | PyTorch Operation | cuTile Replacement |
|--------|-------------------|-------------------|
| `rms_norm_kernel_static_persistent` | `nn.RMSNorm` | Persistent RMSNorm |
| `rope_kernel` | Rotary position embedding | Fused RoPE |
| `fmha_kernel` | `F.scaled_dot_product_attention` | Flash Attention |
| `swiglu_forward_kernel` | SiLU + Mul | Fused SwiGLU |
| `attention_decode_kernel_grouped` | Decode attention | Grouped decode |
**Patched Kernels for DeepSeek-V2:** (see [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py))
```python
from tilegym.transformers import apply_tilegym_kernel_to_deepseek_v2
apply_tilegym_kernel_to_deepseek_v2(
rope=True, # Replace RoPE with cuTile kernel
rms_norm=True, # Replace RMSNorm with cuTile kernel
swiglu=True, # Replace SiLU+Mul with cuTile kernel
attn=True, # Replace MLA attention with cuTile
moe=True, # Replace MoE routing with cuTile
use_cutile=True
)
```
| Kernel | PyTorch Operation | cuTile Replacement |
|--------|-------------------|-------------------|
| `prefill_mla` | MLA prefill attention | Multi-head Latent Attention |
| `_mla_decoding_split_kv` | MLA decode attention | Split-KV decoding |
| `fused_moe_kernel` | MoE expert routing | Fused MoE |
| `group_gemm_kernel` | Expert FFN | Grouped GEMM |
## Step 5. Platform-specific tuning (Advanced)
cuTile exposes two complementary performance-tuning mechanisms:
- **[`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Select different kernel launch parameters per GPU architecture (`sm_<major><minor>`). The compiler picks the value matching the current target at JIT time; if no entry matches, the `default` value is used. See the [Performance Tuning](https://docs.nvidia.com/cuda/cutile-python/performance.html) and [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) pages.
- **`num_ctas`** - Number of Cooperative Thread Arrays (thread blocks) launched per kernel invocation. Tune to the number of SMs on the target GPU.
- **`occupancy`** - Hint for the number of concurrent CTAs the compiler should target per SM. Higher occupancy hides memory latency but increases register/shared-memory pressure. See the [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) documentation.
- **[`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Search a list of candidate values at runtime and pick the fastest configuration. Results are reported via [`cuda.tile.tune.TuningResult`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.TuningResult.html) / [`Measurement`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.Measurement.html).
```python
import cuda.tile as ct
@ct.kernel(
# # num_ctas: how many thread blocks to launch.
# # Use ByTarget to pick an arch-specific value at JIT time.
num_ctas=ct.ByTarget({
"sm_103": 8, # B300 - more SMs, launch more CTAs
"sm_121": 4, # DGX Spark - fewer SMs (48), use fewer CTAs
"default": 1, # Fallback for any other GPU architecture
}),
# # occupancy: hint for concurrent CTAs per SM (latency hiding vs. register pressure).
occupancy=ct.ByTarget({
"sm_103": 16, # B300 - high occupancy, plenty of registers/SMEM
"sm_121": 12, # DGX Spark - moderate occupancy
"default": 8, # Conservative fallback
}),
opt_level=3 # Maximum compiler optimization level
)
def optimized_kernel(A, B, C):
# # Same kernel code works on all platforms;
# # ByTarget swaps in the arch-specific launch params automatically.
...
```
For automatic tuning, use [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search over candidate values and pick the fastest configuration at runtime:
```python
@ct.kernel(
# # autotune: benchmark each value and pick the fastest.
num_ctas=ct.autotune([1, 2, 4, 8, 16]),
occupancy=ct.autotune([8, 12, 16, 24]),
opt_level=3
)
def autotuned_kernel(A, B, C):
...
```
## Step 6. Repeat on B300
Repeat Steps 1-3 on B300 hardware. The **same code runs without modification** - cuTile JIT compiles for sm_103 automatically.
See the **Platform Comparison** tab for detailed scaling results.
## FMHA Implementation
## FMHA Implementation Guide
> [!NOTE]
> This is a guide to understanding FMHA implementation in cuTile, not a complete reference. For comprehensive documentation, see the [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/).
### Attention Basics
Attention allows a neural network to focus on relevant parts of the input. In transformers (GPT, LLaMA, Qwen), each position computes how much to attend to every other position using three vectors:
- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I contain?"
- **Value (V)**: "Here is my content"
```text
Attention(Q, K, V) = softmax(Q × K^T / √d) × V
Shapes:
Q, K, V = [batch, heads, seq_len, head_dim]
Q × K^T = [batch, heads, seq_len, seq_len] # Attention scores
Output = [batch, heads, seq_len, head_dim]
```
For autoregressive models, **causal masking** ensures each token only attends to previous tokens by setting future scores to -infinity before softmax.
### Flash Attention Algorithm
Standard attention materializes a [seq_len × seq_len] matrix (e.g., 2 GB for seq_len=32768). Flash Attention avoids this by processing in tiles with **online softmax**:
```text
m = -infinity # Running maximum
l = 0 # Running sum of exp(x - m)
acc = 0 # Running weighted sum of values
FOR each K,V tile:
scores = Q_tile @ K_tile.T * scale
m_new = max(m, max(scores))
correction = exp(m - m_new)
l = l * correction + sum(exp(scores - m_new))
acc = acc * correction + exp(scores - m_new) @ V_tile
m = m_new
output = acc / l
```
### cuTile Pseudocode → Actual Mapping
| Concept | Pseudocode | cuTile |
|---|---|---|
| Define kernel | `KERNEL fmha(...)` | `@ct.kernel()` |
| Get block ID | `block_x = BLOCK_ID_X` | `bid_x = ct.bid(0)` |
| Create indices | `range(0, N)` | `ct.arange(N, dtype=ct.int32)` |
| Create constant tile | `tile = zeros(M, N)` | `ct.full((M, N), 0.0, dtype)` |
| Load from memory | `tile = LOAD(ptr, shape)` | `ct.load(tensor, index, shape)` |
| Store to memory | `STORE(ptr, tile)` | `ct.store(tensor, index, tile)` |
| Matrix multiply | `C = A @ B + C` | `ct.mma(A, B, C)` |
| Reduction | `max_val = MAX(tile, axis)` | `ct.max(tile, axis, keepdims)` |
### Kernel Pseudocode
```text
KERNEL fmha(Q, K, V, Out, scale, TILE_M, TILE_N):
tile_row = BLOCK_ID_X
batch_head = BLOCK_ID_Y
batch = batch_head // num_heads
head = batch_head % num_heads
m_i = full(TILE_M, -infinity)
l_i = full(TILE_M, 0)
acc = zeros(TILE_M, head_dim)
q = LOAD(Q[batch, head, tile_row*TILE_M : (tile_row+1)*TILE_M, :])
FOR j = 0 to num_k_tiles:
k = LOAD(K[batch, head, j*TILE_N : (j+1)*TILE_N, :])
v = LOAD(V[batch, head, j*TILE_N : (j+1)*TILE_N, :])
scores = MMA(q, transpose(k)) * scale
IF causal AND in_mask_region:
scores = WHERE(valid_mask, scores, -infinity)
m_new = max(m_i, row_max(scores))
correction = exp(m_i - m_new)
p = exp(scores - m_new)
l_i = l_i * correction + row_sum(p)
acc = acc * correction + MMA(p, v)
m_i = m_new
out = acc / l_i
STORE(Out[batch, head, tile_row*TILE_M :, :], out)
```
### cuTile Implementation
```python
import cuda.tile as ct
import math
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_kernel(Q, K, V, Out, qk_scale: float, TILE_D: ConstInt, H: ConstInt,
TILE_M: ConstInt, TILE_N: ConstInt, CAUSAL: ConstBool):
bid_x, bid_y = ct.bid(0), ct.bid(1)
batch_idx, head_idx = bid_y // H, bid_y % H
offs_m = (bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32))[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0),
shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
Tc = ct.cdiv(min((bid_x + 1) * TILE_M, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.mma(q, k_t, ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32))
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
qk = ct.where(offs_m >= offs_n, qk,
ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.maximum(m_i, ct.max(qk, axis=-1, keepdims=True))
qk = qk - m_ij
p = ct.exp(qk)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + ct.sum(p, axis=-1, keepdims=True)
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
acc = ct.mma(p.astype(Q.dtype), v_tile, acc)
m_i = m_ij
acc = (acc / l_i).reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
```
### Launching the Kernel
```python
def run_fmha(q, k, v, sm_scale, is_causal=True):
import torch
TILE_M, TILE_N = 64, 64 # Platform-specific (see below)
batch, num_heads, seq_len, head_dim = q.shape
out = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch * num_heads, 1)
ct.launch(
torch.cuda.current_stream(), grid, fmha_kernel,
(q, k, v, out, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return out
```
### Optimizations
#### exp2 + flush_to_zero
`exp2(x) = 2^x` is faster than `exp(x)` on GPU. Requires scale adjustment by `1/log(2)`.
```python
## Convert natural-exp scale to base-2 so we can use the faster ct.exp2 intrinsic.
## exp(x) == exp2(x / log(2)) == exp2(x * INV_LOG_2).
INV_LOG_2 = 1.0 / math.log(2) # ≈ 1.4427
qk_scale_log2 = qk_scale * INV_LOG_2 # Pre-multiply the softmax scale once
## ... in loop:
## Fuse the running-max update with the scale multiplication.
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
## Subtract the running max for numerical stability (online softmax).
qk = qk * qk_scale_log2 - m_ij
## flush_to_zero=True: flush denormals to 0 -> avoids slow denormal handling on GPU.
p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # Correction factor for previous acc/l_i
```
#### Load Order Transpose
Load K already transposed using `order` parameter, avoiding explicit permute.
```python
## order=(0,1,3,2) swaps the last two axes during the load,
## producing K^T directly in registers -- no extra ct.permute() needed.
## shape is expressed in the transposed layout: (1, 1, TILE_D, TILE_N).
k_t = ct.load(K, index=(..., 0, j), shape=(1,1,TILE_D,TILE_N),
order=(0,1,3,2)).reshape((TILE_D, TILE_N))
```
#### Latency Hints
Prefetch data to overlap memory loads with computation. See the [Performance Tuning docs](https://docs.nvidia.com/cuda/cutile-python/performance.html) for the full list of load/store hints (e.g. `allow_tma`, `latency`).
```python
## latency=N tells the compiler to issue this load N loop iterations in
## advance of its use, so the memory transfer overlaps with the MMA work
## from earlier iterations. Larger latency = deeper software pipeline but
## more register pressure.
k_t = ct.load(K, ..., latency=2) # Prefetch K 2 iterations ahead
v_tile = ct.load(V, ..., latency=4) # Prefetch V 4 iterations ahead (used later in the loop)
```
#### Occupancy
Allow multiple thread blocks per SM to hide memory latency. See the [Execution Model docs](https://docs.nvidia.com/cuda/cutile-python/execution.html) for details on how `occupancy` interacts with registers and shared memory.
```python
## occupancy=N is a hint to the compiler to target N concurrent CTAs per SM.
## Higher occupancy -> more warps available to hide memory latency,
## but constrains the per-CTA register/SMEM budget.
@ct.kernel(occupancy=2) # 2 thread blocks (CTAs) co-resident per SM
def fmha_optimized(...):
```
#### Approximate Division
Use fast approximate division for final normalization.
```python
from cuda.tile import RoundingMode as RMd
## RMd.APPROX -> hardware approximate reciprocal/divide (MUFU), much faster
## than IEEE-compliant division. Safe here because it's the final softmax
## normalization step where a small ULP error is acceptable.
## flush_to_zero=True flushes denormals to 0 to avoid the slow path.
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
```
### Platform Configuration
The same kernel code works on all platforms; only configuration parameters change. Use [`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to select values per architecture, or [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search candidate values automatically.
| Platform | TILE_M | TILE_N | Occupancy | Rationale |
|---|---|---|---|---|
| DGX Spark (sm_121) | 64 | 64 | 2 | Smaller tiles, higher occupancy for 48 SMs |
| B300 (sm_103) | 256 | 128 | 1 | Large tiles maximize HBM3e throughput |
| B300 alternate | 128 | 128 | 2 | Higher occupancy, balanced parallelism |
```python
import cuda.tile as ct
@ct.kernel(
# # TILE_M / TILE_N: rows/cols of the Q and K/V tiles processed per CTA.
# # Larger tiles -> more arithmetic intensity; smaller tiles -> higher occupancy.
# # occupancy: target concurrent CTAs per SM (latency hiding vs. register pressure).
occupancy=ct.ByTarget({
"sm_121": 2, # DGX Spark (48 SMs): 2 CTAs/SM for latency hiding
"sm_100": 1, # B300: larger tiles already saturate the SM
"default": 1, # Conservative fallback for other architectures
}),
opt_level=3 # Maximum compiler optimization level
)
def fmha_kernel(...):
...
```
### Performance Results
> **Note:** PyTorch SDPA is used for correctness verification only, not performance comparison.
#### DGX Spark (sm_121) — Seq 2048
| Step | Optimization | Latency (ms) | TFLOPS |
|---|---|---|---|
| 1 | Basic cuTile | 2.19 | 62.8 |
| 2 | + exp2 | 2.07 | 66.5 |
| 3 | + Load Order | 2.07 | 66.3 |
| 4 | + Latency Hints | 2.07 | 66.5 |
| 5 | + Occupancy=2 | 1.73 | 79.5 |
| 6 | + Approx Div (Final) | 1.69 | 81.1 |
#### B300 (sm_103) — Various Seq Lengths
| Seq Len | Latency (ms) | TFLOPS | vs Spark |
|---|---|---|---|
| 1024 | 0.074 | 465 | 5.7x |
| 2048 | 0.178 | 770 | 9.5x |
| 4096 | 0.550 | 999 | 15.1x |
| 8192 | 1.897 | 1159 | 14.6x |
| 16384 | 7.014 | 1254 | 14.2x |
### Common Issues
| Issue | Solution |
|---|---|
| Shape mismatch in ct.mma | Ensure A is (M,K), B is (K,N), C is (M,N) |
| dtype errors | Use `.astype()` before mma; accumulator should be float32 |
| Incorrect results with causal | Check mask_start calculation and `offs_m >= offs_n` logic |
| Low performance | Try different TILE_M/N, check occupancy, verify latency hints |
### Companion Scripts
The following scripts are included in this playbook and can be run on DGX Spark or B300:
- **[`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py)** — Step-by-step optimization tutorial. Builds the FMHA kernel from basic to fully optimized, matching the progression in this guide.
- **[`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py)** — Scaling analysis across sequence lengths. Benchmarks each optimization level and generates performance data.
```bash
## Run the optimization tutorial (DGX Spark)
python assets/fmha_optimization_tutorial.py --correctness-check
## Run the scaling analysis
python assets/fmha_scaling_analysis.py --iterations 100
```
### References
- [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/)
- [Tile IR Specification](https://docs.nvidia.com/cuda/tile-ir/)
- [TileGym (pre-optimized kernels)](https://github.com/NVIDIA/TileGym)
- [NVIDIA Blog: Tuning Flash Attention for Peak Performance in CUDA Tile](https://developer.nvidia.com/blog/tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile/)
- [Flash Attention Paper](https://arxiv.org/abs/2205.14135)
## Platform Comparison
## DGX Spark vs B300 Performance Comparison
This page summarizes performance scaling between DGX Spark (GB10) and B300 for both kernel benchmarks and end-to-end LLM inference.
## Kernel Benchmark Scaling
Use the ratios below as a reference for how kernel performance scales from DGX Spark (GB10) to B300.
| Kernel | Metric | B300 / GB10 |
|--------|--------|-------------|
| FMHA (causal, 8192) | TFLOPS | 13.7x |
| FMHA (non-causal, 8192) | TFLOPS | 15.1x |
| MatMul (8192) | TFLOPS | 18.9x |
| BMM (batch8, 4096) | TFLOPS | 19.4x |
| Group GEMM (4096) | TFLOPS | 23.9x |
| RMSNorm (4096) | GB/s | 33.1x |
| RoPE (16384) | GB/s | 22.8x |
**Key Observations:**
- Compute-heavy kernels typically scale 14-24x from GB10 to B300
- Memory-bound kernels can scale 20-33x due to HBM bandwidth advantage
## Qwen2-7B Performance
### End-to-End Throughput
| Configuration | DGX Spark | B300 | Platform Speedup |
|---------------|-----------|------|------------------|
| **cuTile** | 18.52 tok/s | 257.33 tok/s | **13.9x** |
### CUDA Kernel Time
| Configuration | DGX Spark | B300 | Platform Speedup |
|---------------|-----------|------|------------------|
| **cuTile** | 43,080 ms | 2,954 ms | **14.6x** |
### cuTile Kernel Breakdown
**DGX Spark (GB10):**
| Kernel | CUDA Time (ms) | Calls |
|--------|----------------|-------|
| `fmha_kernel` | 4,185.9 | 28 |
| `swiglu_forward_kernel` | 2,459.8 | 1,400 |
| `attention_decode_kernel_grouped` | 2,271.8 | 1,372 |
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
| `rope_kernel` | 355.6 | 1,400 |
**B300:**
| Kernel | CUDA Time (ms) | Speedup vs Spark |
|--------|----------------|------------------|
| `fmha_kernel` | 337.9 | 12.4x |
| `swiglu_forward_kernel` | 226.3 | 10.9x |
| `attention_decode_kernel_grouped` | 111.0 | 20.5x |
| `rms_norm_kernel_static_persistent` | 29.7 | 21.4x |
| `rope_kernel` | 16.7 | 21.3x |
**Same code, different architectures** - cuTile JIT compiles for sm_121 (Spark) and sm_103 (B300)
## Platform Specifications
| Specification | DGX Spark (GB10) | B300 |
|---------------|------------------|------|
| Compute Capability | sm_121 (12.1) | sm_103 (10.3) |
| SMs | 48 | 132 |
| Memory | 128 GB LPDDR5x | 192 GB HBM3e |
| Memory Bandwidth | 273 GB/s | 8 TB/s |
## Troubleshooting
| Symptom | Cause | Fix |
|---------|-------|-----|
| `docker: permission denied` | User not in docker group | `sudo usermod -aG docker $USER && newgrp docker` |
| `401 Client Error: Unauthorized` | Missing HuggingFace token | `export HF_TOKEN=<your_token>` |
| `ModuleNotFoundError: tilegym` | TileGym not installed | `cd TileGym && pip install .` |
| `RuntimeError: CUDA out of memory` | Model too large | Reduce batch size or use smaller model |
| `Killed` during model load | Out of system memory | Clear cache: `sync; echo 3 > /proc/sys/vm/drop_caches` |
| Slow first run | JIT compilation | Normal - cuTile compiles kernels on first run |
| `FileNotFoundError: input_prompt_small.txt` | Missing input file | Run from `modeling/transformers` directory |
| `torch.cuda.OutOfMemoryError` | Insufficient GPU memory | Reduce `--batch_size` parameter |
| `ImportError: cuda.tile` | Missing Tile IR | Install: `apt-get install cuda-tile-ir-13-1` |
| Benchmark hangs | GPU busy or locked | Check `nvidia-smi` for other processes |
> [!NOTE]
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing between the GPU and CPU.
> With many applications still updating to take advantage of UMA, you may encounter memory issues even when within
> the memory capacity of DGX Spark. If that happens, manually flush the buffer cache with:
```bash
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
```
> [!TIP]
> First run of cuTile kernels includes JIT compilation overhead. Subsequent runs will be faster as compiled kernels are cached.
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).

View File

@ -1,959 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
FMHA Optimization Tutorial: From Naive to Optimized cuTile Implementation
This script demonstrates step-by-step optimization of Flash Multi-Head Attention
using NVIDIA cuTile, starting from a basic implementation and progressively
adding optimizations until reaching TileGym-level performance.
Target Platform: DGX Spark (sm121) with pre-determined optimal tile sizes.
Note: TileGym supports autotuning, but we use hardcoded values for this tutorial.
Configuration (matches TileGym bench_fused_attention.py):
- Batch: 4, Heads: 32, Head Dim: 128
- Sequence Lengths: 1024, 2048, 4096, 8192, 16384
- Benchmark: triton.testing.do_bench_cudagraph
Usage:
python fmha_optimization_tutorial.py [--iterations N] [--correctness-check]
"""
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from typing import List, Optional
import sys
import torch
LOG_SEPARATOR = "=" * 80
LOG_SUBSEPARATOR = "-" * 60
@dataclass
class BenchmarkResult:
step: int
name: str
description: str
latency_ms: float
tflops: float
speedup_vs_baseline: float
speedup_vs_previous: float
correct: bool
key_changes: List[str]
class Logger:
def __init__(self):
self.results: List[BenchmarkResult] = []
self.logs: List[str] = []
def log(self, msg: str):
print(msg)
self.logs.append(msg)
def section(self, title: str):
self.log(f"\n{LOG_SEPARATOR}")
self.log(f" {title}")
self.log(LOG_SEPARATOR)
def subsection(self, title: str):
self.log(f"\n{LOG_SUBSEPARATOR}")
self.log(f" {title}")
self.log(LOG_SUBSEPARATOR)
def add_result(self, result: BenchmarkResult):
self.results.append(result)
def export_json(self, filepath: str):
data = {
"results": [asdict(r) for r in self.results],
"logs": self.logs
}
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
def export_markdown(self, filepath: str):
with open(filepath, 'w') as f:
f.write("# FMHA Optimization Tutorial Results\n\n")
f.write("## Summary Table\n\n")
f.write("| Step | Name | Latency (ms) | TFLOPS | vs Baseline | vs Previous | Correct |\n")
f.write("|------|------|--------------|--------|-------------|-------------|--------|\n")
for r in self.results:
f.write(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {r.speedup_vs_previous:.2f}x | {'Yes' if r.correct else 'No'} |\n")
f.write("\n## Detailed Steps\n\n")
for r in self.results:
f.write(f"### Step {r.step}: {r.name}\n\n")
f.write(f"**Description**: {r.description}\n\n")
f.write("**Key Changes**:\n")
for change in r.key_changes:
f.write(f"- {change}\n")
f.write(f"\n**Performance**: {r.latency_ms:.3f}ms, {r.tflops:.2f} TFLOPS, {r.speedup_vs_baseline:.2f}x vs baseline\n\n")
logger = Logger()
BATCH = 4
N_HEADS = 32
HEAD_DIM = 128
INV_LOG_2 = 1.0 / math.log(2)
TILE_M = 64
TILE_N = 64
OCCUPANCY = 2
NUM_CTAS = 1
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
raise RuntimeError("CUDA not available")
DEVICE = None
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
return total_flops
def benchmark_fn(fn, warmup=10, iterations=100):
"""Benchmark using triton's do_bench_cudagraph for accurate timing (matches TileGym)"""
try:
import triton
ms = triton.testing.do_bench_cudagraph(fn)
return ms
except (ImportError, Exception):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / iterations * 1000
def verify_correctness(output, reference, atol=1e-2, rtol=1e-2):
try:
torch.testing.assert_close(output, reference, atol=atol, rtol=rtol)
return True
except AssertionError:
max_diff = (output - reference).abs().max().item()
logger.log(f" [WARN] Max difference: {max_diff:.6f}")
return max_diff < 0.1
def reference_fmha(q, k, v, sm_scale, is_causal=True):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
)
def step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True):
return reference_fmha(q, k, v, sm_scale, is_causal)
try:
import cuda.tile as ct
from cuda.tile import RoundingMode as RMd
CUTILE_AVAILABLE = True
except ImportError:
CUTILE_AVAILABLE = False
logger.log("[WARN] cuTile not available. Only PyTorch baseline will run.")
if CUTILE_AVAILABLE:
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_step2_mma(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 2: Basic cuTile FMHA with MMA (Tensor Cores)
- Uses ct.mma() for matrix multiply
- Standard exp() for softmax
- Online softmax algorithm
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True)
m_ij = ct.maximum(m_i, m_ij)
qk = qk - m_ij
p = ct.exp(qk)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step2(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step2_mma,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step3_exp2(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 3: Use exp2 with flush_to_zero for faster math
- exp2(x) = 2^x is faster than exp(x) = e^x on GPU
- Requires scaling adjustment: multiply by 1/log(2)
- flush_to_zero handles denormals efficiently
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step3(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step3_exp2,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step4_load_order(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 4: Optimize K load with order parameter
- Use order=(0,1,3,2) to load K already transposed
- Avoids explicit ct.permute() operation
- Reduces memory traffic
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2)
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step4(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step4_load_order,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel()
def fmha_step5_latency(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 5: Add latency hints for better pipelining
- latency=2 for K load (prefetch)
- latency=4 for V load (more prefetch distance)
- Helps overlap memory loads with computation
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step5(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step5_latency,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel(occupancy=2)
def fmha_step6_occupancy(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 6: Add occupancy hint
- @ct.kernel(occupancy=2) improves SM utilization
- Allows multiple thread blocks per SM
- Better for hiding memory latency
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step6(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step6_occupancy,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
@ct.kernel(occupancy=2)
def fmha_step7_approx_div(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""
Step 7: Use approximate division for final normalization
- ct.truediv with rounding_mode=APPROX is faster
- Acceptable accuracy loss for inference
- This matches TileGym's optimized implementation
"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_step7(q, k, v, sm_scale, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_step7_approx_div,
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return o
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
"""Run TileGym's optimized FMHA for comparison"""
try:
import tilegym
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
except ImportError:
logger.log("[WARN] TileGym not available for comparison")
return None
def run_benchmark(seq_len, iterations=100, check_correct=True):
global DEVICE
DEVICE = get_device()
logger.section(f"FMHA OPTIMIZATION TUTORIAL - SEQ_LEN={seq_len}")
logger.log("Configuration:")
logger.log(f" - Batch: {BATCH}")
logger.log(f" - Heads: {N_HEADS}")
logger.log(f" - Head Dim: {HEAD_DIM}")
logger.log(f" - Sequence Length: {seq_len}")
logger.log(f" - Tile M: {TILE_M}")
logger.log(f" - Tile N: {TILE_N}")
logger.log(" - Precision: float16")
logger.log(" - Causal: True")
logger.log(f" - Iterations: {iterations}")
logger.log(f" - Device: {DEVICE}")
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
ref_output = reference_fmha(q, k, v, sm_scale, is_causal=True)
steps = [
(0, "PyTorch Baseline", "torch.nn.functional.scaled_dot_product_attention",
lambda: step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True),
["PyTorch SDPA with cuDNN backend", "Highly optimized baseline"]),
]
if CUTILE_AVAILABLE:
steps.extend([
(2, "Basic cuTile + MMA", "Tiled FMHA with ct.mma() for Tensor Cores",
lambda: run_step2(q, k, v, sm_scale, is_causal=True),
["@ct.kernel decorator", "ct.mma() for QK and PV products", "Online softmax with exp()"]),
(3, "+ exp2 + flush_to_zero", "Faster exponential math",
lambda: run_step3(q, k, v, sm_scale, is_causal=True),
["ct.exp2() instead of ct.exp()", "flush_to_zero=True for denormals", "qk_scale *= 1/log(2)"]),
(4, "+ Load Order Transpose", "Avoid explicit transpose",
lambda: run_step4(q, k, v, sm_scale, is_causal=True),
["order=(0,1,3,2) for K load", "K loaded already transposed", "Removes ct.permute() call"]),
(5, "+ Latency Hints", "Better memory pipelining",
lambda: run_step5(q, k, v, sm_scale, is_causal=True),
["latency=2 for K load", "latency=4 for V load", "Overlaps loads with compute"]),
(6, "+ Occupancy=2", "Better SM utilization",
lambda: run_step6(q, k, v, sm_scale, is_causal=True),
["@ct.kernel(occupancy=2)", "Multiple blocks per SM", "Hides memory latency"]),
(7, "+ Approx Division (Final)", "Fast final normalization",
lambda: run_step7(q, k, v, sm_scale, is_causal=True),
["ct.truediv with APPROX mode", "Matches TileGym implementation", "Full optimization achieved"]),
])
baseline_latency = None
prev_latency = None
for step_idx, name, desc, fn, changes in steps:
logger.subsection(f"Step {step_idx}: {name}")
logger.log(f"Description: {desc}")
logger.log("Key Changes:")
for change in changes:
logger.log(f" - {change}")
try:
output = fn()
latency_ms = benchmark_fn(fn, warmup=10, iterations=iterations)
tflops = flops * 1e-12 / (latency_ms * 1e-3)
if baseline_latency is None:
baseline_latency = latency_ms
speedup_baseline = 1.0
else:
speedup_baseline = baseline_latency / latency_ms
if prev_latency is None:
speedup_prev = 1.0
else:
speedup_prev = prev_latency / latency_ms
if check_correct and output is not None:
correct = verify_correctness(output, ref_output)
else:
correct = True
logger.log("\nResults:")
logger.log(f" Latency: {latency_ms:.3f} ms")
logger.log(f" TFLOPS: {tflops:.2f}")
logger.log(f" vs Baseline: {speedup_baseline:.2f}x")
logger.log(f" vs Previous: {speedup_prev:.2f}x")
logger.log(f" Correct: {'Yes' if correct else 'No'}")
result = BenchmarkResult(
step=step_idx,
name=name,
description=desc,
latency_ms=latency_ms,
tflops=tflops,
speedup_vs_baseline=speedup_baseline,
speedup_vs_previous=speedup_prev,
correct=correct,
key_changes=changes
)
logger.add_result(result)
prev_latency = latency_ms
except Exception as e:
logger.log(f"\n[ERROR] Step {step_idx} failed: {e}")
import traceback
logger.log(traceback.format_exc())
tilegym_output = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
if tilegym_output is not None:
logger.subsection("TileGym Reference (for comparison)")
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
tilegym_speedup = baseline_latency / tilegym_latency if baseline_latency else 1.0
logger.log("TileGym FMHA:")
logger.log(f" Latency: {tilegym_latency:.3f} ms")
logger.log(f" TFLOPS: {tilegym_tflops:.2f}")
logger.log(f" vs Baseline: {tilegym_speedup:.2f}x")
result = BenchmarkResult(
step=99,
name="TileGym Reference",
description="TileGym's optimized FMHA implementation",
latency_ms=tilegym_latency,
tflops=tilegym_tflops,
speedup_vs_baseline=tilegym_speedup,
speedup_vs_previous=1.0,
correct=True,
key_changes=["Full TileGym implementation", "Pre-tuned for sm121", "Production ready"]
)
logger.add_result(result)
def main():
parser = argparse.ArgumentParser(description="FMHA Optimization Tutorial")
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
parser.add_argument("--seq-len", type=int, default=2048, help="Sequence length (default matches TileGym)")
parser.add_argument("--correctness-check", action="store_true", help="Enable correctness checking")
parser.add_argument("--output-dir", type=str, default=".", help="Output directory for logs")
args = parser.parse_args()
logger.section("FMHA OPTIMIZATION TUTORIAL")
logger.log("From Naive to Optimized cuTile Implementation")
logger.log("Target Platform: DGX Spark (sm121)")
logger.log(f"Tile Sizes: TILE_M={TILE_M}, TILE_N={TILE_N} (hardcoded from TileGym)")
logger.log("Note: TileGym supports autotuning, but we use pre-determined optimal values")
run_benchmark(
seq_len=args.seq_len,
iterations=args.iterations,
check_correct=args.correctness_check
)
logger.section("FINAL SUMMARY")
logger.log("\n| Step | Name | Latency (ms) | TFLOPS | vs Baseline | Correct |")
logger.log("|------|------|--------------|--------|-------------|---------|")
for r in logger.results:
logger.log(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {'Yes' if r.correct else 'No'} |")
json_path = f"{args.output_dir}/fmha_tutorial_results.json"
md_path = f"{args.output_dir}/fmha_tutorial_results.md"
log_path = f"{args.output_dir}/fmha_tutorial_log.txt"
logger.export_json(json_path)
logger.export_markdown(md_path)
with open(log_path, 'w') as f:
f.write('\n'.join(logger.logs))
logger.log("\nResults exported to:")
logger.log(f" - {json_path}")
logger.log(f" - {md_path}")
logger.log(f" - {log_path}")
if __name__ == "__main__":
main()

View File

@ -1,891 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
FMHA Scaling Analysis: How Optimizations Impact Performance at Different Sizes
This script demonstrates:
1. How FMHA performance scales with sequence length
2. Which optimizations provide the most benefit at larger sizes
3. Target-specific configurations for different GPU architectures
Target Platforms (from TileGym):
- DGX Spark (sm120/sm121): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
- Blackwell B300 (sm100): TILE_M=256, TILE_N=128 or 128x128, num_ctas=1, occupancy=1-2
Usage:
python fmha_scaling_analysis.py [--iterations N]
"""
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from typing import List
from types import SimpleNamespace
import torch
LOG_SEPARATOR = "=" * 80
LOG_SUBSEPARATOR = "-" * 60
@dataclass
class StepResult:
step: int
name: str
latency_ms: float
tflops: float
speedup_vs_baseline: float
@dataclass
class SeqLenResult:
seq_len: int
steps: List[StepResult]
best_step: int
best_speedup: float
tilegym_latency_ms: float
tilegym_tflops: float
tilegym_speedup: float
class Logger:
def __init__(self):
self.results: List[SeqLenResult] = []
self.logs: List[str] = []
def log(self, msg: str):
print(msg)
self.logs.append(msg)
def section(self, title: str):
self.log(f"\n{LOG_SEPARATOR}")
self.log(f" {title}")
self.log(LOG_SEPARATOR)
def subsection(self, title: str):
self.log(f"\n{LOG_SUBSEPARATOR}")
self.log(f" {title}")
self.log(LOG_SUBSEPARATOR)
logger = Logger()
BATCH = 4
N_HEADS = 32
HEAD_DIM = 128
INV_LOG_2 = 1.0 / math.log(2)
SEQ_LENS = [1024, 2048, 4096, 8192, 16384]
def get_fmha_config():
"""
Get target-specific FMHA configuration (from TileGym attention.py)
Returns configs matching TileGym's _fmha_autotune_configs():
- sm120/sm121 (DGX Spark): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
- sm100 (Blackwell B300): Two configs to try via autotuning
"""
gpu_capability = torch.cuda.get_device_capability()
if gpu_capability in [(12, 0), (12, 1)]:
return [
SimpleNamespace(
name="DGX Spark (sm121)",
TILE_M=64,
TILE_N=64,
num_ctas=1,
occupancy=2
)
]
else:
return [
SimpleNamespace(
name="Blackwell B300 (sm100) - Config 1",
TILE_M=256,
TILE_N=128,
num_ctas=1,
occupancy=1
),
SimpleNamespace(
name="Blackwell B300 (sm100) - Config 2",
TILE_M=128,
TILE_N=128,
num_ctas=1,
occupancy=2
),
]
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
raise RuntimeError("CUDA not available")
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
return total_flops
def benchmark_fn(fn, warmup=10, iterations=100):
"""Benchmark using triton's do_bench_cudagraph for accurate timing"""
try:
import triton
# Use triton's cudagraph benchmark - same as TileGym
ms = triton.testing.do_bench_cudagraph(fn)
return ms
except (ImportError, Exception):
# Fallback to manual timing
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
fn()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / iterations * 1000
def reference_fmha(q, k, v, sm_scale, is_causal=True):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
)
try:
import cuda.tile as ct
from cuda.tile import RoundingMode as RMd
CUTILE_AVAILABLE = True
except ImportError:
CUTILE_AVAILABLE = False
logger.log("[WARN] cuTile not available.")
if CUTILE_AVAILABLE:
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_basic(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 1: Basic cuTile - no optimizations"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True)
m_ij = ct.maximum(m_i, m_ij)
qk = qk - m_ij
p = ct.exp(qk)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel()
def fmha_math_opt(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 2: Math optimizations - exp2 + flush_to_zero"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
k_tile = k_tile.reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel()
def fmha_memory_opt(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 3: Memory optimizations - load order + latency hints"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = acc / l_i
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel(occupancy=2)
def fmha_full_opt_occ2(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 4a: Full optimization with occupancy=2 (for sm120/sm121)"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
@ct.kernel(occupancy=1)
def fmha_full_opt_occ1(
Q, K, V, Out,
qk_scale: float,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
CAUSAL: ConstBool,
):
"""Step 4b: Full optimization with occupancy=1 (for sm100 Blackwell)"""
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
qk_scale_log2 = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
q = q.reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
m_end = (bid_x + 1) * TILE_M
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_t = ct.load(
K,
index=(batch_idx, head_idx, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2
)
k_t = k_t.reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k_t, qk)
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = offs_m >= offs_n
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
m_ij = ct.maximum(m_i, m_ij)
qk = qk * qk_scale_log2 - m_ij
p = ct.exp2(qk, flush_to_zero=True)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
l_i = l_i * alpha + l_ij
acc = acc * alpha
v_tile = ct.load(
V,
index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
)
v_tile = v_tile.reshape((TILE_N, TILE_D))
p_cast = p.astype(Q.dtype)
acc = ct.mma(p_cast, v_tile, acc)
m_i = m_ij
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_kernel(kernel_fn, q, k, v, sm_scale, tile_m, tile_n, is_causal=True):
batch_size, num_heads, seq_len, head_dim = q.shape
o = torch.empty_like(q)
grid = (math.ceil(seq_len / tile_m), batch_size * num_heads, 1)
ct.launch(
torch.cuda.current_stream(),
grid,
kernel_fn,
(q, k, v, o, sm_scale, head_dim, num_heads, tile_m, tile_n, is_causal)
)
return o
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
try:
import tilegym
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
except ImportError:
return None
def run_scaling_analysis(iterations=100):
device = get_device()
gpu_cap = torch.cuda.get_device_capability()
gpu_name = torch.cuda.get_device_name()
configs = get_fmha_config()
primary_cfg = configs[0]
TILE_M = primary_cfg.TILE_M
TILE_N = primary_cfg.TILE_N
logger.section("FMHA SCALING ANALYSIS (TileGym Benchmark Match)")
logger.log("Matching TileGym bench_fused_attention.py configuration")
logger.log(f"\nGPU: {gpu_name} (sm_{gpu_cap[0]}{gpu_cap[1]})")
logger.subsection("TARGET-SPECIFIC CONFIGURATION (from TileGym)")
for cfg in configs:
logger.log(f"\n {cfg.name}:")
logger.log(f" TILE_M={cfg.TILE_M}, TILE_N={cfg.TILE_N}")
logger.log(f" num_ctas={cfg.num_ctas}, occupancy={cfg.occupancy}")
logger.log(f"\nUsing primary config: TILE_M={TILE_M}, TILE_N={TILE_N}, occupancy={primary_cfg.occupancy}")
logger.log("\nTest Configuration (matches TileGym bench_fused_attention.py):")
logger.log(f" Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}")
logger.log(" Causal: True, Precision: float16")
logger.log(f" Sequence Lengths: {SEQ_LENS}")
logger.log(" Benchmark: triton.testing.do_bench_cudagraph (same as TileGym)")
logger.section("OPTIMIZATION STEPS")
logger.log(f"""
Step 0: PyTorch Baseline
- torch.nn.functional.scaled_dot_product_attention
- Uses cuDNN Flash Attention backend
- Highly optimized reference
Step 1: Basic cuTile (TILE_M={TILE_M}, TILE_N={TILE_N})
- @ct.kernel with ct.mma() for Tensor Cores
- Standard exp() for softmax
- Explicit transpose with ct.permute()
- No memory/occupancy hints
Step 2: Math Optimizations
- ct.exp2() instead of ct.exp() (faster on GPU)
- flush_to_zero=True for denormals
- Scale adjustment: multiply by 1/log(2)
Step 3: Memory Optimizations
- Load order=(0,1,3,2) for implicit K transpose
- Latency hints: K=2, V=4 for prefetching
- Overlaps memory loads with computation
Step 4: Full Optimization (Target-Specific)
- @ct.kernel(occupancy={primary_cfg.occupancy}) for {'sm120/121' if primary_cfg.occupancy == 2 else 'sm100'}
- ct.truediv with APPROX rounding mode
- Matches TileGym production implementation
""")
logger.section("PLATFORM DIFFERENCES: DGX Spark vs Blackwell B300")
logger.log("""
| Parameter | DGX Spark (sm121) | Blackwell B300 (sm100) |
|--------------|-------------------|------------------------|
| TILE_M | 64 | 256 or 128 |
| TILE_N | 64 | 128 |
| num_ctas | 1 | 1 |
| occupancy | 2 | 1 or 2 |
Why the difference?
- B300 has more SMs and larger shared memory -> can use bigger tiles
- B300 benefits from larger tiles (256x128) with lower occupancy
- DGX Spark needs smaller tiles (64x64) with higher occupancy to hide latency
- B300's higher memory bandwidth makes larger tiles more efficient
""")
all_results = []
select_kernel = fmha_full_opt_occ2 if primary_cfg.occupancy == 2 else fmha_full_opt_occ1
for seq_len in SEQ_LENS:
logger.subsection(f"Sequence Length: {seq_len}")
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
steps_results = []
baseline_fn = lambda: reference_fmha(q, k, v, sm_scale, is_causal=True)
baseline_latency = benchmark_fn(baseline_fn, warmup=10, iterations=iterations)
baseline_tflops = flops * 1e-12 / (baseline_latency * 1e-3)
steps_results.append(StepResult(0, "PyTorch Baseline", baseline_latency, baseline_tflops, 1.0))
if CUTILE_AVAILABLE:
kernels = [
(1, "Basic cuTile", fmha_basic),
(2, "Math Opt (exp2)", fmha_math_opt),
(3, "Memory Opt (order+latency)", fmha_memory_opt),
(4, f"Full Opt (occ={primary_cfg.occupancy})", select_kernel),
]
for step, name, kernel in kernels:
try:
fn = lambda kernel=kernel: run_kernel(kernel, q, k, v, sm_scale, TILE_M, TILE_N, is_causal=True)
latency = benchmark_fn(fn, warmup=10, iterations=iterations)
tflops = flops * 1e-12 / (latency * 1e-3)
speedup = baseline_latency / latency
steps_results.append(StepResult(step, name, latency, tflops, speedup))
except Exception as e:
logger.log(f" [ERROR] Step {step} failed: {e}")
tilegym_latency = 0.0
tilegym_tflops = 0.0
tilegym_speedup = 0.0
tilegym_out = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
if tilegym_out is not None:
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
tilegym_speedup = baseline_latency / tilegym_latency
best_step = max(steps_results, key=lambda x: x.speedup_vs_baseline)
result = SeqLenResult(
seq_len=seq_len,
steps=steps_results,
best_step=best_step.step,
best_speedup=best_step.speedup_vs_baseline,
tilegym_latency_ms=tilegym_latency,
tilegym_tflops=tilegym_tflops,
tilegym_speedup=tilegym_speedup,
)
all_results.append(result)
logger.log("\n | Step | Name | Latency (ms) | TFLOPS | Speedup |")
logger.log(" |------|------|--------------|--------|---------|")
for sr in steps_results:
logger.log(f" | {sr.step} | {sr.name:<28} | {sr.latency_ms:>10.3f} | {sr.tflops:>6.2f} | {sr.speedup_vs_baseline:>6.2f}x |")
if tilegym_latency > 0:
logger.log(f" | TG | TileGym Reference | {tilegym_latency:>10.3f} | {tilegym_tflops:>6.2f} | {tilegym_speedup:>6.2f}x |")
logger.log(f"\n Best: Step {best_step.step} ({best_step.name}) with {best_step.speedup_vs_baseline:.2f}x speedup")
logger.results = all_results
return all_results
def print_summary(results: List[SeqLenResult]):
configs = get_fmha_config()
primary_cfg = configs[0]
logger.section("SCALING SUMMARY")
logger.log(f"\nTarget Config: TILE_M={primary_cfg.TILE_M}, TILE_N={primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}")
logger.log("\n## Performance vs Sequence Length\n")
logger.log("| Seq Len | Baseline (ms) | Full Opt (ms) | Speedup | TileGym (ms) | TG Speedup |")
logger.log("|---------|---------------|---------------|---------|--------------|------------|")
for r in results:
baseline = next((s for s in r.steps if s.step == 0), None)
full_opt = next((s for s in r.steps if s.step == 4), None)
if baseline and full_opt:
logger.log(f"| {r.seq_len:>7} | {baseline.latency_ms:>13.3f} | {full_opt.latency_ms:>13.3f} | {full_opt.speedup_vs_baseline:>6.2f}x | {r.tilegym_latency_ms:>12.3f} | {r.tilegym_speedup:>9.2f}x |")
logger.log("\n## Optimization Impact by Sequence Length\n")
logger.log("| Seq Len | Basic | +Math | +Memory | +Full | Best |")
logger.log("|---------|-------|-------|---------|-------|------|")
for r in results:
row = f"| {r.seq_len:>7} |"
for step in [1, 2, 3, 4]:
sr = next((s for s in r.steps if s.step == step), None)
if sr:
row += f" {sr.speedup_vs_baseline:>5.2f}x |"
else:
row += " N/A |"
row += f" {r.best_speedup:>4.2f}x |"
logger.log(row)
logger.section("KEY INSIGHTS")
logger.log("""
## Why Larger Sequences Benefit More from Optimization
1. **Memory Bandwidth Dominance**
- Attention has O() memory complexity for the QK^T matrix
- At seq_len=8192: 8192² × 4 bytes = 256MB per head per batch
- Memory optimizations (order, latency hints) have larger impact
2. **More K-Loop Iterations**
- At seq_len=512: 8 K-tiles (512/64) for sm121, 2 K-tiles (512/256) for sm100
- At seq_len=8192: 128 K-tiles for sm121, 32 K-tiles for sm100
- Latency hiding through pipelining amortizes over more iterations
3. **Better Occupancy Utilization**
- More tiles = more parallelism opportunities
- sm121 uses occupancy=2 (smaller tiles, more blocks)
- sm100 uses occupancy=1 with larger tiles (256x128)
4. **Platform-Specific Tuning**
- DGX Spark (sm121): 64x64 tiles, occupancy=2 - optimized for bandwidth-limited workloads
- B300 (sm100): 256x128 tiles, occupancy=1 - optimized for compute-heavy workloads
## Optimization Priority by Problem Size
Small (seq_len <= 1024):
- Basic cuTile often sufficient
- Focus on correctness first
Medium (1024 < seq_len <= 4096):
- Math optimizations (exp2) provide ~5% gain
- Memory optimizations start to matter
Large (seq_len > 4096):
- Full optimization stack critical
- Platform-specific tuning essential
- Memory pipelining becomes essential
""")
def export_results(results: List[SeqLenResult], output_dir: str):
configs = get_fmha_config()
primary_cfg = configs[0]
data = {
"config": {
"batch": BATCH,
"n_heads": N_HEADS,
"head_dim": HEAD_DIM,
"tile_m": primary_cfg.TILE_M,
"tile_n": primary_cfg.TILE_N,
"occupancy": primary_cfg.occupancy,
"num_ctas": primary_cfg.num_ctas,
"platform": primary_cfg.name,
},
"results": [
{
"seq_len": r.seq_len,
"steps": [asdict(s) for s in r.steps],
"best_step": r.best_step,
"best_speedup": r.best_speedup,
"tilegym_latency_ms": r.tilegym_latency_ms,
"tilegym_tflops": r.tilegym_tflops,
"tilegym_speedup": r.tilegym_speedup,
}
for r in results
]
}
json_path = f"{output_dir}/fmha_scaling_results.json"
with open(json_path, 'w') as f:
json.dump(data, f, indent=2)
md_path = f"{output_dir}/fmha_scaling_results.md"
with open(md_path, 'w') as f:
f.write("# FMHA Scaling Analysis Results\n\n")
f.write("## Configuration\n")
f.write(f"- Platform: {primary_cfg.name}\n")
f.write(f"- Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}\n")
f.write(f"- Tile: {primary_cfg.TILE_M}x{primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}\n\n")
f.write("## Target-Specific Configs (from TileGym)\n\n")
f.write("| Platform | TILE_M | TILE_N | num_ctas | occupancy |\n")
f.write("|----------|--------|--------|----------|----------|\n")
f.write("| DGX Spark (sm121) | 64 | 64 | 1 | 2 |\n")
f.write("| B300 (sm100) Config 1 | 256 | 128 | 1 | 1 |\n")
f.write("| B300 (sm100) Config 2 | 128 | 128 | 1 | 2 |\n\n")
f.write("## Results by Sequence Length\n\n")
for r in results:
f.write(f"### Seq Len = {r.seq_len}\n\n")
f.write("| Step | Name | Latency (ms) | TFLOPS | Speedup |\n")
f.write("|------|------|--------------|--------|--------|\n")
for s in r.steps:
f.write(f"| {s.step} | {s.name} | {s.latency_ms:.3f} | {s.tflops:.2f} | {s.speedup_vs_baseline:.2f}x |\n")
if r.tilegym_latency_ms > 0:
f.write(f"| TG | TileGym Reference | {r.tilegym_latency_ms:.3f} | {r.tilegym_tflops:.2f} | {r.tilegym_speedup:.2f}x |\n")
f.write("\n")
log_path = f"{output_dir}/fmha_scaling_log.txt"
with open(log_path, 'w') as f:
f.write('\n'.join(logger.logs))
logger.log(f"\nResults exported to:")
logger.log(f" - {json_path}")
logger.log(f" - {md_path}")
logger.log(f" - {log_path}")
def main():
parser = argparse.ArgumentParser(description="FMHA Scaling Analysis")
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
args = parser.parse_args()
results = run_scaling_analysis(iterations=args.iterations)
print_summary(results)
export_results(results, args.output_dir)
if __name__ == "__main__":
main()

View File

@ -1,6 +1,7 @@
# Run models with llama.cpp on DGX Spark
> Build llama.cpp with CUDA and serve models via an OpenAI-compatible API
> Build llama.cpp with CUDA and serve models via an OpenAI-compatible API (Nemotron 3 Nano Omni as example)
## Table of Contents
@ -14,148 +15,167 @@
## Basic idea
[llama.cpp](https://github.com/ggml-org/llama.cpp) is a lightweight C/C++ inference stack for large language models. You build it with CUDA so it fully utilizes the DGX Spark GB10 GPU, then load GGUF weights and expose chat through `llama-server`s OpenAI-compatible HTTP API.
[llama.cpp](https://github.com/ggml-org/llama.cpp) is a lightweight C/C++ inference stack for large language models. You build it with CUDA so tensor work runs on the DGX Spark GB10 GPU, then load GGUF weights and expose chat through `llama-server`s OpenAI-compatible HTTP API.
This playbook walks through that stack end to end using MTP-enabled **Qwen3.6-35B-A3B** as the hands-on example. Checkpoint choices and paths for all supported models are summarized in the matrix below; commands are in the instructions.
This playbook walks through that stack end to end using **Nemotron 3 Nano Omni** as the hands-on example: an NVIDIA MoE family that runs well from quantized GGUF on Spark. Checkpoint choices and paths for all supported models are summarized in the matrix below; commands are in the instructions.
## What you'll accomplish
You will build llama.cpp with CUDA for GB10, download a **Qwen3.6-35B-A3B** checkpoint, and run **`llama-server`** with GPU offload. You get:
You will build llama.cpp with CUDA for GB10, download a **Nemotron 3 Nano Omni** example checkpoint, and run **`llama-server`** with GPU offload. You get:
- Local inference through llama.cpp (no separate Python inference framework required)
- An OpenAI-compatible `/v1/chat/completions` endpoint for tools and apps
- A concrete validation that the **Qwen3.6-35B-A3B** example runs on this stack on DGX Spark with MTP support.
- Local inference through llama.cpp (no separate Python inference framework required)
- An OpenAI-compatible `/v1/chat/completions` endpoint for tools and apps
- A concrete validation that the **Nemotron 3 Nano Omni** example runs on this stack on DGX Spark
## What to know before starting
- Basic familiarity with Linux command line and terminal commands
- Understanding of git and building from source with CMake
- Basic familiarity with Linux command line and terminal commands
- Understanding of git and building from source with CMake
- Basic knowledge of REST APIs and cURL for testing
- Familiarity with Hugging Face Hub for downloading GGUF files
## Prerequisites
**Hardware requirements**
- NVIDIA DGX Spark with GB10 GPU
- Sufficient unified memory for the model and the KV-Cache being utilized (about 30GB free RAM for the model in the example)
- At least **\~40GB** free disk for the example download plus build artifacts (more if you keep multiple GGUFs)
- NVIDIA DGX Spark with GB10 GPU
- Sufficient unified memory for the example **Q8_0** checkpoint (weights on the order of **~35GB**, plus KV cache and runtime overhead—scale up if you pick a larger quant or longer context)
- At least **~40GB** free disk for the example download plus build artifacts (more if you keep multiple GGUFs)
**Software requirements**
- NVIDIA DGX OS
- Git: `git --version`
- CMake (3.14+): `cmake --version`
- CUDA Toolkit: `nvcc --version`
- NVIDIA DGX OS
- Git: `git --version`
- CMake (3.14+): `cmake --version`
- CUDA Toolkit: `nvcc --version`
- Network access to GitHub and Hugging Face
## Model support matrix
DGX Spark supports any GGUF format model checkpoint with llama.cpp, as long as the system has memory available to host and run the checkpoint.
The following models are supported with llama.cpp on Spark. The instructions use the **Nemotron 3 Nano Omni** example row by default.
| Model | Support Status | HF Handle |
|-------|----------------|-----------|
| **Nemotron 3 Nano Omni** (example walkthrough) | ✅ | `ggml-org/NVIDIA-Nemotron-3-Nano-Omni` |
| **Qwen3.6-35B-A3B** | ✅ | `unsloth/Qwen3.6-35B-A3B-GGUF` |
| **Qwen3.6-27B** | ✅ | `unsloth/Qwen3.6-27B-GGUF` |
| **Gemma 4 31B IT** | ✅ | `ggml-org/gemma-4-31B-it-GGUF` |
| **Gemma 4 26B A4B IT** | ✅ | `ggml-org/gemma-4-26B-A4B-it-GGUF` |
| **Gemma 4 E4B IT** | ✅ | `ggml-org/gemma-4-E4B-it-GGUF` |
| **Gemma 4 E2B IT** | ✅ | `ggml-org/gemma-4-E2B-it-GGUF` |
| **Nemotron-3-Nano** | ✅ | `unsloth/Nemotron-3-Nano-30B-A3B-GGUF` |
## Time & risk
* **Estimated time:** About 30 minutes, plus downloading the example GGUF (\~35GB order of magnitude for the default quant)
* **Risk level:** Low — build is local to your clone; no system-wide installs required for the steps below
* **Rollback:** Remove the `llama.cpp` clone and the model directory under `~/.cache/huggingface/hub/` to reclaim disk space
* **Last updated:** 06/03/2026
* Walkthrough now uses Qwen3.6-35B-A3B as an example
* **Estimated time:** About 30 minutes, plus downloading the example GGUF (~35GB order of magnitude for the default quant)
* **Risk level:** Low — build is local to your clone; no system-wide installs required for the steps below
* **Rollback:** Remove the `llama.cpp` clone and the model directory under `~/models/` to reclaim disk space
* **Last updated:** 04/28/2026
* Walkthrough now uses Nemotron Omni; other model rows stay available
## Instructions
## Step 1. Install the dependencies
## Step 1. Verify prerequisites
Install the required dependencies:
The **example** checkpoint is **`nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf`** from Hugging Face repo **`ggml-org/NVIDIA-Nemotron-3-Nano-Omni`** (full handle: `ggml-org/NVIDIA-Nemotron-3-Nano-Omni/nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf`). Other supported GGUFs—including Qwen3.6, Gemma, and alternate Nemotron Omni builds—use the same build and server steps; change `hf download` and `--model` paths (see the [overview model matrix](overview.md)).
```shell
sudo apt install -y git clang cmake libcurl4-openssl-dev libssl-dev
Ensure the required tools are installed:
```bash
git --version
cmake --version
nvcc --version
```
All commands should return version information. If any are missing, install them before continuing.
Install the Hugging Face CLI:
```bash
python3 -m venv llama-cpp-venv
source llama-cpp-venv/bin/activate
pip install -U "huggingface_hub[cli]"
```
Verify installation:
```bash
hf version
```
## Step 2. Clone the llama.cpp repository
Clone upstream llama.cpp—the framework you are building:
```shell
```bash
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
```
## Step 3. Build llama.cpp with CUDA
Configure CMake with CUDA and GB10s **sm\_121** architecture so GGMLs CUDA backend matches your GPU:
Configure CMake with CUDA and GB10s **sm_121** architecture so GGMLs CUDA backend matches your GPU:
```shell
cmake -B build -DGGML_NATIVE=ON -DGGML_CUDA=ON -DGGML_CURL=ON -DGGML_RPC=ON -DCMAKE_CUDA_ARCHITECTURES=121a-real
cmake --build build --config Release -j
```bash
mkdir build && cd build
cmake .. -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="121" -DLLAMA_CURL=OFF
make -j8
```
The build usually takes on the order of 510 minutes. When it finishes, binaries such as `llama-server` appear under `build/bin/`.
## Step 4. Start llama-server with a model
## Step 4. Download example Nemotron 3 Nano Omni GGUF
llama.cpp loads models in **GGUF** format. This playbook uses the **Q4\_K\_XL** checkpoint from `unsloth/Qwen3.6-35B-A3B-MTP-GGUF`, which provides a good balance between quality and speed on DGX Spark.
llama.cpp loads models in **GGUF** format. This playbook uses the **Q8_0** checkpoint from `ggml-org/NVIDIA-Nemotron-3-Nano-Omni`, which balances quality and memory on DGX Spark GB10 unified memory.
From your `llama.cpp/build` directory, launch the OpenAI-compatible server with GPU offload. It will load the model from HuggingFace first if it hasnt been downloaded before or if there are any updates.
All models are saved in the default HuggingFace cache directory in \~/.cache/huggingface/hub. For instance, this model will be saved into \~/.cache/huggingface/hub/models--unsloth--Qwen3.6-35B-A3B-MTP-GGUF
It will also automatically load mmproj file to enable vision capabilities if supported by the model. By default, llama-server will try to fit full model context with ability to serve 4 concurrent requests, but it will adjust parameters automatically if needed.
```shell
./bin/llama-server \
-hf unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL \
--host 0.0.0.0 \
--port 30000
```bash
hf download ggml-org/NVIDIA-Nemotron-3-Nano-Omni \
nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf \
--local-dir ~/models/NVIDIA-Nemotron-3-Nano-Omni
```
To run with MTP speculative decoding, provide additional parameters as shown in the example below. MTP requires a compatible model, like `unsloth/Qwen3.6-35B-A3B-MTP-GGUF` used in this example. The following example also sets “preserve\_thinking” flag that allows Qwen models to use so-called “interleaved thinking” by preserving all prior thinking blocks in the history which can be useful for agentic workflows.
The file is on the order of **~35GB** (exact size may vary). The download can be resumed if interrupted.
```shell
## Step 5. Start llama-server with Nemotron 3 Nano Omni
From your `llama.cpp/build` directory, launch the OpenAI-compatible server with GPU offload:
```bash
./bin/llama-server \
-hf unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL \
--model ~/models/NVIDIA-Nemotron-3-Nano-Omni/nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf \
--host 0.0.0.0 \
--port 30000 \
--chat-template-kwargs '{"preserve_thinking": true}' \
--spec-type draft-mtp \
--spec-draft-n-max 3
--n-gpu-layers 99 \
--ctx-size 8192 \
--threads 8
```
**Parameters (short):**
- `--host` / `--port`: bind address and port for the HTTP API
- `--chat-template-kwargs`: sets additional params for the json template parser, must be a valid json object string
- `--spec-type`: comma-separated list of types of speculative decoding to use (default: none, most MTP-compatible models will use “draft-mtp”, but you need to check the model card first)
- `--spec-draft-n-max`: number of tokens to draft for speculative decoding (default: 3\)
- `--host` / `--port`: bind address and port for the HTTP API
- `--n-gpu-layers 99`: offload layers to the GPU (adjust if you use a different model)
- `--ctx-size`: context length (can be increased up to model/server limits; uses more memory)
- `--threads`: CPU threads for non-GPU work
You should see log lines similar to:
```
0.14.322.968 I srv load_model: speculative decoding context initialized
0.14.322.970 I slot load_model: id 0 | task -1 | new slot, n_ctx = 262144
0.14.322.972 I slot load_model: id 1 | task -1 | new slot, n_ctx = 262144
0.14.322.972 I slot load_model: id 2 | task -1 | new slot, n_ctx = 262144
0.14.322.973 I slot load_model: id 3 | task -1 | new slot, n_ctx = 262144
0.14.323.063 I srv load_model: prompt cache is enabled, size limit: 8192 MiB
llama_new_context_with_model: n_ctx = 8192
...
0.14.342.935 I srv llama_server: model loaded
0.14.342.939 I srv llama_server: server is listening on http://0.0.0.0:30000
0.14.342.944 I srv update_slots: all slots are idle
main: server is listening on 0.0.0.0:30000
```
**Keep this terminal open** while testing. Large GGUFs can take a minute or more to load, and initial model download can take a while if the model is not downloaded yet. You will see a progress bar when model is being downloaded.
**Keep this terminal open** while testing. Large GGUFs can take a minute or more to load; until you see `server is listening`, nothing accepts connections on port 30000 (see Troubleshooting if `curl` reports connection refused).
The server is only ready to accept incoming connections on port 30000 after you see `server is listening` message (see Troubleshooting if `curl` reports connection refused).
## Step 6. Test the API
## Step 5. Test the API
Use a **second terminal on the same machine** that runs `llama-server` (for example another SSH session into DGX Spark). If you run `curl` on your laptop while the server runs only on Spark, use the Spark hostname or IP instead of `localhost`.
Use a **second terminal on the same machine** that runs `llama-server` (for example another SSH session into DGX Spark). If you run `curl` on your laptop while the server runs only on Spark, use the Spark hostname or IP instead of `localhost`.
```shell
```bash
curl -X POST http://127.0.0.1:30000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL",
"model": "nemotron",
"messages": [{"role": "user", "content": "New York is a great city because..."}],
"max_tokens": 100
}'
@ -178,7 +198,7 @@ Example shape of the response (fields vary by llama.cpp version; `message` may i
}
],
"created": 1765916539,
"model": "$MODEL_PATH",
"model": "nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf",
"object": "chat.completion",
"usage": {
"completion_tokens": 100,
@ -192,35 +212,41 @@ Example shape of the response (fields vary by llama.cpp version; `message` may i
}
```
## Step 6. Longer completion (with Qwen3.6-35B-A3B)
## Step 7. Longer completion (with Nemotron 3 Nano Omni)
Try a slightly longer prompt to confirm stable generation with **Qwen3.6-35B-A3B**:
Try a slightly longer prompt to confirm stable generation with **Nemotron 3 Nano Omni**:
```shell
```bash
curl -X POST http://127.0.0.1:30000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL",
"model": "nemotron",
"messages": [{"role": "user", "content": "Solve this step by step: If a train travels 120 miles in 2 hours, what is its average speed?"}],
"max_tokens": 500
}'
```
## Step 7. Cleanup
## Step 8. Cleanup
Stop the server with `Ctrl+C` in the terminal where it is running.
To remove this tutorials artifacts:
```shell
```bash
rm -rf ~/llama.cpp
rm -rf ~/.cache/huggingface/hub/models--unsloth--Qwen3.6-35B-A3B-MTP-GGUF
rm -rf ~/models/NVIDIA-Nemotron-3-Nano-Omni
```
## Step 8. Next steps
Deactivate the Python venv if you no longer need `hf`:
1. **Context length:** By default, llama.cpp tries to allocate maximum context size supported for the model if possible, but you can also set it manually using `--ctx-size` (or `-c`) to adjust for your needs. For agentic or coding needs you need a minimum of 32768 tokens, preferably 100000 or more.
2. **Other models:** You can use `--model` to load any compatible GGUF downloaded locally; the llama.cpp server API stays the same. Use `-hf` to let llama.cpp automatically manage downloads/updates. Please note that if you use `--model` with multi-modal models, you need to provide a path to .mmproj file using `--mmproj` parameter. If you use `-hf` it will load the mmproj file automatically.
```bash
deactivate
```
## Step 9. Next steps
1. **Context length:** Increase `--ctx-size` for longer chats (watch memory; 1M-token class contexts are possible only when the build, model, and hardware allow).
2. **Other models:** Point `--model` at any compatible GGUF; the llama.cpp server API stays the same.
3. **Integrations:** Point Open WebUI, Continue.dev, or custom clients at `http://<spark-host>:30000/v1` using the OpenAI client pattern.
The server implements the usual OpenAI-style chat features your llama.cpp build enables (including streaming and tool-related flows where supported).

File diff suppressed because it is too large Load Diff

View File

@ -1,108 +0,0 @@
# Register DGX Spark to Brev
> Link your DGX Spark to Brev for remote access and shared environments
## Table of Contents
- [Overview](#overview)
- [Instructions](#instructions)
- [Troubleshooting](#troubleshooting)
---
## Overview
## Basic idea
NVIDIA Brev is an AI development platform that makes GPU environments remotely accessible, shareable, and easy to standardize using preconfigured setups called Launchables.
This walkthrough will help you connect your NVIDIA DGX Spark to Brev so it shows up as a managed GPU environment in Brev. After a one-time registration, your Spark becomes remotely accessible and shareable.
## What you'll accomplish
Youll register your DGX Spark with Brev and it will be visible as a healthy node in the Brev web UI and CLI, ready to share access and accept workloads whenever needed.
## What to know before starting
While Brev automates the complex configuration, understanding a few key concepts when establishing the initial connection will be useful:
* **Terminal Basics**:
* Familiarity with the command line to run a few simple setup commands
## Prerequisites
Your DGX Spark [device is set up](https://docs.nvidia.com/dgx/dgx-spark/first-boot.html). You will also need the following:
* **Brev Account**:
* Have an NVIDIA Brev account. Create one [here](https://login.brev.nvidia.com/signin) if you dont have one.
* **Permissions**:
* You have administrative (root or sudo) access on the DGX Spark device to run the registration command.
## Time & risk
* **Estimated time:** 5-10 minutes
* **Risk level:** Low - Registration configures the Spark for secure remote access without altering your existing workloads
* **Rollback:** The Brev configuration can be removed through the UI and CLI
## Instructions
## Step 1. Log in to Brev
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm youre in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
Click the “Register Compute” button and follow the instructions in the pop-up window.
## Step 2. Complete Popup Instructions
* Install the Brev CLI
* Configure your compute
* Add a name for compute
* To configure ssh, ensure the “Enable SSH access” toggle is on
* Run the registration command
## Step 3. Follow Registration Flow
In the CLI, youll be walked through registration. Go through the flow until registration is complete.
## Step 4. Confirm Spark in Brev UI
* Go to the [Brev UI](https://brev.nvidia.com)
* Navigate to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute)
* Confirm that the DGX Spark appears as a registered node with a **Connected** status
## Step 5. Next Steps
Your Spark is now integrated into Brev as a secure, remotely accessible GPU environment.
Now that your hardware is connected, you can:
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI by:
* Adding the user to your [Team](https://brev.nvidia.com/org/team)
* Navigating to your instance in the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section
* In **SSH Access** section of the instance, search for the user you wish to add and click **Modify Access** to enable access
## Step 6. Cleanup
If you ever decide to unregister your Spark with Brev, you can either do so through the Brev UI or the Brev CLI.
With the CLI simply run:
```bash
brev deregister
```
In the UI:
* Go to the [Brev UI](https://brev.nvidia.com)
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
* Click the “Remove” menu item on the Spark you wish to delete from Brev.
* Confirm your selection.
## Troubleshooting
| Symptom | Cause | Fix |
|---------|-------|-----|
| Your DGX Spark is showing up in the wrong org | You registered your DGX Spark to the wrong org | Run `brev set <my-org>` and then redo the registration |
| Unable to `brev shell <name>` | Need to refresh | `brev refresh` |
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).

View File

@ -52,18 +52,21 @@ You will also need the following:
## Step 1. Log in to Brev
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm youre in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm youre in the correct org (by clicking the org button on the top right-hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
Click the “Register Compute” button and follow the instructions in the pop-up window.
## Step 2. Complete Popup Instructions
## Step 2. Complete Pop-up Instructions
* Install the Brev CLI
* Configure your compute
* Add a name for compute
* To configure ssh, ensure the “Enable SSH access” toggle is on
* To configure SSH, ensure the “Enable SSH access” toggle is on
* Run the registration command
> [!IMPORTANT]
> Run the Brev CLI install command **without `sudo`**. Prefixing the installer with `sudo` writes the `brev` binary into root's home directory, which is not on your user shell's `PATH` — the next command will fail with `brev: command not found`. Copy the install command from the pop-up and run it as your normal user.
## Step 3. Follow Registration Flow
In the CLI, youll be walked through registration. Go through the flow until registration is complete.
@ -80,10 +83,14 @@ Your DGX Station is now integrated into Brev as a secure, remotely accessible GP
Now that your hardware is connected, you can:
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI by:
* Adding the user to your [Team](https://brev.nvidia.com/org/team)
* Navigating to your instance in the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section
* In **SSH Access** section of the instance, search for the user you wish to add and click **Modify Access** to enable access
* **Access your machine from anywhere:** Open the [Brev UI](https://brev.nvidia.com) and launch a session from [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
* **Share access with others:** Invite teammates to your DGX Station from the Brev UI:
* Go to the [Brev UI](https://brev.nvidia.com) and open [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
* Find your DGX Station in the list and open the row's three-dot (⋯) menu.
* Select **Share Access**.
* Enter the email address of the person you want to share with.
* Choose their role / permission level.
* Confirm to send the invitation.
## Step 6. Cleanup
@ -98,7 +105,7 @@ brev deregister
In the UI:
* Go to the [Brev UI](https://brev.nvidia.com)
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
* Click the “Remove” menu item on the DGX Station you wish to delete from Brev.
* Click the “Remove” menu item on the device you wish to delete from Brev.
* Confirm your selection.
## Troubleshooting

View File

@ -82,18 +82,21 @@ spec:
content: |
# Step 1. Log in to Brev
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm youre in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm youre in the correct org (by clicking the org button on the top right-hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
Click the “Register Compute” button and follow the instructions in the pop-up window.
# Step 2. Complete Popup Instructions
# Step 2. Complete Pop-up Instructions
* Install the Brev CLI
* Configure your compute
* Add a name for compute
* To configure ssh, ensure the “Enable SSH access” toggle is on
* To configure SSH, ensure the “Enable SSH access” toggle is on
* Run the registration command
> [!IMPORTANT]
> Run the Brev CLI install command **without `sudo`**. Prefixing the installer with `sudo` writes the `brev` binary into root's home directory, which is not on your user shell's `PATH` — the next command will fail with `brev: command not found`. Copy the install command from the pop-up and run it as your normal user.
# Step 3. Follow Registration Flow
In the CLI, youll be walked through registration. Go through the flow until registration is complete.
@ -110,10 +113,14 @@ spec:
Now that your hardware is connected, you can:
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI by:
* Adding the user to your [Team](https://brev.nvidia.com/org/team)
* Navigating to your instance in the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section
* In **SSH Access** section of the instance, search for the user you wish to add and click **Modify Access** to enable access
* **Access your machine from anywhere:** Open the [Brev UI](https://brev.nvidia.com) and launch a session from [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
* **Share access with others:** Invite teammates to your DGX Station from the Brev UI:
* Go to the [Brev UI](https://brev.nvidia.com) and open [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
* Find your DGX Station in the list and open the row's three-dot (⋯) menu.
* Select **Share Access**.
* Enter the email address of the person you want to share with.
* Choose their role / permission level.
* Confirm to send the invitation.
# Step 6. Cleanup
@ -128,7 +135,7 @@ spec:
In the UI:
* Go to the [Brev UI](https://brev.nvidia.com)
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
* Click the “Remove” menu item on the DGX Station you wish to delete from Brev.
* Click the “Remove” menu item on the device you wish to delete from Brev.
* Confirm your selection.

View File

@ -107,7 +107,7 @@ spec:
# Time & risk
- **Estimated time:** ~30 minutes for setup. Full d24 training takes on the order of 12+ hours on a single GB300 Ultra.
- **Estimated time:** ~30 minutes for setup. Full d24 training takes on the order of 16+ hours on a single GB300 Ultra.
- **Risk level:** Medium
- Large downloads (FineWeb) can be slow; ensure stable network and disk space.
- API keys (W&B, HF) must be set or `launch.sh` will exit immediately.
@ -184,7 +184,7 @@ spec:
3. **SFT** — downloads synthetic identity conversations, fine-tunes for chat
4. **Report generation** — produces `report.md` with metrics and samples
Training on a single GB300 Ultra takes on the order of 12+ hours for the full d24 run.
Training on a single GB300 Ultra takes on the order of 16+ hours for the full d24 run.
# Step 4. Monitor training