mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-23 02:23:53 +00:00
218 lines
7.8 KiB
Markdown
218 lines
7.8 KiB
Markdown
# Optimized Jax
|
|
|
|
> Develop with Optimized Jax
|
|
|
|
## Table of Contents
|
|
|
|
- [Overview](#overview)
|
|
- [Instructions](#instructions)
|
|
|
|
---
|
|
|
|
## Overview
|
|
|
|
## Basic idea
|
|
|
|
JAX lets you write **NumPy-style Python code** and run it fast on GPUs without writing CUDA. It does this by:
|
|
|
|
- **NumPy on accelerators**: Use `jax.numpy` just like NumPy, but arrays live on the GPU.
|
|
- **Function transformations**:
|
|
- `jit` → Compiles your function into fast GPU code.
|
|
- `grad` → Gives you automatic differentiation.
|
|
- `vmap` → Vectorizes your function across batches.
|
|
- `pmap` → Runs across multiple GPUs in parallel.
|
|
- **XLA backend**: JAX hands your code to XLA (Accelerated Linear Algebra compiler), which fuses operations and generates optimized GPU kernels.
|
|
|
|
## What you'll accomplish
|
|
|
|
You'll set up a JAX development environment on NVIDIA Spark with Blackwell architecture that enables
|
|
high-performance machine learning prototyping using familiar NumPy-like abstractions, complete with
|
|
GPU acceleration and performance optimization capabilities.
|
|
|
|
## What to know before starting
|
|
|
|
- Comfortable with Python and NumPy programming
|
|
- General understanding of machine learning workflows and techniques
|
|
- Experience working in a terminal
|
|
- Experience using and building containers
|
|
- Familiarity with different versions of CUDA
|
|
- Basic understanding of linear algebra (high-school level math sufficient)
|
|
|
|
## Prerequisites
|
|
|
|
[ ] NVIDIA Spark device with Blackwell architecture
|
|
[ ] ARM64 (AArch64) processor architecture
|
|
[ ] Docker or container runtime installed
|
|
[ ] NVIDIA Container Toolkit configured
|
|
[ ] Verify GPU access: `nvidia-smi`
|
|
[ ] Verify Docker GPU support: `docker run --gpus all nvidia/cuda:12.0-base-ubuntu20.04 nvidia-smi`
|
|
[ ] Port 8080 available for marimo notebook access
|
|
|
|
## Ancillary files
|
|
|
|
All required assets can be found [here on GitHub](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main)
|
|
|
|
- [**JAX introduction notebook**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/jax-intro.py) — covers JAX programming model differences from NumPy and performance evaluation
|
|
- [**NumPy SOM implementation**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/numpy-som.py) — reference implementation of self-organized map training algorithm in NumPy
|
|
- [**JAX SOM implementations**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/som-jax.py) — multiple iteratively refined implementations of SOM algorithm in JAX
|
|
- [**Environment configuration**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/Dockerfile) — package dependencies and container setup specifications
|
|
- [**Course guide notebook**]() — overall material navigation and learning path
|
|
|
|
## Time & risk
|
|
|
|
**Duration:** 2-3 hours including setup, tutorial completion, and validation
|
|
|
|
**Risks:**
|
|
- Package dependency conflicts in Python environment
|
|
- Performance validation may require architecture-specific optimizations
|
|
|
|
**Rollback:** Container environments provide isolation; remove containers and restart to reset state.
|
|
|
|
## Instructions
|
|
|
|
## Step 1. Verify system prerequisites
|
|
|
|
Confirm your NVIDIA Spark system meets the requirements and has GPU access configured.
|
|
|
|
```bash
|
|
## Verify GPU access
|
|
nvidia-smi
|
|
|
|
## Verify ARM64 architecture
|
|
uname -m
|
|
|
|
## Check Docker GPU support
|
|
docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi
|
|
```
|
|
|
|
If the `docker` command fails with a permission error, you can either
|
|
|
|
1. run it with `sudo`, e.g., `sudo docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi`, or
|
|
2. add yourself to the `docker` group so you can use `docker` without `sudo`.
|
|
|
|
To add yourself to the `docker` group, first run `sudo usermod -aG docker $USER`. Then, as your user account, either run `newgrp docker` or log out and log back in.
|
|
|
|
## Step 2. Build a Docker image
|
|
|
|
|
|
> **Warning:** This command will download a base image and build a container locally to support this environment
|
|
|
|
```bash
|
|
cd jax-assets
|
|
docker build -t jax-on-spark .
|
|
```
|
|
|
|
## Step 3. Launch Docker container
|
|
|
|
Run the JAX development environment in a Docker container with GPU support and port forwarding for marimo access.
|
|
|
|
```bash
|
|
docker run --gpus all --rm -it \
|
|
--shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \
|
|
-p 8080:8080 \
|
|
jax-on-spark
|
|
```
|
|
|
|
## Step 4. Access marimo interface
|
|
|
|
Connect to the marimo notebook server to begin the JAX tutorial.
|
|
|
|
```bash
|
|
## Access via web browser
|
|
## Navigate to: http://localhost:8080
|
|
```
|
|
|
|
The interface will load a table-of-contents display and brief introduction to marimo.
|
|
|
|
## Step 5. Complete JAX introduction tutorial
|
|
|
|
Work through the introductory material to understand JAX programming model differences from NumPy.
|
|
|
|
Navigate to and complete the JAX introduction notebook, which covers:
|
|
- JAX programming model fundamentals
|
|
- Key differences from NumPy
|
|
- Performance evaluation techniques
|
|
|
|
## Step 6. Implement NumPy baseline
|
|
|
|
Complete the NumPy-based self-organized map (SOM) implementation to establish a performance
|
|
baseline.
|
|
|
|
Work through the NumPy SOM notebook to:
|
|
- Understand the SOM training algorithm
|
|
- Implement the algorithm using familiar NumPy operations
|
|
- Record performance metrics for comparison
|
|
|
|
## Step 7. Optimize with JAX implementations
|
|
|
|
Progress through the iteratively refined JAX implementations to see performance improvements.
|
|
|
|
Complete the JAX SOM notebook sections:
|
|
- Basic JAX port of NumPy implementation
|
|
- Performance-optimized JAX version
|
|
- GPU-accelerated parallel JAX implementation
|
|
- Compare performance across all versions
|
|
|
|
## Step 8. Validate performance gains
|
|
|
|
The notebooks will show you how to check the performance of each SOM training implementation; you'll see that that JAX implementations show performance improvements over NumPy baseline (and some will be quite a lot faster).
|
|
|
|
Visually inspect the SOM training output on random color data to confirm algorithm correctness.
|
|
|
|
## Step 10. Validate installation
|
|
|
|
Confirm all components are working correctly and notebooks execute successfully.
|
|
|
|
```bash
|
|
## Test GPU JAX functionality
|
|
python -c "import jax; print(jax.devices()); print(jax.device_count())"
|
|
|
|
## Verify JAX can access GPU
|
|
python -c "import jax.numpy as jnp; x = jnp.array([1, 2, 3]); print(x.device())"
|
|
```
|
|
|
|
Expected output should show GPU devices detected and JAX arrays placed on GPU.
|
|
|
|
## Step 11. Troubleshooting
|
|
|
|
Common issues and their solutions:
|
|
|
|
| Symptom | Cause | Fix |
|
|
|---------|--------|-----|
|
|
| `nvidia-smi` not found | Missing NVIDIA drivers | Install NVIDIA drivers for ARM64 |
|
|
| Container fails to access GPU | Missing NVIDIA Container Toolkit | Install nvidia-container-toolkit |
|
|
| JAX only uses CPU | CUDA/JAX version mismatch | Reinstall JAX with CUDA support |
|
|
| Port 8080 unavailable | Port already in use | Use `-p 8081:8080` or kill process on 8080 |
|
|
| Package conflicts in Docker build | Outdated environment file | Update environment file for Blackwell |
|
|
|
|
## Step 12. Cleanup and rollback
|
|
|
|
Remove containers and reset environment if needed.
|
|
|
|
> **Warning:** This will remove all container data and downloaded images.
|
|
|
|
```bash
|
|
## Stop and remove containers
|
|
docker stop $(docker ps -q)
|
|
docker system prune -f
|
|
|
|
## Reset pipenv environment
|
|
pipenv --rm
|
|
```
|
|
|
|
To rollback: Re-run installation steps from Step 2.
|
|
|
|
## Step 13. Next steps
|
|
|
|
Apply JAX optimization techniques to your own NumPy-based machine learning code.
|
|
|
|
```bash
|
|
## Example: Profile your existing NumPy code
|
|
python -m cProfile your_numpy_script.py
|
|
|
|
## Then adapt to JAX and compare performance
|
|
```
|
|
|
|
Try adapting your favorite NumPy algorithms to JAX and measure performance improvements on
|
|
Blackwell GPU architecture.
|