mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-21 21:59:30 +00:00
Compare commits
6 Commits
cda3f97231
...
487d4a0894
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
487d4a0894 | ||
|
|
2f703e1793 | ||
|
|
9ce5aae4f3 | ||
|
|
82f68146e5 | ||
|
|
231a45230d | ||
|
|
050f799875 |
@ -21,11 +21,13 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
|
||||
|
||||
### NVIDIA
|
||||
|
||||
- [CLI Coding Agent](nvidia/cli-coding-agent/)
|
||||
- [Comfy UI](nvidia/comfy-ui/)
|
||||
- [Connect Three DGX Spark in a Ring Topology](nvidia/connect-three-sparks/)
|
||||
- [Set Up Local Network Access](nvidia/connect-to-your-spark/)
|
||||
- [Connect Two Sparks](nvidia/connect-two-sparks/)
|
||||
- [CUDA-X Data Science](nvidia/cuda-x-data-science/)
|
||||
- [cuTile Kernels](nvidia/cutile-kernels/)
|
||||
- [DGX Dashboard](nvidia/dgx-dashboard/)
|
||||
- [FLUX.1 Dreambooth LoRA Fine-tuning](nvidia/flux-finetuning/)
|
||||
- [Run Hermes Agent with Local Models](nvidia/hermes-agent/)
|
||||
@ -41,6 +43,7 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
|
||||
- [NCCL for Two Sparks](nvidia/nccl/)
|
||||
- [Fine-tune with NeMo](nvidia/nemo-fine-tune/)
|
||||
- [Run NemoClaw with a Local LLM](nvidia/nemoclaw/)
|
||||
- [🦞 Set Up Example NemoClaw Agents 🦞](nvidia/nemoclaw-applications/)
|
||||
- [Nemotron-3-Nano with llama.cpp](nvidia/nemotron/)
|
||||
- [NIM on Spark](nvidia/nim-llm/)
|
||||
- [NVFP4 Quantization](nvidia/nvfp4-quantization/)
|
||||
@ -52,6 +55,7 @@ Each playbook includes prerequisites, step-by-step instructions, troubleshooting
|
||||
- [Fine-tune with Pytorch](nvidia/pytorch-fine-tune/)
|
||||
- [RAG Application in AI Workbench](nvidia/rag-ai-workbench/)
|
||||
- [Spark & Reachy Photo Booth](nvidia/reachy-photo-booth/)
|
||||
- [Register DGX Spark to Brev](nvidia/register-to-brev/)
|
||||
- [SGLang for Inference](nvidia/sglang/)
|
||||
- [Single-cell RNA Sequencing](nvidia/single-cell/)
|
||||
- [Speculative Decoding](nvidia/speculative-decoding/)
|
||||
|
||||
474
nvidia/cli-coding-agent/README.md
Normal file
474
nvidia/cli-coding-agent/README.md
Normal file
@ -0,0 +1,474 @@
|
||||
# CLI Coding Agent
|
||||
|
||||
> Build local CLI coding agents with Ollama
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Claude Code](#claude-code)
|
||||
- [OpenCode](#opencode)
|
||||
- [Codex CLI](#codex-cli)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
Use [Ollama](https://ollama.com) on [DGX Spark](https://www.nvidia.com/en-us/products/workstations/dgx-spark/) to run a local coding model and connect a CLI coding agent. This
|
||||
playbook supports three options: **[Claude Code](https://docs.claude.com/en/docs/claude-code)**, **[OpenCode](https://opencode.ai)**, and **[Codex CLI](https://github.com/openai/codex)**. Each
|
||||
agent is wired up with Ollama's built-in [launch method](https://ollama.com/blog/launch) (`ollama launch <agent>`), so you
|
||||
can work without environment variables, provider config files, or external cloud APIs.
|
||||
|
||||
## Choose your CLI agent
|
||||
|
||||
Pick the tab that matches the CLI agent you want to use:
|
||||
|
||||
- **Claude Code**: Fastest path to a working CLI agent with a local Ollama model.
|
||||
- **OpenCode**: Open-source CLI launched directly from Ollama.
|
||||
- **Codex CLI**: OpenAI Codex CLI launched directly from Ollama against the local model.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You will run a local coding model ([Qwen3.6](https://ollama.com/library/qwen3.6)) on your DGX Spark with Ollama, launch your
|
||||
chosen CLI agent against it with a single command, and complete a small coding task end-to-end.
|
||||
|
||||
## What to know before starting
|
||||
|
||||
- Comfort with Linux command line basics
|
||||
- Experience running terminal-based tools and editors
|
||||
- Familiarity with Python for the short coding task
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- DGX Spark access with NVIDIA DGX OS 7.3.1 (Ubuntu 24.04.3 LTS base)
|
||||
- Internet access to download model weights
|
||||
- [Ollama](https://ollama.com/download) v0.15 or newer (required for [`ollama launch`](https://ollama.com/blog/launch))
|
||||
- GPU memory depends on the Qwen3.6 variant you choose:
|
||||
- `qwen3.6:latest` (35B-a3b, MoE) — ~24GB, 256K context
|
||||
- `qwen3.6:35b-a3b-nvfp4` — ~22GB, NVIDIA FP4 build tuned for Blackwell (DGX Spark)
|
||||
- `qwen3.6:35b-a3b-q8_0` — ~39GB, higher-quality quant
|
||||
- `qwen3.6:35b-a3b-bf16` — ~71GB, full precision (fits Spark's unified memory)
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Duration**: ~15-25 minutes (mostly model download time)
|
||||
* **Risk level**: Low
|
||||
* Large model downloads can fail if network connectivity is unstable
|
||||
* Ollama versions older than 0.15 do not support `ollama launch`
|
||||
* **Rollback**: Stop Ollama and delete the downloaded model from `~/.ollama/models`
|
||||
* **Last Updated:** 04/16/2026
|
||||
* Switched to `ollama launch` method and upgraded the default model to Qwen3.6
|
||||
|
||||
## Claude Code
|
||||
|
||||
## Step 1. Confirm your environment
|
||||
|
||||
**Description**: Verify the OS version and GPU are visible before installing anything.
|
||||
|
||||
```bash
|
||||
cat /etc/os-release | head -n 2
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
|
||||
|
||||
## Step 2. Install or update Ollama
|
||||
|
||||
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
|
||||
|
||||
```bash
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
ollama --version
|
||||
```
|
||||
|
||||
If Ollama is already installed, just verify the version:
|
||||
|
||||
```bash
|
||||
ollama --version
|
||||
```
|
||||
|
||||
Expected output should show Ollama v0.15 or newer.
|
||||
|
||||
## Step 3. Pull Qwen3.6
|
||||
|
||||
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6
|
||||
```
|
||||
|
||||
Optional variants if you want different memory footprints or precision:
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
|
||||
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
|
||||
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
|
||||
```
|
||||
|
||||
Expected output should show `qwen3.6` (and any optional variants) in `ollama list`.
|
||||
|
||||
## Step 4. Test local inference (optional)
|
||||
|
||||
**Description**: Run a quick prompt to confirm the model loads.
|
||||
|
||||
```bash
|
||||
ollama run qwen3.6
|
||||
```
|
||||
|
||||
Try a prompt like:
|
||||
|
||||
```text
|
||||
Write a short README checklist for a Python project.
|
||||
```
|
||||
|
||||
Expected output should show the model responding in the terminal. When you are done, type `/bye` or press `Ctrl+D` to exit the interactive session before continuing.
|
||||
|
||||
## Step 5. Launch Claude Code with Ollama
|
||||
|
||||
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Claude Code](https://docs.claude.com/en/docs/claude-code) against your local model. No environment variables or config files are required.
|
||||
|
||||
```bash
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Expected output should show Claude Code starting and using the local Qwen3.6 model. Qwen3.6 ships with a 256K context window by default; adjust context length through Ollama's settings if you need to tune it further.
|
||||
|
||||
## Step 6. Complete a small coding task
|
||||
|
||||
**Description**: Create a tiny repo and let Claude Code implement a function and tests.
|
||||
|
||||
```bash
|
||||
mkdir -p ~/cli-agent-demo
|
||||
cd ~/cli-agent-demo
|
||||
|
||||
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
|
||||
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
|
||||
```
|
||||
|
||||
If you do not already have pytest installed:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -U pytest
|
||||
```
|
||||
|
||||
In Claude Code:
|
||||
|
||||
```text
|
||||
Please implement add() in math_utils.py and make sure the test passes.
|
||||
```
|
||||
|
||||
Run the test:
|
||||
|
||||
```bash
|
||||
python3 -m pytest -q
|
||||
```
|
||||
|
||||
Expected output should show the test passing.
|
||||
|
||||
## Step 7. Cleanup and rollback
|
||||
|
||||
**Description**: Remove the model and stop services if you no longer need them.
|
||||
|
||||
To stop the service:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop ollama
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This will delete the downloaded model files.
|
||||
|
||||
```bash
|
||||
ollama rm qwen3.6
|
||||
```
|
||||
|
||||
## Step 8. Next steps
|
||||
|
||||
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
|
||||
- Use Claude Code on multi-file refactors or test-generation tasks
|
||||
- Explore the full 256K context window on larger codebases
|
||||
|
||||
## OpenCode
|
||||
|
||||
## Step 1. Confirm your environment
|
||||
|
||||
**Description**: Verify the OS version and GPU are visible before installing anything.
|
||||
|
||||
```bash
|
||||
cat /etc/os-release | head -n 2
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
|
||||
|
||||
## Step 2. Install or update Ollama
|
||||
|
||||
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
|
||||
|
||||
```bash
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
ollama --version
|
||||
```
|
||||
|
||||
If Ollama is already installed, just verify the version:
|
||||
|
||||
```bash
|
||||
ollama --version
|
||||
```
|
||||
|
||||
Expected output should show Ollama v0.15 or newer.
|
||||
|
||||
## Step 3. Pull Qwen3.6
|
||||
|
||||
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6
|
||||
```
|
||||
|
||||
Optional variants if you want different memory footprints or precision:
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
|
||||
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
|
||||
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
|
||||
```
|
||||
|
||||
Expected output should show `qwen3.6` in `ollama list`.
|
||||
|
||||
## Step 4. Test local inference (optional)
|
||||
|
||||
**Description**: Run a quick prompt to confirm the model loads.
|
||||
|
||||
```bash
|
||||
ollama run qwen3.6
|
||||
```
|
||||
|
||||
Try a prompt like:
|
||||
|
||||
```text
|
||||
Write a short README checklist for a Python project.
|
||||
```
|
||||
|
||||
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
|
||||
|
||||
## Step 5. Launch OpenCode with Ollama
|
||||
|
||||
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [OpenCode](https://opencode.ai) against your local model. No [`opencode.json`](https://opencode.ai/docs/config/) provider configuration is required.
|
||||
|
||||
```bash
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
If you want to pre-configure OpenCode without launching immediately:
|
||||
|
||||
```bash
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
Expected output should show OpenCode starting with Ollama preselected as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default.
|
||||
|
||||
## Step 6. Complete a small coding task
|
||||
|
||||
**Description**: Create a tiny repo and let OpenCode implement a function and tests.
|
||||
|
||||
```bash
|
||||
mkdir -p ~/cli-agent-demo
|
||||
cd ~/cli-agent-demo
|
||||
|
||||
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
|
||||
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
|
||||
```
|
||||
|
||||
If you do not already have pytest installed:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -U pytest
|
||||
```
|
||||
|
||||
In OpenCode:
|
||||
|
||||
```text
|
||||
Please implement add() in math_utils.py and make sure the test passes.
|
||||
```
|
||||
|
||||
Run the test:
|
||||
|
||||
```bash
|
||||
python3 -m pytest -q
|
||||
```
|
||||
|
||||
Expected output should show the test passing.
|
||||
|
||||
## Step 7. Cleanup and rollback
|
||||
|
||||
**Description**: Remove the model and stop services if you no longer need them.
|
||||
|
||||
To stop the service:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop ollama
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This will delete the downloaded model files.
|
||||
|
||||
```bash
|
||||
ollama rm qwen3.6
|
||||
```
|
||||
|
||||
## Step 8. Next steps
|
||||
|
||||
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
|
||||
- Use OpenCode on multi-file changes or test-generation tasks
|
||||
- Explore the full 256K context window on larger codebases
|
||||
|
||||
## Codex CLI
|
||||
|
||||
## Step 1. Confirm your environment
|
||||
|
||||
**Description**: Verify the OS version and GPU are visible before installing anything.
|
||||
|
||||
```bash
|
||||
cat /etc/os-release | head -n 2
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
Expected output should show Ubuntu 24.04.3 LTS (DGX OS 7.3.1 base) and a detected GPU.
|
||||
|
||||
## Step 2. Install or update Ollama
|
||||
|
||||
**Description**: Install [Ollama](https://ollama.com/download) or ensure it is recent enough to support [`ollama launch`](https://ollama.com/blog/launch).
|
||||
|
||||
```bash
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
ollama --version
|
||||
```
|
||||
|
||||
If Ollama is already installed, just verify the version:
|
||||
|
||||
```bash
|
||||
ollama --version
|
||||
```
|
||||
|
||||
Expected output should show Ollama v0.15 or newer.
|
||||
|
||||
## Step 3. Pull Qwen3.6
|
||||
|
||||
**Description**: Download the [Qwen3.6](https://ollama.com/library/qwen3.6) model weights to your Spark node.
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6
|
||||
```
|
||||
|
||||
Optional variants if you want different memory footprints or precision:
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.6:35b-a3b-nvfp4 # NVIDIA FP4 build tuned for Blackwell (~22GB)
|
||||
ollama pull qwen3.6:35b-a3b-q8_0 # Higher-quality 8-bit quant (~39GB)
|
||||
ollama pull qwen3.6:35b-a3b-bf16 # Full precision (~71GB)
|
||||
```
|
||||
|
||||
Expected output should show `qwen3.6` in `ollama list`.
|
||||
|
||||
## Step 4. Test local inference (optional)
|
||||
|
||||
**Description**: Run a quick prompt to confirm the model loads.
|
||||
|
||||
```bash
|
||||
ollama run qwen3.6
|
||||
```
|
||||
|
||||
Try a prompt like:
|
||||
|
||||
```text
|
||||
Write a short README checklist for a Python project.
|
||||
```
|
||||
|
||||
Expected output should show the model responding. When you are done, type `/bye` or press `Ctrl+D` to exit before continuing.
|
||||
|
||||
## Step 5. Launch Codex CLI with Ollama
|
||||
|
||||
**Description**: Use Ollama's built-in [launch method](https://ollama.com/blog/launch) to start [Codex CLI](https://github.com/openai/codex) against your local model. No `~/.codex/config.toml` and no manual `npm install -g @openai/codex` are required — Ollama handles the Codex integration.
|
||||
|
||||
```bash
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
Expected output should show Codex CLI starting with Ollama as the provider and Qwen3.6 as the model. Qwen3.6 ships with a 256K context window by default, which is well suited to Codex's agentic workflows.
|
||||
|
||||
## Step 6. Complete a small coding task
|
||||
|
||||
**Description**: Create a tiny repo and let Codex implement a function and tests.
|
||||
|
||||
```bash
|
||||
mkdir -p ~/cli-agent-demo
|
||||
cd ~/cli-agent-demo
|
||||
|
||||
printf 'def add(a, b):\n """Return the sum of a and b."""\n pass\n' > math_utils.py
|
||||
printf 'import math_utils\n\n\ndef test_add():\n assert math_utils.add(1, 2) == 3\n' > test_math_utils.py
|
||||
```
|
||||
|
||||
If you do not already have pytest installed:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -U pytest
|
||||
```
|
||||
|
||||
In Codex:
|
||||
|
||||
```text
|
||||
Please implement add() in math_utils.py and make sure the test passes.
|
||||
```
|
||||
|
||||
Run the test:
|
||||
|
||||
```bash
|
||||
python3 -m pytest -q
|
||||
```
|
||||
|
||||
Expected output should show the test passing.
|
||||
|
||||
## Step 7. Cleanup and rollback
|
||||
|
||||
**Description**: Remove the model and stop services if you no longer need them.
|
||||
|
||||
To stop the service:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop ollama
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This will delete the downloaded model files.
|
||||
|
||||
```bash
|
||||
ollama rm qwen3.6
|
||||
```
|
||||
|
||||
## Step 8. Next steps
|
||||
|
||||
- Try the `qwen3.6:35b-a3b-nvfp4` or `bf16` variants for different quality/VRAM tradeoffs
|
||||
- Use Codex CLI on multi-file changes or test-generation tasks
|
||||
- Explore the full 256K context window on larger codebases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| `ollama: command not found` | Ollama not installed or PATH not updated | Rerun `curl -fsSL https://ollama.com/install.sh \| sh` and open a new shell |
|
||||
| `ollama launch` reports unknown command | Ollama is older than v0.15 | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
|
||||
| Model load fails with version error or HTTP 412 | Ollama version is too old for the model | Update Ollama: `curl -fsSL https://ollama.com/install.sh \| sh` |
|
||||
| `model not found` when launching an agent | Model was not pulled | Run `ollama pull qwen3.6` and retry |
|
||||
| `connection refused` to localhost:11434 | Ollama service not running | Start with `ollama serve` or `sudo systemctl start ollama` |
|
||||
| `ollama launch <agent>` exits immediately | Agent integration failed to initialize | Re-run `ollama launch <agent>`; if it persists, check `journalctl -u ollama` |
|
||||
| Slow responses or OOM errors | Model variant too large for GPU memory | Switch to `qwen3.6:35b-a3b-nvfp4` or close other GPU workloads |
|
||||
|
||||
> [!NOTE]
|
||||
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing
|
||||
> between the GPU and CPU. If you see memory pressure, flush the buffer cache with:
|
||||
> ```bash
|
||||
> sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
> ```
|
||||
859
nvidia/cutile-kernels/README.md
Normal file
859
nvidia/cutile-kernels/README.md
Normal file
@ -0,0 +1,859 @@
|
||||
# cuTile Kernels
|
||||
|
||||
> Run cuTile kernel benchmarks, FMHA implementation, and LLM inference on DGX Spark and B300
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Kernel Benchmarks](#kernel-benchmarks)
|
||||
- [End-to-End Inference](#end-to-end-inference)
|
||||
- [FMHA Implementation](#fmha-implementation)
|
||||
- [Attention Basics](#attention-basics)
|
||||
- [Flash Attention Algorithm](#flash-attention-algorithm)
|
||||
- [cuTile Pseudocode → Actual Mapping](#cutile-pseudocode-actual-mapping)
|
||||
- [Kernel Pseudocode](#kernel-pseudocode)
|
||||
- [cuTile Implementation](#cutile-implementation)
|
||||
- [Launching the Kernel](#launching-the-kernel)
|
||||
- [Optimizations](#optimizations)
|
||||
- [Platform Configuration](#platform-configuration)
|
||||
- [Performance Results](#performance-results)
|
||||
- [Common Issues](#common-issues)
|
||||
- [Companion Scripts](#companion-scripts)
|
||||
- [References](#references)
|
||||
- [Platform Comparison](#platform-comparison)
|
||||
- [End-to-End Throughput](#end-to-end-throughput)
|
||||
- [CUDA Kernel Time](#cuda-kernel-time)
|
||||
- [cuTile Kernel Breakdown](#cutile-kernel-breakdown)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
[TileGym](https://github.com/NVIDIA/TileGym) is NVIDIA's benchmark suite and integration framework for cuTile kernels - high-performance GPU kernels written using the cuTile Python DSL. cuTile compiles to Tile IR, enabling developers to write efficient kernels without low-level CUDA programming.
|
||||
|
||||
This playbook covers three workflows:
|
||||
1. **[Kernel Benchmarks](kernel-benchmarks)** - Run standalone cuTile kernel benchmarks (FMHA, MatMul, RMSNorm, etc.)
|
||||
2. **[End-to-End Inference](e2e-inference)** - Run LLM inference with cuTile-optimized kernels via monkey-patching
|
||||
3. **[FMHA Implementation](fmha)** - Step-by-step tutorial building a Flash Multi-Head Attention kernel from pseudocode to optimized cuTile, with companion scripts to run and benchmark
|
||||
|
||||
The same cuTile code runs on both DGX Spark (sm_121) and B300 (sm_103) - cuTile JIT compiles to the appropriate GPU architecture automatically.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
- Run the TileGym benchmark suite on DGX Spark
|
||||
- Run Qwen2-7B or DeepSeek-V2-Lite inference with cuTile-optimized kernels
|
||||
- Observe performance scaling between DGX Spark and B300
|
||||
- Build an FMHA kernel step-by-step from pseudocode to optimized cuTile implementation
|
||||
|
||||
## What to know before starting
|
||||
|
||||
- Basic familiarity with Docker and command-line tools
|
||||
- Understanding of GPU compute concepts (TFLOPS, memory bandwidth)
|
||||
- No CUDA programming experience required
|
||||
- HuggingFace account with access token (for LLM inference)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
**Hardware Requirements:**
|
||||
- DGX Spark with Ubuntu 24.04 or B300 cloud instance
|
||||
- Minimum 16GB GPU memory for LLM inference
|
||||
- At least 50GB available storage space for model downloads
|
||||
|
||||
**Software Requirements:**
|
||||
- Docker installed and configured: `docker ps`
|
||||
- CUDA Toolkit 13.x with Tile IR support
|
||||
- HuggingFace token for model access (LLM inference only)
|
||||
- Network access for pulling containers and downloading models
|
||||
|
||||
Verify Docker is available:
|
||||
```bash
|
||||
docker ps
|
||||
```
|
||||
|
||||
If you get a permission error:
|
||||
```bash
|
||||
sudo usermod -aG docker $USER
|
||||
newgrp docker
|
||||
```
|
||||
|
||||
## Kernel support matrix
|
||||
|
||||
| Kernel | Category | Data Types | Description |
|
||||
|--------|----------|------------|-------------|
|
||||
| **FMHA** | Attention | float16, float8 | Flash Multi-Head Attention |
|
||||
| **MLA** | Attention | bfloat16, float8 | Multi-head Latent Attention |
|
||||
| **MLA Decoding** | Attention | float16, float8 | MLA for decode phase |
|
||||
| **MatMul** | Matrix Ops | float16, float8 | Matrix multiplication |
|
||||
| **BMM** | Matrix Ops | float16 | Batched matrix multiplication |
|
||||
| **Group GEMM** | Matrix Ops | float16, float8 | Grouped GEMM for MoE |
|
||||
| **RMSNorm** | Normalization | float16, bfloat16 | Root mean square normalization |
|
||||
| **RoPE** | Positional | float16 | Rotary position embedding |
|
||||
| **SiLU** | Activation | float16, float32 | SiLU activation with multiply |
|
||||
| **SwiGLU** | Activation | float16, float32 | SwiGLU fused operation |
|
||||
| **Softmax** | Activation | float16 | Softmax normalization |
|
||||
| **Dropout** | Regularization | float16, float32 | Dropout forward |
|
||||
|
||||
## Model support for LLM inference
|
||||
|
||||
| Model | Supported Kernels | Batch Size | Output Tokens | Notes |
|
||||
|-------|-------------------|------------|---------------|-------|
|
||||
| **Qwen2-7B** | RoPE, RMSNorm, SwiGLU, FMHA | 16 | 50 | Standard transformer |
|
||||
| **DeepSeek-V2-Lite** | RoPE, RMSNorm, SiLU, MLA, MoE | 1 | 100 | MLA attention, MoE layers |
|
||||
|
||||
## Ancillary files
|
||||
|
||||
All required assets can be found in the [TileGym repository](https://github.com/NVIDIA/TileGym).
|
||||
|
||||
- `tests/benchmark/run_all.sh` - Run all kernel benchmarks
|
||||
- `modeling/transformers/bench_qwen.sh` - Qwen2-7B benchmark script
|
||||
- `modeling/transformers/bench_deepseek.sh` - DeepSeek-V2-Lite benchmark script
|
||||
- `modeling/transformers/infer.py` - Main inference script with TileGym integration
|
||||
- [`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py) - FMHA step-by-step optimization tutorial
|
||||
- [`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py) - FMHA scaling analysis across sequence lengths
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Estimated time:** 30-45 minutes (including model download for LLM inference)
|
||||
* **Risk level:** Low
|
||||
* Large downloads may fail due to network issues
|
||||
* First run includes JIT compilation overhead
|
||||
* **Rollback:** Remove Docker container to undo all changes
|
||||
* **Last Updated:** February 2026
|
||||
* First Publication
|
||||
|
||||
## Kernel Benchmarks
|
||||
|
||||
## Step 1. Pull CUDA NGC container with CTK 13.x
|
||||
|
||||
```bash
|
||||
docker pull nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
|
||||
```
|
||||
|
||||
Launch an interactive session with GPU access:
|
||||
|
||||
```bash
|
||||
docker run --gpus all -it --rm \
|
||||
-v ~/TileGym:/workspace/TileGym \
|
||||
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The `-v` flag mounts a local directory to persist the TileGym repository. The `--rm` flag automatically removes the container when you exit; omit it if you want to keep the container for later use.
|
||||
|
||||
Or if running outside a container, install Tile IR directly:
|
||||
|
||||
```bash
|
||||
## Requires root privileges - run with sudo or as root
|
||||
sudo apt-get install cuda-tile-ir-13-1 cuda-compiler-13-1
|
||||
```
|
||||
|
||||
## Step 2. Clone TileGym repository
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NVIDIA/TileGym
|
||||
cd TileGym
|
||||
pip install .
|
||||
```
|
||||
|
||||
## Step 3. Run benchmark suite
|
||||
|
||||
```bash
|
||||
cd tests/benchmark/
|
||||
bash run_all.sh
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The benchmark runs sequentially to ensure accurate timing results. This may take 10-15 minutes to complete all kernels.
|
||||
|
||||
## Step 4. View results
|
||||
|
||||
Results show cuTile performance for each kernel and sequence length.
|
||||
|
||||
Expected output should look like:
|
||||
|
||||
```text
|
||||
==========================================
|
||||
Running bench_fused_attention.py...
|
||||
==========================================
|
||||
fused-attention-batch4-head32-d128-fwd-causal=True-float16-TFLOPS:
|
||||
N_CTX CuTile
|
||||
0 1024.0 58.188262
|
||||
1 2048.0 80.906892
|
||||
2 4096.0 86.189532
|
||||
3 8192.0 88.891086
|
||||
4 16384.0 89.491869
|
||||
✓ PASSED: bench_fused_attention.py
|
||||
```
|
||||
|
||||
## Step 5. Run individual benchmarks
|
||||
|
||||
To run specific kernel benchmarks:
|
||||
|
||||
```bash
|
||||
## Flash Multi-Head Attention
|
||||
python bench_fused_attention.py
|
||||
|
||||
## Matrix Multiplication
|
||||
python bench_matrix_multiplication.py
|
||||
|
||||
## RMSNorm
|
||||
python bench_rmsnorm.py
|
||||
|
||||
## RoPE
|
||||
python bench_rope.py
|
||||
|
||||
## SwiGLU
|
||||
python bench_swiglu.py
|
||||
```
|
||||
|
||||
## Step 6. Clean up
|
||||
|
||||
Exit the container:
|
||||
|
||||
```bash
|
||||
exit
|
||||
```
|
||||
|
||||
Remove this workflow's containers (if you ran without `--rm`):
|
||||
|
||||
```bash
|
||||
## Preferred: remove only containers from this workflow's image
|
||||
docker rm $(docker ps -a --filter ancestor=nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 --format '{{.ID}}')
|
||||
|
||||
## Alternative: prune all stopped containers (will prompt for confirmation)
|
||||
## docker container prune
|
||||
```
|
||||
|
||||
Remove the image (optional):
|
||||
|
||||
```bash
|
||||
docker rmi nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04
|
||||
```
|
||||
|
||||
## Step 7. Repeat on B300
|
||||
|
||||
Repeat Steps 1-6 on B300 hardware to observe scaling. See the **Platform Comparison** tab for expected scaling results.
|
||||
|
||||
## End-to-End Inference
|
||||
|
||||
## Step 1. Set up environment
|
||||
|
||||
If you haven't already, pull the CUDA container and clone TileGym (see **Kernel Benchmarks** tab for details).
|
||||
|
||||
First, clone TileGym on the host:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/TileGym
|
||||
git clone https://github.com/NVIDIA/TileGym ~/TileGym
|
||||
```
|
||||
|
||||
Then launch the container with the repository mounted:
|
||||
|
||||
```bash
|
||||
docker run --gpus all -it --rm \
|
||||
-v ~/TileGym:/workspace/TileGym \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
nvcr.io/nvidia/cuda:13.1-devel-ubuntu24.04 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The `-v ~/.cache/huggingface:/root/.cache/huggingface` mounts your HuggingFace cache to avoid re-downloading models.
|
||||
|
||||
Install TileGym inside the container:
|
||||
|
||||
```bash
|
||||
cd /workspace/TileGym
|
||||
pip install .
|
||||
```
|
||||
|
||||
Set your HuggingFace token for accessing gated models:
|
||||
|
||||
```bash
|
||||
export HF_TOKEN=<your_huggingface_token>
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> You need a HuggingFace account and access token. Get one at https://huggingface.co/settings/tokens
|
||||
|
||||
## Step 2. Run inference benchmark
|
||||
|
||||
Navigate to the transformers benchmark directory:
|
||||
|
||||
```bash
|
||||
cd modeling/transformers
|
||||
```
|
||||
|
||||
**Option A: Run Qwen2-7B benchmark**
|
||||
|
||||
```bash
|
||||
./bench_qwen.sh
|
||||
```
|
||||
|
||||
Configuration: Model `Qwen/Qwen2-7B`, Batch size 16, Output length 50 tokens.
|
||||
|
||||
**Option B: Run DeepSeek-V2-Lite benchmark**
|
||||
|
||||
```bash
|
||||
./bench_deepseek.sh
|
||||
```
|
||||
|
||||
Configuration: Model `deepseek-ai/DeepSeek-V2-Lite-Chat`, Batch size 1, Output length 100 tokens.
|
||||
|
||||
Both scripts run two configurations:
|
||||
1. **PyTorch baseline** - Standard HuggingFace inference
|
||||
2. **TileGym cuTile** - With cuTile kernel replacements
|
||||
|
||||
## Step 3. View results
|
||||
|
||||
**Sample DGX Spark (GB10) Results for Qwen2-7B:**
|
||||
|
||||
```text
|
||||
========================================
|
||||
Benchmark Results
|
||||
========================================
|
||||
Qwen2-7B_naive_bfloat16 | 15.66 tokens/s | 51.10s | 51151.0ms CUDA
|
||||
Qwen2-7B_cutile_attn | 18.52 tokens/s | 43.20s | 43079.7ms CUDA
|
||||
========================================
|
||||
```
|
||||
|
||||
**cuTile Kernel Breakdown (DGX Spark - Qwen2):**
|
||||
|
||||
| Kernel | CUDA Time (ms) | Calls |
|
||||
|--------|----------------|-------|
|
||||
| `fmha_kernel` | 4185.9 | 28 |
|
||||
| `swiglu_forward_kernel` | 2459.8 | 1400 |
|
||||
| `attention_decode_kernel_grouped` | 2271.8 | 1372 |
|
||||
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
|
||||
| `rope_kernel` | 355.6 | 1400 |
|
||||
|
||||
## Step 4. How TileGym monkey-patching works
|
||||
|
||||
TileGym replaces PyTorch model operations with cuTile kernels. The snippet below is taken from TileGym's [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py) and invoked from [`modeling/transformers/infer.py`](https://github.com/NVIDIA/TileGym/blob/main/modeling/transformers/infer.py):
|
||||
|
||||
```python
|
||||
from tilegym.transformers import apply_tilegym_kernel_to_qwen2
|
||||
|
||||
apply_tilegym_kernel_to_qwen2(
|
||||
rope=True, # Replace RoPE with cuTile kernel
|
||||
rms_norm=True, # Replace RMSNorm with cuTile kernel
|
||||
swiglu=True, # Replace SwiGLU with cuTile kernel
|
||||
attn=True, # Replace attention with cuTile FMHA
|
||||
use_cutile=True # Use cuTile backend (vs Triton)
|
||||
)
|
||||
```
|
||||
|
||||
**Patched Kernels for Qwen2:**
|
||||
|
||||
| Kernel | PyTorch Operation | cuTile Replacement |
|
||||
|--------|-------------------|-------------------|
|
||||
| `rms_norm_kernel_static_persistent` | `nn.RMSNorm` | Persistent RMSNorm |
|
||||
| `rope_kernel` | Rotary position embedding | Fused RoPE |
|
||||
| `fmha_kernel` | `F.scaled_dot_product_attention` | Flash Attention |
|
||||
| `swiglu_forward_kernel` | SiLU + Mul | Fused SwiGLU |
|
||||
| `attention_decode_kernel_grouped` | Decode attention | Grouped decode |
|
||||
|
||||
**Patched Kernels for DeepSeek-V2:** (see [`src/tilegym/transformers/monkey_patch.py`](https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/transformers/monkey_patch.py))
|
||||
|
||||
```python
|
||||
from tilegym.transformers import apply_tilegym_kernel_to_deepseek_v2
|
||||
|
||||
apply_tilegym_kernel_to_deepseek_v2(
|
||||
rope=True, # Replace RoPE with cuTile kernel
|
||||
rms_norm=True, # Replace RMSNorm with cuTile kernel
|
||||
swiglu=True, # Replace SiLU+Mul with cuTile kernel
|
||||
attn=True, # Replace MLA attention with cuTile
|
||||
moe=True, # Replace MoE routing with cuTile
|
||||
use_cutile=True
|
||||
)
|
||||
```
|
||||
|
||||
| Kernel | PyTorch Operation | cuTile Replacement |
|
||||
|--------|-------------------|-------------------|
|
||||
| `prefill_mla` | MLA prefill attention | Multi-head Latent Attention |
|
||||
| `_mla_decoding_split_kv` | MLA decode attention | Split-KV decoding |
|
||||
| `fused_moe_kernel` | MoE expert routing | Fused MoE |
|
||||
| `group_gemm_kernel` | Expert FFN | Grouped GEMM |
|
||||
|
||||
## Step 5. Platform-specific tuning (Advanced)
|
||||
|
||||
cuTile exposes two complementary performance-tuning mechanisms:
|
||||
|
||||
- **[`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Select different kernel launch parameters per GPU architecture (`sm_<major><minor>`). The compiler picks the value matching the current target at JIT time; if no entry matches, the `default` value is used. See the [Performance Tuning](https://docs.nvidia.com/cuda/cutile-python/performance.html) and [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) pages.
|
||||
- **`num_ctas`** - Number of Cooperative Thread Arrays (thread blocks) launched per kernel invocation. Tune to the number of SMs on the target GPU.
|
||||
- **`occupancy`** - Hint for the number of concurrent CTAs the compiler should target per SM. Higher occupancy hides memory latency but increases register/shared-memory pressure. See the [Execution Model](https://docs.nvidia.com/cuda/cutile-python/execution.html) documentation.
|
||||
- **[`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html)** - Search a list of candidate values at runtime and pick the fastest configuration. Results are reported via [`cuda.tile.tune.TuningResult`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.TuningResult.html) / [`Measurement`](https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tune.Measurement.html).
|
||||
|
||||
```python
|
||||
import cuda.tile as ct
|
||||
|
||||
@ct.kernel(
|
||||
# # num_ctas: how many thread blocks to launch.
|
||||
# # Use ByTarget to pick an arch-specific value at JIT time.
|
||||
num_ctas=ct.ByTarget({
|
||||
"sm_103": 8, # B300 - more SMs, launch more CTAs
|
||||
"sm_121": 4, # DGX Spark - fewer SMs (48), use fewer CTAs
|
||||
"default": 1, # Fallback for any other GPU architecture
|
||||
}),
|
||||
# # occupancy: hint for concurrent CTAs per SM (latency hiding vs. register pressure).
|
||||
occupancy=ct.ByTarget({
|
||||
"sm_103": 16, # B300 - high occupancy, plenty of registers/SMEM
|
||||
"sm_121": 12, # DGX Spark - moderate occupancy
|
||||
"default": 8, # Conservative fallback
|
||||
}),
|
||||
opt_level=3 # Maximum compiler optimization level
|
||||
)
|
||||
def optimized_kernel(A, B, C):
|
||||
# # Same kernel code works on all platforms;
|
||||
# # ByTarget swaps in the arch-specific launch params automatically.
|
||||
...
|
||||
```
|
||||
|
||||
For automatic tuning, use [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search over candidate values and pick the fastest configuration at runtime:
|
||||
|
||||
```python
|
||||
@ct.kernel(
|
||||
# # autotune: benchmark each value and pick the fastest.
|
||||
num_ctas=ct.autotune([1, 2, 4, 8, 16]),
|
||||
occupancy=ct.autotune([8, 12, 16, 24]),
|
||||
opt_level=3
|
||||
)
|
||||
def autotuned_kernel(A, B, C):
|
||||
...
|
||||
```
|
||||
|
||||
## Step 6. Repeat on B300
|
||||
|
||||
Repeat Steps 1-3 on B300 hardware. The **same code runs without modification** - cuTile JIT compiles for sm_103 automatically.
|
||||
|
||||
See the **Platform Comparison** tab for detailed scaling results.
|
||||
|
||||
## FMHA Implementation
|
||||
|
||||
## FMHA Implementation Guide
|
||||
|
||||
> [!NOTE]
|
||||
> This is a guide to understanding FMHA implementation in cuTile, not a complete reference. For comprehensive documentation, see the [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/).
|
||||
|
||||
### Attention Basics
|
||||
|
||||
Attention allows a neural network to focus on relevant parts of the input. In transformers (GPT, LLaMA, Qwen), each position computes how much to attend to every other position using three vectors:
|
||||
|
||||
- **Query (Q)**: "What am I looking for?"
|
||||
- **Key (K)**: "What do I contain?"
|
||||
- **Value (V)**: "Here is my content"
|
||||
|
||||
```text
|
||||
Attention(Q, K, V) = softmax(Q × K^T / √d) × V
|
||||
|
||||
Shapes:
|
||||
Q, K, V = [batch, heads, seq_len, head_dim]
|
||||
Q × K^T = [batch, heads, seq_len, seq_len] # Attention scores
|
||||
Output = [batch, heads, seq_len, head_dim]
|
||||
```
|
||||
|
||||
For autoregressive models, **causal masking** ensures each token only attends to previous tokens by setting future scores to -infinity before softmax.
|
||||
|
||||
### Flash Attention Algorithm
|
||||
|
||||
Standard attention materializes a [seq_len × seq_len] matrix (e.g., 2 GB for seq_len=32768). Flash Attention avoids this by processing in tiles with **online softmax**:
|
||||
|
||||
```text
|
||||
m = -infinity # Running maximum
|
||||
l = 0 # Running sum of exp(x - m)
|
||||
acc = 0 # Running weighted sum of values
|
||||
|
||||
FOR each K,V tile:
|
||||
scores = Q_tile @ K_tile.T * scale
|
||||
m_new = max(m, max(scores))
|
||||
correction = exp(m - m_new)
|
||||
l = l * correction + sum(exp(scores - m_new))
|
||||
acc = acc * correction + exp(scores - m_new) @ V_tile
|
||||
m = m_new
|
||||
|
||||
output = acc / l
|
||||
```
|
||||
|
||||
### cuTile Pseudocode → Actual Mapping
|
||||
|
||||
| Concept | Pseudocode | cuTile |
|
||||
|---|---|---|
|
||||
| Define kernel | `KERNEL fmha(...)` | `@ct.kernel()` |
|
||||
| Get block ID | `block_x = BLOCK_ID_X` | `bid_x = ct.bid(0)` |
|
||||
| Create indices | `range(0, N)` | `ct.arange(N, dtype=ct.int32)` |
|
||||
| Create constant tile | `tile = zeros(M, N)` | `ct.full((M, N), 0.0, dtype)` |
|
||||
| Load from memory | `tile = LOAD(ptr, shape)` | `ct.load(tensor, index, shape)` |
|
||||
| Store to memory | `STORE(ptr, tile)` | `ct.store(tensor, index, tile)` |
|
||||
| Matrix multiply | `C = A @ B + C` | `ct.mma(A, B, C)` |
|
||||
| Reduction | `max_val = MAX(tile, axis)` | `ct.max(tile, axis, keepdims)` |
|
||||
|
||||
### Kernel Pseudocode
|
||||
|
||||
```text
|
||||
KERNEL fmha(Q, K, V, Out, scale, TILE_M, TILE_N):
|
||||
tile_row = BLOCK_ID_X
|
||||
batch_head = BLOCK_ID_Y
|
||||
batch = batch_head // num_heads
|
||||
head = batch_head % num_heads
|
||||
|
||||
m_i = full(TILE_M, -infinity)
|
||||
l_i = full(TILE_M, 0)
|
||||
acc = zeros(TILE_M, head_dim)
|
||||
|
||||
q = LOAD(Q[batch, head, tile_row*TILE_M : (tile_row+1)*TILE_M, :])
|
||||
|
||||
FOR j = 0 to num_k_tiles:
|
||||
k = LOAD(K[batch, head, j*TILE_N : (j+1)*TILE_N, :])
|
||||
v = LOAD(V[batch, head, j*TILE_N : (j+1)*TILE_N, :])
|
||||
scores = MMA(q, transpose(k)) * scale
|
||||
IF causal AND in_mask_region:
|
||||
scores = WHERE(valid_mask, scores, -infinity)
|
||||
m_new = max(m_i, row_max(scores))
|
||||
correction = exp(m_i - m_new)
|
||||
p = exp(scores - m_new)
|
||||
l_i = l_i * correction + row_sum(p)
|
||||
acc = acc * correction + MMA(p, v)
|
||||
m_i = m_new
|
||||
|
||||
out = acc / l_i
|
||||
STORE(Out[batch, head, tile_row*TILE_M :, :], out)
|
||||
```
|
||||
|
||||
### cuTile Implementation
|
||||
|
||||
```python
|
||||
import cuda.tile as ct
|
||||
import math
|
||||
ConstInt = ct.Constant[int]
|
||||
ConstBool = ct.Constant[bool]
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_kernel(Q, K, V, Out, qk_scale: float, TILE_D: ConstInt, H: ConstInt,
|
||||
TILE_M: ConstInt, TILE_N: ConstInt, CAUSAL: ConstBool):
|
||||
bid_x, bid_y = ct.bid(0), ct.bid(1)
|
||||
batch_idx, head_idx = bid_y // H, bid_y % H
|
||||
|
||||
offs_m = (bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32))[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0),
|
||||
shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
Tc = ct.cdiv(min((bid_x + 1) * TILE_M, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.mma(q, k_t, ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32))
|
||||
qk = qk * qk_scale
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
qk = ct.where(offs_m >= offs_n, qk,
|
||||
ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.maximum(m_i, ct.max(qk, axis=-1, keepdims=True))
|
||||
qk = qk - m_ij
|
||||
p = ct.exp(qk)
|
||||
alpha = ct.exp(m_i - m_ij)
|
||||
l_i = l_i * alpha + ct.sum(p, axis=-1, keepdims=True)
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
|
||||
acc = ct.mma(p.astype(Q.dtype), v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = (acc / l_i).reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
```
|
||||
|
||||
### Launching the Kernel
|
||||
|
||||
```python
|
||||
def run_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
import torch
|
||||
TILE_M, TILE_N = 64, 64 # Platform-specific (see below)
|
||||
batch, num_heads, seq_len, head_dim = q.shape
|
||||
out = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(), grid, fmha_kernel,
|
||||
(q, k, v, out, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return out
|
||||
```
|
||||
|
||||
### Optimizations
|
||||
|
||||
#### exp2 + flush_to_zero
|
||||
|
||||
`exp2(x) = 2^x` is faster than `exp(x)` on GPU. Requires scale adjustment by `1/log(2)`.
|
||||
|
||||
```python
|
||||
## Convert natural-exp scale to base-2 so we can use the faster ct.exp2 intrinsic.
|
||||
## exp(x) == exp2(x / log(2)) == exp2(x * INV_LOG_2).
|
||||
INV_LOG_2 = 1.0 / math.log(2) # ≈ 1.4427
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2 # Pre-multiply the softmax scale once
|
||||
|
||||
## ... in loop:
|
||||
## Fuse the running-max update with the scale multiplication.
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
## Subtract the running max for numerical stability (online softmax).
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
## flush_to_zero=True: flush denormals to 0 -> avoids slow denormal handling on GPU.
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # Correction factor for previous acc/l_i
|
||||
```
|
||||
|
||||
#### Load Order Transpose
|
||||
|
||||
Load K already transposed using `order` parameter, avoiding explicit permute.
|
||||
|
||||
```python
|
||||
## order=(0,1,3,2) swaps the last two axes during the load,
|
||||
## producing K^T directly in registers -- no extra ct.permute() needed.
|
||||
## shape is expressed in the transposed layout: (1, 1, TILE_D, TILE_N).
|
||||
k_t = ct.load(K, index=(..., 0, j), shape=(1,1,TILE_D,TILE_N),
|
||||
order=(0,1,3,2)).reshape((TILE_D, TILE_N))
|
||||
```
|
||||
|
||||
#### Latency Hints
|
||||
|
||||
Prefetch data to overlap memory loads with computation. See the [Performance Tuning docs](https://docs.nvidia.com/cuda/cutile-python/performance.html) for the full list of load/store hints (e.g. `allow_tma`, `latency`).
|
||||
|
||||
```python
|
||||
## latency=N tells the compiler to issue this load N loop iterations in
|
||||
## advance of its use, so the memory transfer overlaps with the MMA work
|
||||
## from earlier iterations. Larger latency = deeper software pipeline but
|
||||
## more register pressure.
|
||||
k_t = ct.load(K, ..., latency=2) # Prefetch K 2 iterations ahead
|
||||
v_tile = ct.load(V, ..., latency=4) # Prefetch V 4 iterations ahead (used later in the loop)
|
||||
```
|
||||
|
||||
#### Occupancy
|
||||
|
||||
Allow multiple thread blocks per SM to hide memory latency. See the [Execution Model docs](https://docs.nvidia.com/cuda/cutile-python/execution.html) for details on how `occupancy` interacts with registers and shared memory.
|
||||
|
||||
```python
|
||||
## occupancy=N is a hint to the compiler to target N concurrent CTAs per SM.
|
||||
## Higher occupancy -> more warps available to hide memory latency,
|
||||
## but constrains the per-CTA register/SMEM budget.
|
||||
@ct.kernel(occupancy=2) # 2 thread blocks (CTAs) co-resident per SM
|
||||
def fmha_optimized(...):
|
||||
```
|
||||
|
||||
#### Approximate Division
|
||||
|
||||
Use fast approximate division for final normalization.
|
||||
|
||||
```python
|
||||
from cuda.tile import RoundingMode as RMd
|
||||
## RMd.APPROX -> hardware approximate reciprocal/divide (MUFU), much faster
|
||||
## than IEEE-compliant division. Safe here because it's the final softmax
|
||||
## normalization step where a small ULP error is acceptable.
|
||||
## flush_to_zero=True flushes denormals to 0 to avoid the slow path.
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
```
|
||||
|
||||
### Platform Configuration
|
||||
|
||||
The same kernel code works on all platforms; only configuration parameters change. Use [`ct.ByTarget`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to select values per architecture, or [`ct.autotune`](https://docs.nvidia.com/cuda/cutile-python/performance.html) to search candidate values automatically.
|
||||
|
||||
| Platform | TILE_M | TILE_N | Occupancy | Rationale |
|
||||
|---|---|---|---|---|
|
||||
| DGX Spark (sm_121) | 64 | 64 | 2 | Smaller tiles, higher occupancy for 48 SMs |
|
||||
| B300 (sm_103) | 256 | 128 | 1 | Large tiles maximize HBM3e throughput |
|
||||
| B300 alternate | 128 | 128 | 2 | Higher occupancy, balanced parallelism |
|
||||
|
||||
```python
|
||||
import cuda.tile as ct
|
||||
|
||||
@ct.kernel(
|
||||
# # TILE_M / TILE_N: rows/cols of the Q and K/V tiles processed per CTA.
|
||||
# # Larger tiles -> more arithmetic intensity; smaller tiles -> higher occupancy.
|
||||
# # occupancy: target concurrent CTAs per SM (latency hiding vs. register pressure).
|
||||
occupancy=ct.ByTarget({
|
||||
"sm_121": 2, # DGX Spark (48 SMs): 2 CTAs/SM for latency hiding
|
||||
"sm_100": 1, # B300: larger tiles already saturate the SM
|
||||
"default": 1, # Conservative fallback for other architectures
|
||||
}),
|
||||
opt_level=3 # Maximum compiler optimization level
|
||||
)
|
||||
def fmha_kernel(...):
|
||||
...
|
||||
```
|
||||
|
||||
### Performance Results
|
||||
|
||||
> **Note:** PyTorch SDPA is used for correctness verification only, not performance comparison.
|
||||
|
||||
#### DGX Spark (sm_121) — Seq 2048
|
||||
|
||||
| Step | Optimization | Latency (ms) | TFLOPS |
|
||||
|---|---|---|---|
|
||||
| 1 | Basic cuTile | 2.19 | 62.8 |
|
||||
| 2 | + exp2 | 2.07 | 66.5 |
|
||||
| 3 | + Load Order | 2.07 | 66.3 |
|
||||
| 4 | + Latency Hints | 2.07 | 66.5 |
|
||||
| 5 | + Occupancy=2 | 1.73 | 79.5 |
|
||||
| 6 | + Approx Div (Final) | 1.69 | 81.1 |
|
||||
|
||||
#### B300 (sm_103) — Various Seq Lengths
|
||||
|
||||
| Seq Len | Latency (ms) | TFLOPS | vs Spark |
|
||||
|---|---|---|---|
|
||||
| 1024 | 0.074 | 465 | 5.7x |
|
||||
| 2048 | 0.178 | 770 | 9.5x |
|
||||
| 4096 | 0.550 | 999 | 15.1x |
|
||||
| 8192 | 1.897 | 1159 | 14.6x |
|
||||
| 16384 | 7.014 | 1254 | 14.2x |
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Issue | Solution |
|
||||
|---|---|
|
||||
| Shape mismatch in ct.mma | Ensure A is (M,K), B is (K,N), C is (M,N) |
|
||||
| dtype errors | Use `.astype()` before mma; accumulator should be float32 |
|
||||
| Incorrect results with causal | Check mask_start calculation and `offs_m >= offs_n` logic |
|
||||
| Low performance | Try different TILE_M/N, check occupancy, verify latency hints |
|
||||
|
||||
### Companion Scripts
|
||||
|
||||
The following scripts are included in this playbook and can be run on DGX Spark or B300:
|
||||
|
||||
- **[`assets/fmha_optimization_tutorial.py`](assets/fmha_optimization_tutorial.py)** — Step-by-step optimization tutorial. Builds the FMHA kernel from basic to fully optimized, matching the progression in this guide.
|
||||
- **[`assets/fmha_scaling_analysis.py`](assets/fmha_scaling_analysis.py)** — Scaling analysis across sequence lengths. Benchmarks each optimization level and generates performance data.
|
||||
|
||||
```bash
|
||||
## Run the optimization tutorial (DGX Spark)
|
||||
python assets/fmha_optimization_tutorial.py --correctness-check
|
||||
|
||||
## Run the scaling analysis
|
||||
python assets/fmha_scaling_analysis.py --iterations 100
|
||||
```
|
||||
|
||||
### References
|
||||
|
||||
- [cuTile Python Documentation](https://docs.nvidia.com/cuda/cutile-python/)
|
||||
- [Tile IR Specification](https://docs.nvidia.com/cuda/tile-ir/)
|
||||
- [TileGym (pre-optimized kernels)](https://github.com/NVIDIA/TileGym)
|
||||
- [NVIDIA Blog: Tuning Flash Attention for Peak Performance in CUDA Tile](https://developer.nvidia.com/blog/tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile/)
|
||||
- [Flash Attention Paper](https://arxiv.org/abs/2205.14135)
|
||||
|
||||
## Platform Comparison
|
||||
|
||||
## DGX Spark vs B300 Performance Comparison
|
||||
|
||||
This page summarizes performance scaling between DGX Spark (GB10) and B300 for both kernel benchmarks and end-to-end LLM inference.
|
||||
|
||||
## Kernel Benchmark Scaling
|
||||
|
||||
Use the ratios below as a reference for how kernel performance scales from DGX Spark (GB10) to B300.
|
||||
|
||||
| Kernel | Metric | B300 / GB10 |
|
||||
|--------|--------|-------------|
|
||||
| FMHA (causal, 8192) | TFLOPS | 13.7x |
|
||||
| FMHA (non-causal, 8192) | TFLOPS | 15.1x |
|
||||
| MatMul (8192) | TFLOPS | 18.9x |
|
||||
| BMM (batch8, 4096) | TFLOPS | 19.4x |
|
||||
| Group GEMM (4096) | TFLOPS | 23.9x |
|
||||
| RMSNorm (4096) | GB/s | 33.1x |
|
||||
| RoPE (16384) | GB/s | 22.8x |
|
||||
|
||||
**Key Observations:**
|
||||
- Compute-heavy kernels typically scale 14-24x from GB10 to B300
|
||||
- Memory-bound kernels can scale 20-33x due to HBM bandwidth advantage
|
||||
|
||||
## Qwen2-7B Performance
|
||||
|
||||
### End-to-End Throughput
|
||||
|
||||
| Configuration | DGX Spark | B300 | Platform Speedup |
|
||||
|---------------|-----------|------|------------------|
|
||||
| **cuTile** | 18.52 tok/s | 257.33 tok/s | **13.9x** |
|
||||
|
||||
### CUDA Kernel Time
|
||||
|
||||
| Configuration | DGX Spark | B300 | Platform Speedup |
|
||||
|---------------|-----------|------|------------------|
|
||||
| **cuTile** | 43,080 ms | 2,954 ms | **14.6x** |
|
||||
|
||||
### cuTile Kernel Breakdown
|
||||
|
||||
**DGX Spark (GB10):**
|
||||
|
||||
| Kernel | CUDA Time (ms) | Calls |
|
||||
|--------|----------------|-------|
|
||||
| `fmha_kernel` | 4,185.9 | 28 |
|
||||
| `swiglu_forward_kernel` | 2,459.8 | 1,400 |
|
||||
| `attention_decode_kernel_grouped` | 2,271.8 | 1,372 |
|
||||
| `rms_norm_kernel_static_persistent` | 634.7 | 57 |
|
||||
| `rope_kernel` | 355.6 | 1,400 |
|
||||
|
||||
**B300:**
|
||||
|
||||
| Kernel | CUDA Time (ms) | Speedup vs Spark |
|
||||
|--------|----------------|------------------|
|
||||
| `fmha_kernel` | 337.9 | 12.4x |
|
||||
| `swiglu_forward_kernel` | 226.3 | 10.9x |
|
||||
| `attention_decode_kernel_grouped` | 111.0 | 20.5x |
|
||||
| `rms_norm_kernel_static_persistent` | 29.7 | 21.4x |
|
||||
| `rope_kernel` | 16.7 | 21.3x |
|
||||
|
||||
**Same code, different architectures** - cuTile JIT compiles for sm_121 (Spark) and sm_103 (B300)
|
||||
|
||||
## Platform Specifications
|
||||
|
||||
| Specification | DGX Spark (GB10) | B300 |
|
||||
|---------------|------------------|------|
|
||||
| Compute Capability | sm_121 (12.1) | sm_103 (10.3) |
|
||||
| SMs | 48 | 132 |
|
||||
| Memory | 128 GB LPDDR5x | 192 GB HBM3e |
|
||||
| Memory Bandwidth | 273 GB/s | 8 TB/s |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| `docker: permission denied` | User not in docker group | `sudo usermod -aG docker $USER && newgrp docker` |
|
||||
| `401 Client Error: Unauthorized` | Missing HuggingFace token | `export HF_TOKEN=<your_token>` |
|
||||
| `ModuleNotFoundError: tilegym` | TileGym not installed | `cd TileGym && pip install .` |
|
||||
| `RuntimeError: CUDA out of memory` | Model too large | Reduce batch size or use smaller model |
|
||||
| `Killed` during model load | Out of system memory | Clear cache: `sync; echo 3 > /proc/sys/vm/drop_caches` |
|
||||
| Slow first run | JIT compilation | Normal - cuTile compiles kernels on first run |
|
||||
| `FileNotFoundError: input_prompt_small.txt` | Missing input file | Run from `modeling/transformers` directory |
|
||||
| `torch.cuda.OutOfMemoryError` | Insufficient GPU memory | Reduce `--batch_size` parameter |
|
||||
| `ImportError: cuda.tile` | Missing Tile IR | Install: `apt-get install cuda-tile-ir-13-1` |
|
||||
| Benchmark hangs | GPU busy or locked | Check `nvidia-smi` for other processes |
|
||||
|
||||
> [!NOTE]
|
||||
> DGX Spark uses a Unified Memory Architecture (UMA), which enables dynamic memory sharing between the GPU and CPU.
|
||||
> With many applications still updating to take advantage of UMA, you may encounter memory issues even when within
|
||||
> the memory capacity of DGX Spark. If that happens, manually flush the buffer cache with:
|
||||
|
||||
```bash
|
||||
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> First run of cuTile kernels includes JIT compilation overhead. Subsequent runs will be faster as compiled kernels are cached.
|
||||
|
||||
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).
|
||||
959
nvidia/cutile-kernels/assets/fmha_optimization_tutorial.py
Normal file
959
nvidia/cutile-kernels/assets/fmha_optimization_tutorial.py
Normal file
@ -0,0 +1,959 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FMHA Optimization Tutorial: From Naive to Optimized cuTile Implementation
|
||||
|
||||
This script demonstrates step-by-step optimization of Flash Multi-Head Attention
|
||||
using NVIDIA cuTile, starting from a basic implementation and progressively
|
||||
adding optimizations until reaching TileGym-level performance.
|
||||
|
||||
Target Platform: DGX Spark (sm121) with pre-determined optimal tile sizes.
|
||||
Note: TileGym supports autotuning, but we use hardcoded values for this tutorial.
|
||||
|
||||
Configuration (matches TileGym bench_fused_attention.py):
|
||||
- Batch: 4, Heads: 32, Head Dim: 128
|
||||
- Sequence Lengths: 1024, 2048, 4096, 8192, 16384
|
||||
- Benchmark: triton.testing.do_bench_cudagraph
|
||||
|
||||
Usage:
|
||||
python fmha_optimization_tutorial.py [--iterations N] [--correctness-check]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
LOG_SEPARATOR = "=" * 80
|
||||
LOG_SUBSEPARATOR = "-" * 60
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
step: int
|
||||
name: str
|
||||
description: str
|
||||
latency_ms: float
|
||||
tflops: float
|
||||
speedup_vs_baseline: float
|
||||
speedup_vs_previous: float
|
||||
correct: bool
|
||||
key_changes: List[str]
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
self.results: List[BenchmarkResult] = []
|
||||
self.logs: List[str] = []
|
||||
|
||||
def log(self, msg: str):
|
||||
print(msg)
|
||||
self.logs.append(msg)
|
||||
|
||||
def section(self, title: str):
|
||||
self.log(f"\n{LOG_SEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SEPARATOR)
|
||||
|
||||
def subsection(self, title: str):
|
||||
self.log(f"\n{LOG_SUBSEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SUBSEPARATOR)
|
||||
|
||||
def add_result(self, result: BenchmarkResult):
|
||||
self.results.append(result)
|
||||
|
||||
def export_json(self, filepath: str):
|
||||
data = {
|
||||
"results": [asdict(r) for r in self.results],
|
||||
"logs": self.logs
|
||||
}
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def export_markdown(self, filepath: str):
|
||||
with open(filepath, 'w') as f:
|
||||
f.write("# FMHA Optimization Tutorial Results\n\n")
|
||||
f.write("## Summary Table\n\n")
|
||||
f.write("| Step | Name | Latency (ms) | TFLOPS | vs Baseline | vs Previous | Correct |\n")
|
||||
f.write("|------|------|--------------|--------|-------------|-------------|--------|\n")
|
||||
for r in self.results:
|
||||
f.write(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {r.speedup_vs_previous:.2f}x | {'Yes' if r.correct else 'No'} |\n")
|
||||
f.write("\n## Detailed Steps\n\n")
|
||||
for r in self.results:
|
||||
f.write(f"### Step {r.step}: {r.name}\n\n")
|
||||
f.write(f"**Description**: {r.description}\n\n")
|
||||
f.write("**Key Changes**:\n")
|
||||
for change in r.key_changes:
|
||||
f.write(f"- {change}\n")
|
||||
f.write(f"\n**Performance**: {r.latency_ms:.3f}ms, {r.tflops:.2f} TFLOPS, {r.speedup_vs_baseline:.2f}x vs baseline\n\n")
|
||||
|
||||
logger = Logger()
|
||||
|
||||
BATCH = 4
|
||||
N_HEADS = 32
|
||||
HEAD_DIM = 128
|
||||
INV_LOG_2 = 1.0 / math.log(2)
|
||||
|
||||
TILE_M = 64
|
||||
TILE_N = 64
|
||||
OCCUPANCY = 2
|
||||
NUM_CTAS = 1
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
raise RuntimeError("CUDA not available")
|
||||
|
||||
DEVICE = None
|
||||
|
||||
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
|
||||
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
|
||||
total_flops = 2 * flops_per_matmul
|
||||
if causal:
|
||||
total_flops *= 0.5
|
||||
return total_flops
|
||||
|
||||
def benchmark_fn(fn, warmup=10, iterations=100):
|
||||
"""Benchmark using triton's do_bench_cudagraph for accurate timing (matches TileGym)"""
|
||||
try:
|
||||
import triton
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
return ms
|
||||
except (ImportError, Exception):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
return (end - start) / iterations * 1000
|
||||
|
||||
def verify_correctness(output, reference, atol=1e-2, rtol=1e-2):
|
||||
try:
|
||||
torch.testing.assert_close(output, reference, atol=atol, rtol=rtol)
|
||||
return True
|
||||
except AssertionError:
|
||||
max_diff = (output - reference).abs().max().item()
|
||||
logger.log(f" [WARN] Max difference: {max_diff:.6f}")
|
||||
return max_diff < 0.1
|
||||
|
||||
def reference_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
|
||||
)
|
||||
|
||||
def step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True):
|
||||
return reference_fmha(q, k, v, sm_scale, is_causal)
|
||||
|
||||
try:
|
||||
import cuda.tile as ct
|
||||
from cuda.tile import RoundingMode as RMd
|
||||
CUTILE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CUTILE_AVAILABLE = False
|
||||
logger.log("[WARN] cuTile not available. Only PyTorch baseline will run.")
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
ConstInt = ct.Constant[int]
|
||||
ConstBool = ct.Constant[bool]
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step2_mma(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 2: Basic cuTile FMHA with MMA (Tensor Cores)
|
||||
- Uses ct.mma() for matrix multiply
|
||||
- Standard exp() for softmax
|
||||
- Online softmax algorithm
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
qk = qk * qk_scale
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True)
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk - m_ij
|
||||
|
||||
p = ct.exp(qk)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step2(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step2_mma,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step3_exp2(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 3: Use exp2 with flush_to_zero for faster math
|
||||
- exp2(x) = 2^x is faster than exp(x) = e^x on GPU
|
||||
- Requires scaling adjustment: multiply by 1/log(2)
|
||||
- flush_to_zero handles denormals efficiently
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step3(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step3_exp2,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step4_load_order(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 4: Optimize K load with order parameter
|
||||
- Use order=(0,1,3,2) to load K already transposed
|
||||
- Avoids explicit ct.permute() operation
|
||||
- Reduces memory traffic
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2)
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step4(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step4_load_order,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_step5_latency(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 5: Add latency hints for better pipelining
|
||||
- latency=2 for K load (prefetch)
|
||||
- latency=4 for V load (more prefetch distance)
|
||||
- Helps overlap memory loads with computation
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step5(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step5_latency,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel(occupancy=2)
|
||||
def fmha_step6_occupancy(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 6: Add occupancy hint
|
||||
- @ct.kernel(occupancy=2) improves SM utilization
|
||||
- Allows multiple thread blocks per SM
|
||||
- Better for hiding memory latency
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step6(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step6_occupancy,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
@ct.kernel(occupancy=2)
|
||||
def fmha_step7_approx_div(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""
|
||||
Step 7: Use approximate division for final normalization
|
||||
- ct.truediv with rounding_mode=APPROX is faster
|
||||
- Acceptable accuracy loss for inference
|
||||
- This matches TileGym's optimized implementation
|
||||
"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
|
||||
m_i = m_ij
|
||||
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_step7(q, k, v, sm_scale, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / TILE_M), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
fmha_step7_approx_div,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
"""Run TileGym's optimized FMHA for comparison"""
|
||||
try:
|
||||
import tilegym
|
||||
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
|
||||
except ImportError:
|
||||
logger.log("[WARN] TileGym not available for comparison")
|
||||
return None
|
||||
|
||||
|
||||
def run_benchmark(seq_len, iterations=100, check_correct=True):
|
||||
global DEVICE
|
||||
DEVICE = get_device()
|
||||
|
||||
logger.section(f"FMHA OPTIMIZATION TUTORIAL - SEQ_LEN={seq_len}")
|
||||
logger.log("Configuration:")
|
||||
logger.log(f" - Batch: {BATCH}")
|
||||
logger.log(f" - Heads: {N_HEADS}")
|
||||
logger.log(f" - Head Dim: {HEAD_DIM}")
|
||||
logger.log(f" - Sequence Length: {seq_len}")
|
||||
logger.log(f" - Tile M: {TILE_M}")
|
||||
logger.log(f" - Tile N: {TILE_N}")
|
||||
logger.log(" - Precision: float16")
|
||||
logger.log(" - Causal: True")
|
||||
logger.log(f" - Iterations: {iterations}")
|
||||
logger.log(f" - Device: {DEVICE}")
|
||||
|
||||
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
||||
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
||||
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=DEVICE)
|
||||
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
|
||||
|
||||
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
|
||||
|
||||
ref_output = reference_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
|
||||
steps = [
|
||||
(0, "PyTorch Baseline", "torch.nn.functional.scaled_dot_product_attention",
|
||||
lambda: step0_pytorch_baseline(q, k, v, sm_scale, is_causal=True),
|
||||
["PyTorch SDPA with cuDNN backend", "Highly optimized baseline"]),
|
||||
]
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
steps.extend([
|
||||
(2, "Basic cuTile + MMA", "Tiled FMHA with ct.mma() for Tensor Cores",
|
||||
lambda: run_step2(q, k, v, sm_scale, is_causal=True),
|
||||
["@ct.kernel decorator", "ct.mma() for QK and PV products", "Online softmax with exp()"]),
|
||||
|
||||
(3, "+ exp2 + flush_to_zero", "Faster exponential math",
|
||||
lambda: run_step3(q, k, v, sm_scale, is_causal=True),
|
||||
["ct.exp2() instead of ct.exp()", "flush_to_zero=True for denormals", "qk_scale *= 1/log(2)"]),
|
||||
|
||||
(4, "+ Load Order Transpose", "Avoid explicit transpose",
|
||||
lambda: run_step4(q, k, v, sm_scale, is_causal=True),
|
||||
["order=(0,1,3,2) for K load", "K loaded already transposed", "Removes ct.permute() call"]),
|
||||
|
||||
(5, "+ Latency Hints", "Better memory pipelining",
|
||||
lambda: run_step5(q, k, v, sm_scale, is_causal=True),
|
||||
["latency=2 for K load", "latency=4 for V load", "Overlaps loads with compute"]),
|
||||
|
||||
(6, "+ Occupancy=2", "Better SM utilization",
|
||||
lambda: run_step6(q, k, v, sm_scale, is_causal=True),
|
||||
["@ct.kernel(occupancy=2)", "Multiple blocks per SM", "Hides memory latency"]),
|
||||
|
||||
(7, "+ Approx Division (Final)", "Fast final normalization",
|
||||
lambda: run_step7(q, k, v, sm_scale, is_causal=True),
|
||||
["ct.truediv with APPROX mode", "Matches TileGym implementation", "Full optimization achieved"]),
|
||||
])
|
||||
|
||||
baseline_latency = None
|
||||
prev_latency = None
|
||||
|
||||
for step_idx, name, desc, fn, changes in steps:
|
||||
logger.subsection(f"Step {step_idx}: {name}")
|
||||
logger.log(f"Description: {desc}")
|
||||
logger.log("Key Changes:")
|
||||
for change in changes:
|
||||
logger.log(f" - {change}")
|
||||
|
||||
try:
|
||||
output = fn()
|
||||
latency_ms = benchmark_fn(fn, warmup=10, iterations=iterations)
|
||||
tflops = flops * 1e-12 / (latency_ms * 1e-3)
|
||||
|
||||
if baseline_latency is None:
|
||||
baseline_latency = latency_ms
|
||||
speedup_baseline = 1.0
|
||||
else:
|
||||
speedup_baseline = baseline_latency / latency_ms
|
||||
|
||||
if prev_latency is None:
|
||||
speedup_prev = 1.0
|
||||
else:
|
||||
speedup_prev = prev_latency / latency_ms
|
||||
|
||||
if check_correct and output is not None:
|
||||
correct = verify_correctness(output, ref_output)
|
||||
else:
|
||||
correct = True
|
||||
|
||||
logger.log("\nResults:")
|
||||
logger.log(f" Latency: {latency_ms:.3f} ms")
|
||||
logger.log(f" TFLOPS: {tflops:.2f}")
|
||||
logger.log(f" vs Baseline: {speedup_baseline:.2f}x")
|
||||
logger.log(f" vs Previous: {speedup_prev:.2f}x")
|
||||
logger.log(f" Correct: {'Yes' if correct else 'No'}")
|
||||
|
||||
result = BenchmarkResult(
|
||||
step=step_idx,
|
||||
name=name,
|
||||
description=desc,
|
||||
latency_ms=latency_ms,
|
||||
tflops=tflops,
|
||||
speedup_vs_baseline=speedup_baseline,
|
||||
speedup_vs_previous=speedup_prev,
|
||||
correct=correct,
|
||||
key_changes=changes
|
||||
)
|
||||
logger.add_result(result)
|
||||
|
||||
prev_latency = latency_ms
|
||||
|
||||
except Exception as e:
|
||||
logger.log(f"\n[ERROR] Step {step_idx} failed: {e}")
|
||||
import traceback
|
||||
logger.log(traceback.format_exc())
|
||||
|
||||
tilegym_output = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
if tilegym_output is not None:
|
||||
logger.subsection("TileGym Reference (for comparison)")
|
||||
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
|
||||
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
|
||||
tilegym_speedup = baseline_latency / tilegym_latency if baseline_latency else 1.0
|
||||
|
||||
logger.log("TileGym FMHA:")
|
||||
logger.log(f" Latency: {tilegym_latency:.3f} ms")
|
||||
logger.log(f" TFLOPS: {tilegym_tflops:.2f}")
|
||||
logger.log(f" vs Baseline: {tilegym_speedup:.2f}x")
|
||||
|
||||
result = BenchmarkResult(
|
||||
step=99,
|
||||
name="TileGym Reference",
|
||||
description="TileGym's optimized FMHA implementation",
|
||||
latency_ms=tilegym_latency,
|
||||
tflops=tilegym_tflops,
|
||||
speedup_vs_baseline=tilegym_speedup,
|
||||
speedup_vs_previous=1.0,
|
||||
correct=True,
|
||||
key_changes=["Full TileGym implementation", "Pre-tuned for sm121", "Production ready"]
|
||||
)
|
||||
logger.add_result(result)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FMHA Optimization Tutorial")
|
||||
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
|
||||
parser.add_argument("--seq-len", type=int, default=2048, help="Sequence length (default matches TileGym)")
|
||||
parser.add_argument("--correctness-check", action="store_true", help="Enable correctness checking")
|
||||
parser.add_argument("--output-dir", type=str, default=".", help="Output directory for logs")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.section("FMHA OPTIMIZATION TUTORIAL")
|
||||
logger.log("From Naive to Optimized cuTile Implementation")
|
||||
logger.log("Target Platform: DGX Spark (sm121)")
|
||||
logger.log(f"Tile Sizes: TILE_M={TILE_M}, TILE_N={TILE_N} (hardcoded from TileGym)")
|
||||
logger.log("Note: TileGym supports autotuning, but we use pre-determined optimal values")
|
||||
|
||||
run_benchmark(
|
||||
seq_len=args.seq_len,
|
||||
iterations=args.iterations,
|
||||
check_correct=args.correctness_check
|
||||
)
|
||||
|
||||
logger.section("FINAL SUMMARY")
|
||||
logger.log("\n| Step | Name | Latency (ms) | TFLOPS | vs Baseline | Correct |")
|
||||
logger.log("|------|------|--------------|--------|-------------|---------|")
|
||||
for r in logger.results:
|
||||
logger.log(f"| {r.step} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.2f} | {r.speedup_vs_baseline:.2f}x | {'Yes' if r.correct else 'No'} |")
|
||||
|
||||
json_path = f"{args.output_dir}/fmha_tutorial_results.json"
|
||||
md_path = f"{args.output_dir}/fmha_tutorial_results.md"
|
||||
log_path = f"{args.output_dir}/fmha_tutorial_log.txt"
|
||||
|
||||
logger.export_json(json_path)
|
||||
logger.export_markdown(md_path)
|
||||
|
||||
with open(log_path, 'w') as f:
|
||||
f.write('\n'.join(logger.logs))
|
||||
|
||||
logger.log("\nResults exported to:")
|
||||
logger.log(f" - {json_path}")
|
||||
logger.log(f" - {md_path}")
|
||||
logger.log(f" - {log_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
891
nvidia/cutile-kernels/assets/fmha_scaling_analysis.py
Normal file
891
nvidia/cutile-kernels/assets/fmha_scaling_analysis.py
Normal file
@ -0,0 +1,891 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FMHA Scaling Analysis: How Optimizations Impact Performance at Different Sizes
|
||||
|
||||
This script demonstrates:
|
||||
1. How FMHA performance scales with sequence length
|
||||
2. Which optimizations provide the most benefit at larger sizes
|
||||
3. Target-specific configurations for different GPU architectures
|
||||
|
||||
Target Platforms (from TileGym):
|
||||
- DGX Spark (sm120/sm121): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
||||
- Blackwell B300 (sm100): TILE_M=256, TILE_N=128 or 128x128, num_ctas=1, occupancy=1-2
|
||||
|
||||
Usage:
|
||||
python fmha_scaling_analysis.py [--iterations N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
|
||||
LOG_SEPARATOR = "=" * 80
|
||||
LOG_SUBSEPARATOR = "-" * 60
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
step: int
|
||||
name: str
|
||||
latency_ms: float
|
||||
tflops: float
|
||||
speedup_vs_baseline: float
|
||||
|
||||
@dataclass
|
||||
class SeqLenResult:
|
||||
seq_len: int
|
||||
steps: List[StepResult]
|
||||
best_step: int
|
||||
best_speedup: float
|
||||
tilegym_latency_ms: float
|
||||
tilegym_tflops: float
|
||||
tilegym_speedup: float
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
self.results: List[SeqLenResult] = []
|
||||
self.logs: List[str] = []
|
||||
|
||||
def log(self, msg: str):
|
||||
print(msg)
|
||||
self.logs.append(msg)
|
||||
|
||||
def section(self, title: str):
|
||||
self.log(f"\n{LOG_SEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SEPARATOR)
|
||||
|
||||
def subsection(self, title: str):
|
||||
self.log(f"\n{LOG_SUBSEPARATOR}")
|
||||
self.log(f" {title}")
|
||||
self.log(LOG_SUBSEPARATOR)
|
||||
|
||||
logger = Logger()
|
||||
|
||||
BATCH = 4
|
||||
N_HEADS = 32
|
||||
HEAD_DIM = 128
|
||||
INV_LOG_2 = 1.0 / math.log(2)
|
||||
|
||||
SEQ_LENS = [1024, 2048, 4096, 8192, 16384]
|
||||
|
||||
|
||||
def get_fmha_config():
|
||||
"""
|
||||
Get target-specific FMHA configuration (from TileGym attention.py)
|
||||
|
||||
Returns configs matching TileGym's _fmha_autotune_configs():
|
||||
- sm120/sm121 (DGX Spark): TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2
|
||||
- sm100 (Blackwell B300): Two configs to try via autotuning
|
||||
"""
|
||||
gpu_capability = torch.cuda.get_device_capability()
|
||||
|
||||
if gpu_capability in [(12, 0), (12, 1)]:
|
||||
return [
|
||||
SimpleNamespace(
|
||||
name="DGX Spark (sm121)",
|
||||
TILE_M=64,
|
||||
TILE_N=64,
|
||||
num_ctas=1,
|
||||
occupancy=2
|
||||
)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
SimpleNamespace(
|
||||
name="Blackwell B300 (sm100) - Config 1",
|
||||
TILE_M=256,
|
||||
TILE_N=128,
|
||||
num_ctas=1,
|
||||
occupancy=1
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="Blackwell B300 (sm100) - Config 2",
|
||||
TILE_M=128,
|
||||
TILE_N=128,
|
||||
num_ctas=1,
|
||||
occupancy=2
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
raise RuntimeError("CUDA not available")
|
||||
|
||||
def compute_flops(batch, heads, seq_len, head_dim, causal=True):
|
||||
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * head_dim
|
||||
total_flops = 2 * flops_per_matmul
|
||||
if causal:
|
||||
total_flops *= 0.5
|
||||
return total_flops
|
||||
|
||||
def benchmark_fn(fn, warmup=10, iterations=100):
|
||||
"""Benchmark using triton's do_bench_cudagraph for accurate timing"""
|
||||
try:
|
||||
import triton
|
||||
# Use triton's cudagraph benchmark - same as TileGym
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
return ms
|
||||
except (ImportError, Exception):
|
||||
# Fallback to manual timing
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
return (end - start) / iterations * 1000
|
||||
|
||||
def reference_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=sm_scale
|
||||
)
|
||||
|
||||
try:
|
||||
import cuda.tile as ct
|
||||
from cuda.tile import RoundingMode as RMd
|
||||
CUTILE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CUTILE_AVAILABLE = False
|
||||
logger.log("[WARN] cuTile not available.")
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
ConstInt = ct.Constant[int]
|
||||
ConstBool = ct.Constant[bool]
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_basic(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 1: Basic cuTile - no optimizations"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
qk = qk * qk_scale
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True)
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk - m_ij
|
||||
|
||||
p = ct.exp(qk)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_math_opt(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 2: Math optimizations - exp2 + flush_to_zero"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
k_tile = k_tile.reshape((TILE_N, TILE_D))
|
||||
k_t = ct.permute(k_tile, (1, 0))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0), shape=(1, 1, TILE_N, TILE_D))
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel()
|
||||
def fmha_memory_opt(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 3: Memory optimizations - load order + latency hints"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel(occupancy=2)
|
||||
def fmha_full_opt_occ2(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 4a: Full optimization with occupancy=2 (for sm120/sm121)"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
@ct.kernel(occupancy=1)
|
||||
def fmha_full_opt_occ1(
|
||||
Q, K, V, Out,
|
||||
qk_scale: float,
|
||||
TILE_D: ConstInt,
|
||||
H: ConstInt,
|
||||
TILE_M: ConstInt,
|
||||
TILE_N: ConstInt,
|
||||
CAUSAL: ConstBool,
|
||||
):
|
||||
"""Step 4b: Full optimization with occupancy=1 (for sm100 Blackwell)"""
|
||||
bid_x = ct.bid(0)
|
||||
bid_y = ct.bid(1)
|
||||
batch_idx = bid_y // H
|
||||
head_idx = bid_y % H
|
||||
|
||||
qk_scale_log2 = qk_scale * INV_LOG_2
|
||||
|
||||
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
|
||||
offs_m = offs_m[:, None]
|
||||
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
|
||||
offs_n_tile = offs_n_tile[None, :]
|
||||
|
||||
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
|
||||
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
|
||||
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
|
||||
|
||||
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D))
|
||||
q = q.reshape((TILE_M, TILE_D))
|
||||
|
||||
k_seqlen = K.shape[2]
|
||||
if CAUSAL:
|
||||
m_end = (bid_x + 1) * TILE_M
|
||||
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
|
||||
mask_start = (bid_x * TILE_M) // TILE_N
|
||||
else:
|
||||
Tc = ct.cdiv(k_seqlen, TILE_N)
|
||||
mask_start = k_seqlen // TILE_N
|
||||
|
||||
for j in range(0, Tc):
|
||||
k_t = ct.load(
|
||||
K,
|
||||
index=(batch_idx, head_idx, 0, j),
|
||||
shape=(1, 1, TILE_D, TILE_N),
|
||||
order=(0, 1, 3, 2),
|
||||
latency=2
|
||||
)
|
||||
k_t = k_t.reshape((TILE_D, TILE_N))
|
||||
|
||||
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
|
||||
qk = ct.mma(q, k_t, qk)
|
||||
|
||||
if CAUSAL and j >= mask_start:
|
||||
offs_n = j * TILE_N + offs_n_tile
|
||||
mask = offs_m >= offs_n
|
||||
qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
|
||||
|
||||
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
|
||||
m_ij = ct.maximum(m_i, m_ij)
|
||||
qk = qk * qk_scale_log2 - m_ij
|
||||
|
||||
p = ct.exp2(qk, flush_to_zero=True)
|
||||
l_ij = ct.sum(p, axis=-1, keepdims=True)
|
||||
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
|
||||
l_i = l_i * alpha + l_ij
|
||||
acc = acc * alpha
|
||||
|
||||
v_tile = ct.load(
|
||||
V,
|
||||
index=(batch_idx, head_idx, j, 0),
|
||||
shape=(1, 1, TILE_N, TILE_D),
|
||||
latency=4
|
||||
)
|
||||
v_tile = v_tile.reshape((TILE_N, TILE_D))
|
||||
p_cast = p.astype(Q.dtype)
|
||||
acc = ct.mma(p_cast, v_tile, acc)
|
||||
m_i = m_ij
|
||||
|
||||
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
|
||||
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
|
||||
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
|
||||
|
||||
def run_kernel(kernel_fn, q, k, v, sm_scale, tile_m, tile_n, is_causal=True):
|
||||
batch_size, num_heads, seq_len, head_dim = q.shape
|
||||
o = torch.empty_like(q)
|
||||
grid = (math.ceil(seq_len / tile_m), batch_size * num_heads, 1)
|
||||
ct.launch(
|
||||
torch.cuda.current_stream(),
|
||||
grid,
|
||||
kernel_fn,
|
||||
(q, k, v, o, sm_scale, head_dim, num_heads, tile_m, tile_n, is_causal)
|
||||
)
|
||||
return o
|
||||
|
||||
def run_tilegym_fmha(q, k, v, sm_scale, is_causal=True):
|
||||
try:
|
||||
import tilegym
|
||||
return tilegym.ops.fmha(q, k, v, scaling=sm_scale, is_causal=is_causal, backend="cutile")
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def run_scaling_analysis(iterations=100):
|
||||
device = get_device()
|
||||
gpu_cap = torch.cuda.get_device_capability()
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
|
||||
configs = get_fmha_config()
|
||||
primary_cfg = configs[0]
|
||||
TILE_M = primary_cfg.TILE_M
|
||||
TILE_N = primary_cfg.TILE_N
|
||||
|
||||
logger.section("FMHA SCALING ANALYSIS (TileGym Benchmark Match)")
|
||||
logger.log("Matching TileGym bench_fused_attention.py configuration")
|
||||
logger.log(f"\nGPU: {gpu_name} (sm_{gpu_cap[0]}{gpu_cap[1]})")
|
||||
|
||||
logger.subsection("TARGET-SPECIFIC CONFIGURATION (from TileGym)")
|
||||
for cfg in configs:
|
||||
logger.log(f"\n {cfg.name}:")
|
||||
logger.log(f" TILE_M={cfg.TILE_M}, TILE_N={cfg.TILE_N}")
|
||||
logger.log(f" num_ctas={cfg.num_ctas}, occupancy={cfg.occupancy}")
|
||||
|
||||
logger.log(f"\nUsing primary config: TILE_M={TILE_M}, TILE_N={TILE_N}, occupancy={primary_cfg.occupancy}")
|
||||
logger.log("\nTest Configuration (matches TileGym bench_fused_attention.py):")
|
||||
logger.log(f" Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}")
|
||||
logger.log(" Causal: True, Precision: float16")
|
||||
logger.log(f" Sequence Lengths: {SEQ_LENS}")
|
||||
logger.log(" Benchmark: triton.testing.do_bench_cudagraph (same as TileGym)")
|
||||
|
||||
logger.section("OPTIMIZATION STEPS")
|
||||
logger.log(f"""
|
||||
Step 0: PyTorch Baseline
|
||||
- torch.nn.functional.scaled_dot_product_attention
|
||||
- Uses cuDNN Flash Attention backend
|
||||
- Highly optimized reference
|
||||
|
||||
Step 1: Basic cuTile (TILE_M={TILE_M}, TILE_N={TILE_N})
|
||||
- @ct.kernel with ct.mma() for Tensor Cores
|
||||
- Standard exp() for softmax
|
||||
- Explicit transpose with ct.permute()
|
||||
- No memory/occupancy hints
|
||||
|
||||
Step 2: Math Optimizations
|
||||
- ct.exp2() instead of ct.exp() (faster on GPU)
|
||||
- flush_to_zero=True for denormals
|
||||
- Scale adjustment: multiply by 1/log(2)
|
||||
|
||||
Step 3: Memory Optimizations
|
||||
- Load order=(0,1,3,2) for implicit K transpose
|
||||
- Latency hints: K=2, V=4 for prefetching
|
||||
- Overlaps memory loads with computation
|
||||
|
||||
Step 4: Full Optimization (Target-Specific)
|
||||
- @ct.kernel(occupancy={primary_cfg.occupancy}) for {'sm120/121' if primary_cfg.occupancy == 2 else 'sm100'}
|
||||
- ct.truediv with APPROX rounding mode
|
||||
- Matches TileGym production implementation
|
||||
""")
|
||||
|
||||
logger.section("PLATFORM DIFFERENCES: DGX Spark vs Blackwell B300")
|
||||
logger.log("""
|
||||
| Parameter | DGX Spark (sm121) | Blackwell B300 (sm100) |
|
||||
|--------------|-------------------|------------------------|
|
||||
| TILE_M | 64 | 256 or 128 |
|
||||
| TILE_N | 64 | 128 |
|
||||
| num_ctas | 1 | 1 |
|
||||
| occupancy | 2 | 1 or 2 |
|
||||
|
||||
Why the difference?
|
||||
- B300 has more SMs and larger shared memory -> can use bigger tiles
|
||||
- B300 benefits from larger tiles (256x128) with lower occupancy
|
||||
- DGX Spark needs smaller tiles (64x64) with higher occupancy to hide latency
|
||||
- B300's higher memory bandwidth makes larger tiles more efficient
|
||||
""")
|
||||
|
||||
all_results = []
|
||||
|
||||
select_kernel = fmha_full_opt_occ2 if primary_cfg.occupancy == 2 else fmha_full_opt_occ1
|
||||
|
||||
for seq_len in SEQ_LENS:
|
||||
logger.subsection(f"Sequence Length: {seq_len}")
|
||||
|
||||
q = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||||
k = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||||
v = torch.randn(BATCH, N_HEADS, seq_len, HEAD_DIM, dtype=torch.float16, device=device)
|
||||
sm_scale = 1.0 / math.sqrt(HEAD_DIM)
|
||||
|
||||
flops = compute_flops(BATCH, N_HEADS, seq_len, HEAD_DIM, causal=True)
|
||||
|
||||
steps_results = []
|
||||
|
||||
baseline_fn = lambda: reference_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
baseline_latency = benchmark_fn(baseline_fn, warmup=10, iterations=iterations)
|
||||
baseline_tflops = flops * 1e-12 / (baseline_latency * 1e-3)
|
||||
steps_results.append(StepResult(0, "PyTorch Baseline", baseline_latency, baseline_tflops, 1.0))
|
||||
|
||||
if CUTILE_AVAILABLE:
|
||||
kernels = [
|
||||
(1, "Basic cuTile", fmha_basic),
|
||||
(2, "Math Opt (exp2)", fmha_math_opt),
|
||||
(3, "Memory Opt (order+latency)", fmha_memory_opt),
|
||||
(4, f"Full Opt (occ={primary_cfg.occupancy})", select_kernel),
|
||||
]
|
||||
|
||||
for step, name, kernel in kernels:
|
||||
try:
|
||||
fn = lambda kernel=kernel: run_kernel(kernel, q, k, v, sm_scale, TILE_M, TILE_N, is_causal=True)
|
||||
latency = benchmark_fn(fn, warmup=10, iterations=iterations)
|
||||
tflops = flops * 1e-12 / (latency * 1e-3)
|
||||
speedup = baseline_latency / latency
|
||||
steps_results.append(StepResult(step, name, latency, tflops, speedup))
|
||||
except Exception as e:
|
||||
logger.log(f" [ERROR] Step {step} failed: {e}")
|
||||
|
||||
tilegym_latency = 0.0
|
||||
tilegym_tflops = 0.0
|
||||
tilegym_speedup = 0.0
|
||||
tilegym_out = run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
if tilegym_out is not None:
|
||||
tilegym_fn = lambda: run_tilegym_fmha(q, k, v, sm_scale, is_causal=True)
|
||||
tilegym_latency = benchmark_fn(tilegym_fn, warmup=10, iterations=iterations)
|
||||
tilegym_tflops = flops * 1e-12 / (tilegym_latency * 1e-3)
|
||||
tilegym_speedup = baseline_latency / tilegym_latency
|
||||
|
||||
best_step = max(steps_results, key=lambda x: x.speedup_vs_baseline)
|
||||
|
||||
result = SeqLenResult(
|
||||
seq_len=seq_len,
|
||||
steps=steps_results,
|
||||
best_step=best_step.step,
|
||||
best_speedup=best_step.speedup_vs_baseline,
|
||||
tilegym_latency_ms=tilegym_latency,
|
||||
tilegym_tflops=tilegym_tflops,
|
||||
tilegym_speedup=tilegym_speedup,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
logger.log("\n | Step | Name | Latency (ms) | TFLOPS | Speedup |")
|
||||
logger.log(" |------|------|--------------|--------|---------|")
|
||||
|
||||
for sr in steps_results:
|
||||
logger.log(f" | {sr.step} | {sr.name:<28} | {sr.latency_ms:>10.3f} | {sr.tflops:>6.2f} | {sr.speedup_vs_baseline:>6.2f}x |")
|
||||
if tilegym_latency > 0:
|
||||
logger.log(f" | TG | TileGym Reference | {tilegym_latency:>10.3f} | {tilegym_tflops:>6.2f} | {tilegym_speedup:>6.2f}x |")
|
||||
|
||||
logger.log(f"\n Best: Step {best_step.step} ({best_step.name}) with {best_step.speedup_vs_baseline:.2f}x speedup")
|
||||
|
||||
logger.results = all_results
|
||||
return all_results
|
||||
|
||||
|
||||
def print_summary(results: List[SeqLenResult]):
|
||||
configs = get_fmha_config()
|
||||
primary_cfg = configs[0]
|
||||
|
||||
logger.section("SCALING SUMMARY")
|
||||
|
||||
logger.log(f"\nTarget Config: TILE_M={primary_cfg.TILE_M}, TILE_N={primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}")
|
||||
|
||||
logger.log("\n## Performance vs Sequence Length\n")
|
||||
logger.log("| Seq Len | Baseline (ms) | Full Opt (ms) | Speedup | TileGym (ms) | TG Speedup |")
|
||||
logger.log("|---------|---------------|---------------|---------|--------------|------------|")
|
||||
for r in results:
|
||||
baseline = next((s for s in r.steps if s.step == 0), None)
|
||||
full_opt = next((s for s in r.steps if s.step == 4), None)
|
||||
if baseline and full_opt:
|
||||
logger.log(f"| {r.seq_len:>7} | {baseline.latency_ms:>13.3f} | {full_opt.latency_ms:>13.3f} | {full_opt.speedup_vs_baseline:>6.2f}x | {r.tilegym_latency_ms:>12.3f} | {r.tilegym_speedup:>9.2f}x |")
|
||||
|
||||
logger.log("\n## Optimization Impact by Sequence Length\n")
|
||||
logger.log("| Seq Len | Basic | +Math | +Memory | +Full | Best |")
|
||||
logger.log("|---------|-------|-------|---------|-------|------|")
|
||||
for r in results:
|
||||
row = f"| {r.seq_len:>7} |"
|
||||
for step in [1, 2, 3, 4]:
|
||||
sr = next((s for s in r.steps if s.step == step), None)
|
||||
if sr:
|
||||
row += f" {sr.speedup_vs_baseline:>5.2f}x |"
|
||||
else:
|
||||
row += " N/A |"
|
||||
row += f" {r.best_speedup:>4.2f}x |"
|
||||
logger.log(row)
|
||||
|
||||
logger.section("KEY INSIGHTS")
|
||||
logger.log("""
|
||||
## Why Larger Sequences Benefit More from Optimization
|
||||
|
||||
1. **Memory Bandwidth Dominance**
|
||||
- Attention has O(N²) memory complexity for the QK^T matrix
|
||||
- At seq_len=8192: 8192² × 4 bytes = 256MB per head per batch
|
||||
- Memory optimizations (order, latency hints) have larger impact
|
||||
|
||||
2. **More K-Loop Iterations**
|
||||
- At seq_len=512: 8 K-tiles (512/64) for sm121, 2 K-tiles (512/256) for sm100
|
||||
- At seq_len=8192: 128 K-tiles for sm121, 32 K-tiles for sm100
|
||||
- Latency hiding through pipelining amortizes over more iterations
|
||||
|
||||
3. **Better Occupancy Utilization**
|
||||
- More tiles = more parallelism opportunities
|
||||
- sm121 uses occupancy=2 (smaller tiles, more blocks)
|
||||
- sm100 uses occupancy=1 with larger tiles (256x128)
|
||||
|
||||
4. **Platform-Specific Tuning**
|
||||
- DGX Spark (sm121): 64x64 tiles, occupancy=2 - optimized for bandwidth-limited workloads
|
||||
- B300 (sm100): 256x128 tiles, occupancy=1 - optimized for compute-heavy workloads
|
||||
|
||||
## Optimization Priority by Problem Size
|
||||
|
||||
Small (seq_len <= 1024):
|
||||
- Basic cuTile often sufficient
|
||||
- Focus on correctness first
|
||||
|
||||
Medium (1024 < seq_len <= 4096):
|
||||
- Math optimizations (exp2) provide ~5% gain
|
||||
- Memory optimizations start to matter
|
||||
|
||||
Large (seq_len > 4096):
|
||||
- Full optimization stack critical
|
||||
- Platform-specific tuning essential
|
||||
- Memory pipelining becomes essential
|
||||
""")
|
||||
|
||||
|
||||
def export_results(results: List[SeqLenResult], output_dir: str):
|
||||
configs = get_fmha_config()
|
||||
primary_cfg = configs[0]
|
||||
|
||||
data = {
|
||||
"config": {
|
||||
"batch": BATCH,
|
||||
"n_heads": N_HEADS,
|
||||
"head_dim": HEAD_DIM,
|
||||
"tile_m": primary_cfg.TILE_M,
|
||||
"tile_n": primary_cfg.TILE_N,
|
||||
"occupancy": primary_cfg.occupancy,
|
||||
"num_ctas": primary_cfg.num_ctas,
|
||||
"platform": primary_cfg.name,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"seq_len": r.seq_len,
|
||||
"steps": [asdict(s) for s in r.steps],
|
||||
"best_step": r.best_step,
|
||||
"best_speedup": r.best_speedup,
|
||||
"tilegym_latency_ms": r.tilegym_latency_ms,
|
||||
"tilegym_tflops": r.tilegym_tflops,
|
||||
"tilegym_speedup": r.tilegym_speedup,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
}
|
||||
|
||||
json_path = f"{output_dir}/fmha_scaling_results.json"
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
md_path = f"{output_dir}/fmha_scaling_results.md"
|
||||
with open(md_path, 'w') as f:
|
||||
f.write("# FMHA Scaling Analysis Results\n\n")
|
||||
f.write("## Configuration\n")
|
||||
f.write(f"- Platform: {primary_cfg.name}\n")
|
||||
f.write(f"- Batch: {BATCH}, Heads: {N_HEADS}, Head Dim: {HEAD_DIM}\n")
|
||||
f.write(f"- Tile: {primary_cfg.TILE_M}x{primary_cfg.TILE_N}, occupancy={primary_cfg.occupancy}\n\n")
|
||||
|
||||
f.write("## Target-Specific Configs (from TileGym)\n\n")
|
||||
f.write("| Platform | TILE_M | TILE_N | num_ctas | occupancy |\n")
|
||||
f.write("|----------|--------|--------|----------|----------|\n")
|
||||
f.write("| DGX Spark (sm121) | 64 | 64 | 1 | 2 |\n")
|
||||
f.write("| B300 (sm100) Config 1 | 256 | 128 | 1 | 1 |\n")
|
||||
f.write("| B300 (sm100) Config 2 | 128 | 128 | 1 | 2 |\n\n")
|
||||
|
||||
f.write("## Results by Sequence Length\n\n")
|
||||
for r in results:
|
||||
f.write(f"### Seq Len = {r.seq_len}\n\n")
|
||||
f.write("| Step | Name | Latency (ms) | TFLOPS | Speedup |\n")
|
||||
f.write("|------|------|--------------|--------|--------|\n")
|
||||
for s in r.steps:
|
||||
f.write(f"| {s.step} | {s.name} | {s.latency_ms:.3f} | {s.tflops:.2f} | {s.speedup_vs_baseline:.2f}x |\n")
|
||||
if r.tilegym_latency_ms > 0:
|
||||
f.write(f"| TG | TileGym Reference | {r.tilegym_latency_ms:.3f} | {r.tilegym_tflops:.2f} | {r.tilegym_speedup:.2f}x |\n")
|
||||
f.write("\n")
|
||||
|
||||
log_path = f"{output_dir}/fmha_scaling_log.txt"
|
||||
with open(log_path, 'w') as f:
|
||||
f.write('\n'.join(logger.logs))
|
||||
|
||||
logger.log(f"\nResults exported to:")
|
||||
logger.log(f" - {json_path}")
|
||||
logger.log(f" - {md_path}")
|
||||
logger.log(f" - {log_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FMHA Scaling Analysis")
|
||||
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
|
||||
parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
results = run_scaling_analysis(iterations=args.iterations)
|
||||
print_summary(results)
|
||||
export_results(results, args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -14,11 +14,11 @@
|
||||
|
||||
## Basic idea
|
||||
|
||||
The DGX Dashboard is a web application that runs locally on DGX Spark devices, providing a graphical interface for system updates, resource monitoring, and an integrated JupyterLab environment. Users can access the dashboard locally from the app launcher or remotely through NVIDIA Sync or SSH tunneling. The dashboard is the easiest way to update system packages and firmware when working remotely.
|
||||
The DGX Dashboard is a web application that runs locally on DGX Spark devices, providing a graphical interface for system updates, resource monitoring, and an integrated JupyterLab environment. Users can access the dashboard locally from the app launcher or remotely through NVIDIA Sync, SSH tunneling, or Tailscale. The dashboard is the easiest way to update system packages and firmware when working remotely.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You will learn how to access and use the DGX Dashboard on your DGX Spark device. By the end of this walkthrough, you will be able to launch JupyterLab instances with pre-configured Python environments, monitor GPU performance, manage system updates, and run a sample AI workload using Stable Diffusion. You'll understand multiple access methods including desktop shortcuts, NVIDIA Sync, and manual SSH tunneling.
|
||||
You will learn how to access and use the DGX Dashboard on your DGX Spark device. By the end of this walkthrough, you will be able to launch JupyterLab instances with pre-configured Python environments, monitor GPU performance, manage system updates, and run a sample AI workload using Stable Diffusion. You'll understand multiple access methods including desktop shortcuts, NVIDIA Sync, manual SSH tunneling, and Tailscale.
|
||||
|
||||
## What to know before starting
|
||||
|
||||
@ -98,6 +98,10 @@ Replace `<ASSIGNED_PORT>` with the port number from the YAML file.
|
||||
|
||||
Open your web browser and navigate to `http://localhost:11000`.
|
||||
|
||||
**Option D: Tailscale (alternative to manual SSH tunnels)**
|
||||
|
||||
For secure remote access over your private network without manual SSH tunneling, check out the [Tailscale playbook](../tailscale/README.md#step-12-access-dgx-dashboard-over-tailnet) for instructions on accessing the DGX Dashboard over the tailnet using Tailscale Serve.
|
||||
|
||||
|
||||
## Step 2. Log into DGX Dashboard
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Run models with llama.cpp on DGX Spark
|
||||
|
||||
> Build llama.cpp with CUDA and serve models via an OpenAI-compatible API (Nemotron 3 Nano Omni as example)
|
||||
|
||||
> Build llama.cpp with CUDA and serve models via an OpenAI-compatible API
|
||||
|
||||
## Table of Contents
|
||||
|
||||
@ -15,167 +14,148 @@
|
||||
|
||||
## Basic idea
|
||||
|
||||
[llama.cpp](https://github.com/ggml-org/llama.cpp) is a lightweight C/C++ inference stack for large language models. You build it with CUDA so tensor work runs on the DGX Spark GB10 GPU, then load GGUF weights and expose chat through `llama-server`’s OpenAI-compatible HTTP API.
|
||||
[llama.cpp](https://github.com/ggml-org/llama.cpp) is a lightweight C/C++ inference stack for large language models. You build it with CUDA so it fully utilizes the DGX Spark GB10 GPU, then load GGUF weights and expose chat through `llama-server`’s OpenAI-compatible HTTP API.
|
||||
|
||||
This playbook walks through that stack end to end using **Nemotron 3 Nano Omni** as the hands-on example: an NVIDIA MoE family that runs well from quantized GGUF on Spark. Checkpoint choices and paths for all supported models are summarized in the matrix below; commands are in the instructions.
|
||||
This playbook walks through that stack end to end using MTP-enabled **Qwen3.6-35B-A3B** as the hands-on example. Checkpoint choices and paths for all supported models are summarized in the matrix below; commands are in the instructions.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You will build llama.cpp with CUDA for GB10, download a **Nemotron 3 Nano Omni** example checkpoint, and run **`llama-server`** with GPU offload. You get:
|
||||
You will build llama.cpp with CUDA for GB10, download a **Qwen3.6-35B-A3B** checkpoint, and run **`llama-server`** with GPU offload. You get:
|
||||
|
||||
- Local inference through llama.cpp (no separate Python inference framework required)
|
||||
- An OpenAI-compatible `/v1/chat/completions` endpoint for tools and apps
|
||||
- A concrete validation that the **Nemotron 3 Nano Omni** example runs on this stack on DGX Spark
|
||||
- Local inference through llama.cpp (no separate Python inference framework required)
|
||||
- An OpenAI-compatible `/v1/chat/completions` endpoint for tools and apps
|
||||
- A concrete validation that the **Qwen3.6-35B-A3B** example runs on this stack on DGX Spark with MTP support.
|
||||
|
||||
## What to know before starting
|
||||
|
||||
- Basic familiarity with Linux command line and terminal commands
|
||||
- Understanding of git and building from source with CMake
|
||||
- Basic familiarity with Linux command line and terminal commands
|
||||
- Understanding of git and building from source with CMake
|
||||
- Basic knowledge of REST APIs and cURL for testing
|
||||
- Familiarity with Hugging Face Hub for downloading GGUF files
|
||||
|
||||
## Prerequisites
|
||||
|
||||
**Hardware requirements**
|
||||
|
||||
- NVIDIA DGX Spark with GB10 GPU
|
||||
- Sufficient unified memory for the example **Q8_0** checkpoint (weights on the order of **~35GB**, plus KV cache and runtime overhead—scale up if you pick a larger quant or longer context)
|
||||
- At least **~40GB** free disk for the example download plus build artifacts (more if you keep multiple GGUFs)
|
||||
- NVIDIA DGX Spark with GB10 GPU
|
||||
- Sufficient unified memory for the model and the KV-Cache being utilized (about 30GB free RAM for the model in the example)
|
||||
- At least **\~40GB** free disk for the example download plus build artifacts (more if you keep multiple GGUFs)
|
||||
|
||||
**Software requirements**
|
||||
|
||||
- NVIDIA DGX OS
|
||||
- Git: `git --version`
|
||||
- CMake (3.14+): `cmake --version`
|
||||
- CUDA Toolkit: `nvcc --version`
|
||||
- NVIDIA DGX OS
|
||||
- Git: `git --version`
|
||||
- CMake (3.14+): `cmake --version`
|
||||
- CUDA Toolkit: `nvcc --version`
|
||||
- Network access to GitHub and Hugging Face
|
||||
|
||||
## Model support matrix
|
||||
|
||||
The following models are supported with llama.cpp on Spark. The instructions use the **Nemotron 3 Nano Omni** example row by default.
|
||||
|
||||
| Model | Support Status | HF Handle |
|
||||
|-------|----------------|-----------|
|
||||
| **Nemotron 3 Nano Omni** (example walkthrough) | ✅ | `ggml-org/NVIDIA-Nemotron-3-Nano-Omni` |
|
||||
| **Qwen3.6-35B-A3B** | ✅ | `unsloth/Qwen3.6-35B-A3B-GGUF` |
|
||||
| **Qwen3.6-27B** | ✅ | `unsloth/Qwen3.6-27B-GGUF` |
|
||||
| **Gemma 4 31B IT** | ✅ | `ggml-org/gemma-4-31B-it-GGUF` |
|
||||
| **Gemma 4 26B A4B IT** | ✅ | `ggml-org/gemma-4-26B-A4B-it-GGUF` |
|
||||
| **Gemma 4 E4B IT** | ✅ | `ggml-org/gemma-4-E4B-it-GGUF` |
|
||||
| **Gemma 4 E2B IT** | ✅ | `ggml-org/gemma-4-E2B-it-GGUF` |
|
||||
| **Nemotron-3-Nano** | ✅ | `unsloth/Nemotron-3-Nano-30B-A3B-GGUF` |
|
||||
DGX Spark supports any GGUF format model checkpoint with llama.cpp, as long as the system has memory available to host and run the checkpoint.
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Estimated time:** About 30 minutes, plus downloading the example GGUF (~35GB order of magnitude for the default quant)
|
||||
* **Risk level:** Low — build is local to your clone; no system-wide installs required for the steps below
|
||||
* **Rollback:** Remove the `llama.cpp` clone and the model directory under `~/models/` to reclaim disk space
|
||||
* **Last updated:** 04/28/2026
|
||||
* Walkthrough now uses Nemotron Omni; other model rows stay available
|
||||
* **Estimated time:** About 30 minutes, plus downloading the example GGUF (\~35GB order of magnitude for the default quant)
|
||||
* **Risk level:** Low — build is local to your clone; no system-wide installs required for the steps below
|
||||
* **Rollback:** Remove the `llama.cpp` clone and the model directory under `~/.cache/huggingface/hub/` to reclaim disk space
|
||||
* **Last updated:** 06/03/2026
|
||||
* Walkthrough now uses Qwen3.6-35B-A3B as an example
|
||||
|
||||
## Instructions
|
||||
|
||||
## Step 1. Verify prerequisites
|
||||
## Step 1. Install the dependencies
|
||||
|
||||
The **example** checkpoint is **`nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf`** from Hugging Face repo **`ggml-org/NVIDIA-Nemotron-3-Nano-Omni`** (full handle: `ggml-org/NVIDIA-Nemotron-3-Nano-Omni/nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf`). Other supported GGUFs—including Qwen3.6, Gemma, and alternate Nemotron Omni builds—use the same build and server steps; change `hf download` and `--model` paths (see the [overview model matrix](overview.md)).
|
||||
Install the required dependencies:
|
||||
|
||||
Ensure the required tools are installed:
|
||||
|
||||
```bash
|
||||
git --version
|
||||
cmake --version
|
||||
nvcc --version
|
||||
```
|
||||
|
||||
All commands should return version information. If any are missing, install them before continuing.
|
||||
|
||||
Install the Hugging Face CLI:
|
||||
|
||||
```bash
|
||||
python3 -m venv llama-cpp-venv
|
||||
source llama-cpp-venv/bin/activate
|
||||
pip install -U "huggingface_hub[cli]"
|
||||
```
|
||||
|
||||
Verify installation:
|
||||
|
||||
```bash
|
||||
hf version
|
||||
```shell
|
||||
sudo apt install -y git clang cmake libcurl4-openssl-dev libssl-dev
|
||||
```
|
||||
|
||||
## Step 2. Clone the llama.cpp repository
|
||||
|
||||
Clone upstream llama.cpp—the framework you are building:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
## Step 3. Build llama.cpp with CUDA
|
||||
|
||||
Configure CMake with CUDA and GB10’s **sm_121** architecture so GGML’s CUDA backend matches your GPU:
|
||||
Configure CMake with CUDA and GB10’s **sm\_121** architecture so GGML’s CUDA backend matches your GPU:
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake .. -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="121" -DLLAMA_CURL=OFF
|
||||
make -j8
|
||||
```shell
|
||||
cmake -B build -DGGML_NATIVE=ON -DGGML_CUDA=ON -DGGML_CURL=ON -DGGML_RPC=ON -DCMAKE_CUDA_ARCHITECTURES=121a-real
|
||||
cmake --build build --config Release -j
|
||||
```
|
||||
|
||||
The build usually takes on the order of 5–10 minutes. When it finishes, binaries such as `llama-server` appear under `build/bin/`.
|
||||
|
||||
## Step 4. Download example Nemotron 3 Nano Omni GGUF
|
||||
## Step 4. Start llama-server with a model
|
||||
|
||||
llama.cpp loads models in **GGUF** format. This playbook uses the **Q8_0** checkpoint from `ggml-org/NVIDIA-Nemotron-3-Nano-Omni`, which balances quality and memory on DGX Spark GB10 unified memory.
|
||||
llama.cpp loads models in **GGUF** format. This playbook uses the **Q4\_K\_XL** checkpoint from `unsloth/Qwen3.6-35B-A3B-MTP-GGUF`, which provides a good balance between quality and speed on DGX Spark.
|
||||
|
||||
```bash
|
||||
hf download ggml-org/NVIDIA-Nemotron-3-Nano-Omni \
|
||||
nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf \
|
||||
--local-dir ~/models/NVIDIA-Nemotron-3-Nano-Omni
|
||||
From your `llama.cpp/build` directory, launch the OpenAI-compatible server with GPU offload. It will load the model from HuggingFace first if it hasn’t been downloaded before or if there are any updates.
|
||||
|
||||
All models are saved in the default HuggingFace cache directory in \~/.cache/huggingface/hub. For instance, this model will be saved into \~/.cache/huggingface/hub/models--unsloth--Qwen3.6-35B-A3B-MTP-GGUF
|
||||
|
||||
It will also automatically load mmproj file to enable vision capabilities if supported by the model. By default, llama-server will try to fit full model context with ability to serve 4 concurrent requests, but it will adjust parameters automatically if needed.
|
||||
|
||||
```shell
|
||||
./bin/llama-server \
|
||||
-hf unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL \
|
||||
--host 0.0.0.0 \
|
||||
--port 30000
|
||||
```
|
||||
|
||||
The file is on the order of **~35GB** (exact size may vary). The download can be resumed if interrupted.
|
||||
To run with MTP speculative decoding, provide additional parameters as shown in the example below. MTP requires a compatible model, like `unsloth/Qwen3.6-35B-A3B-MTP-GGUF` used in this example. The following example also sets “preserve\_thinking” flag that allows Qwen models to use so-called “interleaved thinking” by preserving all prior thinking blocks in the history which can be useful for agentic workflows.
|
||||
|
||||
## Step 5. Start llama-server with Nemotron 3 Nano Omni
|
||||
|
||||
From your `llama.cpp/build` directory, launch the OpenAI-compatible server with GPU offload:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
./bin/llama-server \
|
||||
--model ~/models/NVIDIA-Nemotron-3-Nano-Omni/nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf \
|
||||
-hf unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL \
|
||||
--host 0.0.0.0 \
|
||||
--port 30000 \
|
||||
--n-gpu-layers 99 \
|
||||
--ctx-size 8192 \
|
||||
--threads 8
|
||||
--chat-template-kwargs '{"preserve_thinking": true}' \
|
||||
--spec-type draft-mtp \
|
||||
--spec-draft-n-max 3
|
||||
```
|
||||
|
||||
**Parameters (short):**
|
||||
|
||||
- `--host` / `--port`: bind address and port for the HTTP API
|
||||
- `--n-gpu-layers 99`: offload layers to the GPU (adjust if you use a different model)
|
||||
- `--ctx-size`: context length (can be increased up to model/server limits; uses more memory)
|
||||
- `--threads`: CPU threads for non-GPU work
|
||||
- `--host` / `--port`: bind address and port for the HTTP API
|
||||
- `--chat-template-kwargs`: sets additional params for the json template parser, must be a valid json object string
|
||||
- `--spec-type`: comma-separated list of types of speculative decoding to use (default: none, most MTP-compatible models will use “draft-mtp”, but you need to check the model card first)
|
||||
- `--spec-draft-n-max`: number of tokens to draft for speculative decoding (default: 3\)
|
||||
|
||||
You should see log lines similar to:
|
||||
|
||||
```
|
||||
llama_new_context_with_model: n_ctx = 8192
|
||||
0.14.322.968 I srv load_model: speculative decoding context initialized
|
||||
0.14.322.970 I slot load_model: id 0 | task -1 | new slot, n_ctx = 262144
|
||||
0.14.322.972 I slot load_model: id 1 | task -1 | new slot, n_ctx = 262144
|
||||
0.14.322.972 I slot load_model: id 2 | task -1 | new slot, n_ctx = 262144
|
||||
0.14.322.973 I slot load_model: id 3 | task -1 | new slot, n_ctx = 262144
|
||||
0.14.323.063 I srv load_model: prompt cache is enabled, size limit: 8192 MiB
|
||||
|
||||
...
|
||||
main: server is listening on 0.0.0.0:30000
|
||||
0.14.342.935 I srv llama_server: model loaded
|
||||
0.14.342.939 I srv llama_server: server is listening on http://0.0.0.0:30000
|
||||
0.14.342.944 I srv update_slots: all slots are idle
|
||||
|
||||
```
|
||||
|
||||
**Keep this terminal open** while testing. Large GGUFs can take a minute or more to load; until you see `server is listening`, nothing accepts connections on port 30000 (see Troubleshooting if `curl` reports connection refused).
|
||||
**Keep this terminal open** while testing. Large GGUFs can take a minute or more to load, and initial model download can take a while if the model is not downloaded yet. You will see a progress bar when model is being downloaded.
|
||||
|
||||
## Step 6. Test the API
|
||||
The server is only ready to accept incoming connections on port 30000 after you see `server is listening` message (see Troubleshooting if `curl` reports connection refused).
|
||||
|
||||
Use a **second terminal on the same machine** that runs `llama-server` (for example another SSH session into DGX Spark). If you run `curl` on your laptop while the server runs only on Spark, use the Spark hostname or IP instead of `localhost`.
|
||||
## Step 5. Test the API
|
||||
|
||||
```bash
|
||||
Use a **second terminal on the same machine** that runs `llama-server` (for example another SSH session into DGX Spark). If you run `curl` on your laptop while the server runs only on Spark, use the Spark hostname or IP instead of `localhost`.
|
||||
|
||||
```shell
|
||||
curl -X POST http://127.0.0.1:30000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "nemotron",
|
||||
"model": "unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL",
|
||||
"messages": [{"role": "user", "content": "New York is a great city because..."}],
|
||||
"max_tokens": 100
|
||||
}'
|
||||
@ -198,7 +178,7 @@ Example shape of the response (fields vary by llama.cpp version; `message` may i
|
||||
}
|
||||
],
|
||||
"created": 1765916539,
|
||||
"model": "nemotron-3-nano-omni-ga_v1.0-Q8_0.gguf",
|
||||
"model": "$MODEL_PATH",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
@ -212,41 +192,35 @@ Example shape of the response (fields vary by llama.cpp version; `message` may i
|
||||
}
|
||||
```
|
||||
|
||||
## Step 7. Longer completion (with Nemotron 3 Nano Omni)
|
||||
## Step 6. Longer completion (with Qwen3.6-35B-A3B)
|
||||
|
||||
Try a slightly longer prompt to confirm stable generation with **Nemotron 3 Nano Omni**:
|
||||
Try a slightly longer prompt to confirm stable generation with **Qwen3.6-35B-A3B**:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
curl -X POST http://127.0.0.1:30000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "nemotron",
|
||||
"model": "unsloth/Qwen3.6-35B-A3B-MTP-GGUF:UD-Q4_K_XL",
|
||||
"messages": [{"role": "user", "content": "Solve this step by step: If a train travels 120 miles in 2 hours, what is its average speed?"}],
|
||||
"max_tokens": 500
|
||||
}'
|
||||
```
|
||||
|
||||
## Step 8. Cleanup
|
||||
## Step 7. Cleanup
|
||||
|
||||
Stop the server with `Ctrl+C` in the terminal where it is running.
|
||||
|
||||
To remove this tutorial’s artifacts:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
rm -rf ~/llama.cpp
|
||||
rm -rf ~/models/NVIDIA-Nemotron-3-Nano-Omni
|
||||
rm -rf ~/.cache/huggingface/hub/models--unsloth--Qwen3.6-35B-A3B-MTP-GGUF
|
||||
```
|
||||
|
||||
Deactivate the Python venv if you no longer need `hf`:
|
||||
## Step 8. Next steps
|
||||
|
||||
```bash
|
||||
deactivate
|
||||
```
|
||||
|
||||
## Step 9. Next steps
|
||||
|
||||
1. **Context length:** Increase `--ctx-size` for longer chats (watch memory; 1M-token class contexts are possible only when the build, model, and hardware allow).
|
||||
2. **Other models:** Point `--model` at any compatible GGUF; the llama.cpp server API stays the same.
|
||||
1. **Context length:** By default, llama.cpp tries to allocate maximum context size supported for the model if possible, but you can also set it manually using `--ctx-size` (or `-c`) to adjust for your needs. For agentic or coding needs you need a minimum of 32768 tokens, preferably 100000 or more.
|
||||
2. **Other models:** You can use `--model` to load any compatible GGUF downloaded locally; the llama.cpp server API stays the same. Use `-hf` to let llama.cpp automatically manage downloads/updates. Please note that if you use `--model` with multi-modal models, you need to provide a path to .mmproj file using `--mmproj` parameter. If you use `-hf` it will load the mmproj file automatically.
|
||||
3. **Integrations:** Point Open WebUI, Continue.dev, or custom clients at `http://<spark-host>:30000/v1` using the OpenAI client pattern.
|
||||
|
||||
The server implements the usual OpenAI-style chat features your llama.cpp build enables (including streaming and tool-related flows where supported).
|
||||
|
||||
1412
nvidia/nemoclaw-applications/README.md
Normal file
1412
nvidia/nemoclaw-applications/README.md
Normal file
File diff suppressed because it is too large
Load Diff
108
nvidia/register-to-brev/README.md
Normal file
108
nvidia/register-to-brev/README.md
Normal file
@ -0,0 +1,108 @@
|
||||
# Register DGX Spark to Brev
|
||||
|
||||
> Link your DGX Spark to Brev for remote access and shared environments
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Instructions](#instructions)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
## Basic idea
|
||||
|
||||
NVIDIA Brev is an AI development platform that makes GPU environments remotely accessible, shareable, and easy to standardize using preconfigured setups called Launchables.
|
||||
|
||||
This walkthrough will help you connect your NVIDIA DGX Spark to Brev so it shows up as a managed GPU environment in Brev. After a one-time registration, your Spark becomes remotely accessible and shareable.
|
||||
|
||||
## What you'll accomplish
|
||||
|
||||
You’ll register your DGX Spark with Brev and it will be visible as a healthy node in the Brev web UI and CLI, ready to share access and accept workloads whenever needed.
|
||||
|
||||
## What to know before starting
|
||||
|
||||
While Brev automates the complex configuration, understanding a few key concepts when establishing the initial connection will be useful:
|
||||
|
||||
* **Terminal Basics**:
|
||||
* Familiarity with the command line to run a few simple setup commands
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Your DGX Spark [device is set up](https://docs.nvidia.com/dgx/dgx-spark/first-boot.html). You will also need the following:
|
||||
|
||||
* **Brev Account**:
|
||||
* Have an NVIDIA Brev account. Create one [here](https://login.brev.nvidia.com/signin) if you don’t have one.
|
||||
|
||||
* **Permissions**:
|
||||
* You have administrative (root or sudo) access on the DGX Spark device to run the registration command.
|
||||
|
||||
## Time & risk
|
||||
|
||||
* **Estimated time:** 5-10 minutes
|
||||
* **Risk level:** Low - Registration configures the Spark for secure remote access without altering your existing workloads
|
||||
* **Rollback:** The Brev configuration can be removed through the UI and CLI
|
||||
|
||||
## Instructions
|
||||
|
||||
## Step 1. Log in to Brev
|
||||
|
||||
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm you’re in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
|
||||
|
||||
Click the “Register Compute” button and follow the instructions in the pop-up window.
|
||||
|
||||
## Step 2. Complete Popup Instructions
|
||||
|
||||
* Install the Brev CLI
|
||||
* Configure your compute
|
||||
* Add a name for compute
|
||||
* To configure ssh, ensure the “Enable SSH access” toggle is on
|
||||
* Run the registration command
|
||||
|
||||
## Step 3. Follow Registration Flow
|
||||
|
||||
In the CLI, you’ll be walked through registration. Go through the flow until registration is complete.
|
||||
|
||||
## Step 4. Confirm Spark in Brev UI
|
||||
|
||||
* Go to the [Brev UI](https://brev.nvidia.com)
|
||||
* Navigate to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute)
|
||||
* Confirm that the DGX Spark appears as a registered node with a **Connected** status
|
||||
|
||||
## Step 5. Next Steps
|
||||
|
||||
Your Spark is now integrated into Brev as a secure, remotely accessible GPU environment.
|
||||
|
||||
Now that your hardware is connected, you can:
|
||||
|
||||
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI by:
|
||||
* Adding the user to your [Team](https://brev.nvidia.com/org/team)
|
||||
* Navigating to your instance in the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section
|
||||
* In **SSH Access** section of the instance, search for the user you wish to add and click **Modify Access** to enable access
|
||||
|
||||
## Step 6. Cleanup
|
||||
|
||||
If you ever decide to unregister your Spark with Brev, you can either do so through the Brev UI or the Brev CLI.
|
||||
|
||||
With the CLI simply run:
|
||||
|
||||
```bash
|
||||
brev deregister
|
||||
```
|
||||
|
||||
In the UI:
|
||||
* Go to the [Brev UI](https://brev.nvidia.com)
|
||||
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
|
||||
* Click the “Remove” menu item on the Spark you wish to delete from Brev.
|
||||
* Confirm your selection.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| Your DGX Spark is showing up in the wrong org | You registered your DGX Spark to the wrong org | Run `brev set <my-org>` and then redo the registration |
|
||||
| Unable to `brev shell <name>` | Need to refresh | `brev refresh` |
|
||||
|
||||
For the latest known issues, please review the [DGX Spark User Guide](https://docs.nvidia.com/dgx/dgx-spark/known-issues.html).
|
||||
@ -52,21 +52,18 @@ You will also need the following:
|
||||
|
||||
## Step 1. Log in to Brev
|
||||
|
||||
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm you’re in the correct org (by clicking the org button on the top right-hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
|
||||
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm you’re in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
|
||||
|
||||
Click the “Register Compute” button and follow the instructions in the pop-up window.
|
||||
|
||||
## Step 2. Complete Pop-up Instructions
|
||||
## Step 2. Complete Popup Instructions
|
||||
|
||||
* Install the Brev CLI
|
||||
* Configure your compute
|
||||
* Add a name for compute
|
||||
* To configure SSH, ensure the “Enable SSH access” toggle is on
|
||||
* To configure ssh, ensure the “Enable SSH access” toggle is on
|
||||
* Run the registration command
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Run the Brev CLI install command **without `sudo`**. Prefixing the installer with `sudo` writes the `brev` binary into root's home directory, which is not on your user shell's `PATH` — the next command will fail with `brev: command not found`. Copy the install command from the pop-up and run it as your normal user.
|
||||
|
||||
## Step 3. Follow Registration Flow
|
||||
|
||||
In the CLI, you’ll be walked through registration. Go through the flow until registration is complete.
|
||||
@ -83,14 +80,10 @@ Your DGX Station is now integrated into Brev as a secure, remotely accessible GP
|
||||
|
||||
Now that your hardware is connected, you can:
|
||||
|
||||
* **Access your machine from anywhere:** Open the [Brev UI](https://brev.nvidia.com) and launch a session from [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
|
||||
* **Share access with others:** Invite teammates to your DGX Station from the Brev UI:
|
||||
* Go to the [Brev UI](https://brev.nvidia.com) and open [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
|
||||
* Find your DGX Station in the list and open the row's three-dot (⋯) menu.
|
||||
* Select **Share Access**.
|
||||
* Enter the email address of the person you want to share with.
|
||||
* Choose their role / permission level.
|
||||
* Confirm to send the invitation.
|
||||
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI by:
|
||||
* Adding the user to your [Team](https://brev.nvidia.com/org/team)
|
||||
* Navigating to your instance in the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section
|
||||
* In **SSH Access** section of the instance, search for the user you wish to add and click **Modify Access** to enable access
|
||||
|
||||
## Step 6. Cleanup
|
||||
|
||||
@ -105,7 +98,7 @@ brev deregister
|
||||
In the UI:
|
||||
* Go to the [Brev UI](https://brev.nvidia.com)
|
||||
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
|
||||
* Click the “Remove” menu item on the device you wish to delete from Brev.
|
||||
* Click the “Remove” menu item on the DGX Station you wish to delete from Brev.
|
||||
* Confirm your selection.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
@ -82,21 +82,18 @@ spec:
|
||||
content: |
|
||||
# Step 1. Log in to Brev
|
||||
|
||||
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm you’re in the correct org (by clicking the org button on the top right-hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
|
||||
Go to the [Brev UI](https://brev.nvidia.com), log in, and confirm you’re in the correct org (by clicking the org button on the top right hand side of the page). Once logged in, go to the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section under the "GPU" tab in the main navigation.
|
||||
|
||||
Click the “Register Compute” button and follow the instructions in the pop-up window.
|
||||
|
||||
# Step 2. Complete Pop-up Instructions
|
||||
# Step 2. Complete Popup Instructions
|
||||
|
||||
* Install the Brev CLI
|
||||
* Configure your compute
|
||||
* Add a name for compute
|
||||
* To configure SSH, ensure the “Enable SSH access” toggle is on
|
||||
* To configure ssh, ensure the “Enable SSH access” toggle is on
|
||||
* Run the registration command
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Run the Brev CLI install command **without `sudo`**. Prefixing the installer with `sudo` writes the `brev` binary into root's home directory, which is not on your user shell's `PATH` — the next command will fail with `brev: command not found`. Copy the install command from the pop-up and run it as your normal user.
|
||||
|
||||
# Step 3. Follow Registration Flow
|
||||
|
||||
In the CLI, you’ll be walked through registration. Go through the flow until registration is complete.
|
||||
@ -113,14 +110,10 @@ spec:
|
||||
|
||||
Now that your hardware is connected, you can:
|
||||
|
||||
* **Access your machine from anywhere:** Open the [Brev UI](https://brev.nvidia.com) and launch a session from [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
|
||||
* **Share access with others:** Invite teammates to your DGX Station from the Brev UI:
|
||||
* Go to the [Brev UI](https://brev.nvidia.com) and open [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute).
|
||||
* Find your DGX Station in the list and open the row's three-dot (⋯) menu.
|
||||
* Select **Share Access**.
|
||||
* Enter the email address of the person you want to share with.
|
||||
* Choose their role / permission level.
|
||||
* Confirm to send the invitation.
|
||||
* **Share Access Anywhere:** Access your machine from anywhere and share access with others through the Brev UI by:
|
||||
* Adding the user to your [Team](https://brev.nvidia.com/org/team)
|
||||
* Navigating to your instance in the [Registered Compute](https://brev.nvidia.com/org/environments?tab=registered-compute) section
|
||||
* In **SSH Access** section of the instance, search for the user you wish to add and click **Modify Access** to enable access
|
||||
|
||||
# Step 6. Cleanup
|
||||
|
||||
@ -135,7 +128,7 @@ spec:
|
||||
In the UI:
|
||||
* Go to the [Brev UI](https://brev.nvidia.com)
|
||||
* Navigate to the section listing “GPU Environments” and look under “Registered Compute”
|
||||
* Click the “Remove” menu item on the device you wish to delete from Brev.
|
||||
* Click the “Remove” menu item on the DGX Station you wish to delete from Brev.
|
||||
* Confirm your selection.
|
||||
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ spec:
|
||||
|
||||
# Time & risk
|
||||
|
||||
- **Estimated time:** ~30 minutes for setup. Full d24 training takes on the order of 16+ hours on a single GB300 Ultra.
|
||||
- **Estimated time:** ~30 minutes for setup. Full d24 training takes on the order of 12+ hours on a single GB300 Ultra.
|
||||
- **Risk level:** Medium
|
||||
- Large downloads (FineWeb) can be slow; ensure stable network and disk space.
|
||||
- API keys (W&B, HF) must be set or `launch.sh` will exit immediately.
|
||||
@ -184,7 +184,7 @@ spec:
|
||||
3. **SFT** — downloads synthetic identity conversations, fine-tunes for chat
|
||||
4. **Report generation** — produces `report.md` with metrics and samples
|
||||
|
||||
Training on a single GB300 Ultra takes on the order of 16+ hours for the full d24 run.
|
||||
Training on a single GB300 Ultra takes on the order of 12+ hours for the full d24 run.
|
||||
|
||||
# Step 4. Monitor training
|
||||
|
||||
|
||||
@ -18,8 +18,10 @@
|
||||
- [Step 9. Configure SSH authentication](#step-9-configure-ssh-authentication)
|
||||
- [Step 10. Test SSH connection](#step-10-test-ssh-connection)
|
||||
- [Step 11. Validate installation](#step-11-validate-installation)
|
||||
- [Step 13. Cleanup and rollback](#step-13-cleanup-and-rollback)
|
||||
- [Step 14. Next steps](#step-14-next-steps)
|
||||
- [Step 12. Access DGX Dashboard over Tailnet](#step-12-access-dgx-dashboard-over-tailnet)
|
||||
- [Step 13. Next steps](#step-13-next-steps)
|
||||
- [Step 14. Cleanup and rollback](#step-14-cleanup-and-rollback)
|
||||
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
@ -316,14 +318,89 @@ Expected output:
|
||||
- Successful file transfers
|
||||
- Remote command execution working
|
||||
|
||||
### Step 13. Cleanup and rollback
|
||||
### Step 12. Access DGX Dashboard over Tailnet
|
||||
|
||||
The DGX Dashboard is locked to localhost:11000 for security. This means you can only access it over localhost thorugh the ssh tunnel. Instead of manually creating an SSH tunnel every time, use Tailscale Serve to proxy the traffic so you can access it via your Tailscale IP/URL from any device.
|
||||
|
||||
## On your DGX Spark machine, run:
|
||||
```bash
|
||||
## Proxy incoming Tailnet traffic to the local dashboard
|
||||
## The --bg flag ensures this keeps running after you close your terminal
|
||||
sudo tailscale serve --bg --http=11000 localhost:11000
|
||||
```
|
||||
|
||||
## Verify proxy is active:
|
||||
```bash
|
||||
tailscale serve status
|
||||
```
|
||||
|
||||
You can access the dashboard using the Tailscale IP address:
|
||||
|
||||
`http://<TAILSCALE_IP>:11000`
|
||||
|
||||
You can find your Tailscale IP by running `tailscale ip -4` on the DGX Spark device.
|
||||
|
||||
Alternatively, if you set up tailsale with Magic DNS, you can use your tailscale URL with:
|
||||
|
||||
`http://SPARK_HOST_NAME.XXXXX-YYYYYY.ts.net:11000`
|
||||
|
||||
Where XXXXX an YYYYYY are part of the custom domain name to your tailnet.
|
||||
|
||||
You can now bookmark this URL and access it anywhere on your tailnet.
|
||||
|
||||
**Option: Enable HTTPS (recommended for security)**
|
||||
|
||||
For secure HTTPS access with SSL certificates, enable MagicDNS and HTTPS Certificates in your Tailscale Admin Console:
|
||||
|
||||
1. Go to your Tailscale Admin Console
|
||||
2. Under DNS, ensure MagicDNS is enabled
|
||||
3. Scroll down to HTTPS Certificates and click Enable
|
||||
|
||||
Then, on your DGX Spark machine, reset the HTTP proxy and start the HTTPS proxy:
|
||||
|
||||
```bash
|
||||
# First, reset the old HTTP proxy
|
||||
sudo tailscale serve --http=11000 off
|
||||
|
||||
# Now, start the HTTPS proxy
|
||||
sudo tailscale serve --bg --https=11000 localhost:11000
|
||||
```
|
||||
|
||||
Access the dashboard securely via: `https://SPARK_HOST_NAME.XXXXX-YYYYYY.ts.net:11000`
|
||||
> **Note:** It may take a little longer on first load to set the SSL certificate. This is normal.
|
||||
|
||||
### Step 13. Next steps
|
||||
|
||||
Your Tailscale setup is complete. You can now:
|
||||
|
||||
- Access your DGX Spark device from any network with: `ssh <USERNAME>@<SPARK_HOSTNAME>`
|
||||
- Transfer files securely: `scp file.txt <USERNAME>@<SPARK_HOSTNAME>:~/`
|
||||
- Open the DGX Dashboard and start JupyterLab, then connect with:
|
||||
`ssh -L 8888:localhost:1102 <USERNAME>@<SPARK_HOSTNAME>`
|
||||
|
||||
> **Note:** Alternatively, see Step 12 for accessing the DGX Dashboard over Tailnet without manual SSH tunneling.
|
||||
|
||||
|
||||
### Step 14. Cleanup and rollback
|
||||
|
||||
Remove Tailscale completely if needed. This will disconnect devices from the
|
||||
tailnet and remove all network configurations.
|
||||
|
||||
**Option A: Remove only DGX Dashboard access**
|
||||
|
||||
If you want to keep Tailscale installed but stop serving the DGX Dashboard:
|
||||
|
||||
```bash
|
||||
## Remove DGX Dashboard access from tailnet (from Step 12)
|
||||
sudo tailscale serve --http=11000 off
|
||||
sudo tailscale serve --https=11000 off
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This will permanently remove the device from your Tailscale network and require re-authentication to rejoin.
|
||||
|
||||
**Option B: Full Tailscale removal**
|
||||
|
||||
```bash
|
||||
## Stop Tailscale service
|
||||
sudo tailscale down
|
||||
@ -337,19 +414,12 @@ sudo rm /usr/share/keyrings/tailscale-archive-keyring.gpg
|
||||
|
||||
## Update package list
|
||||
sudo apt update
|
||||
|
||||
```
|
||||
|
||||
|
||||
To restore: Re-run installation steps 3-5.
|
||||
|
||||
### Step 14. Next steps
|
||||
|
||||
Your Tailscale setup is complete. You can now:
|
||||
|
||||
- Access your DGX Spark device from any network with: `ssh <USERNAME>@<SPARK_HOSTNAME>`
|
||||
- Transfer files securely: `scp file.txt <USERNAME>@<SPARK_HOSTNAME>:~/`
|
||||
- Open the DGX Dashboard and start JupyterLab, then connect with:
|
||||
`ssh -L 8888:localhost:1102 <USERNAME>@<SPARK_HOSTNAME>`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|
||||
Loading…
Reference in New Issue
Block a user