| .. | ||
| assets | ||
| README.md | ||
Optimized Jax
Develop with Optimized Jax
Table of Contents
Overview
Basic idea
JAX lets you write NumPy-style Python code and run it fast on GPUs without writing CUDA. It does this by:
- NumPy on accelerators: Use
jax.numpyjust like NumPy, but arrays live on the GPU. - Function transformations:
jit→ Compiles your function into fast GPU code.grad→ Gives you automatic differentiation.vmap→ Vectorizes your function across batches.pmap→ Runs across multiple GPUs in parallel.
- XLA backend: JAX hands your code to XLA (Accelerated Linear Algebra compiler), which fuses operations and generates optimized GPU kernels.
What you'll accomplish
You'll set up a JAX development environment on NVIDIA Spark with Blackwell architecture that enables high-performance machine learning prototyping using familiar NumPy-like abstractions, complete with GPU acceleration and performance optimization capabilities.
What to know before starting
- Comfortable with Python and NumPy programming
- General understanding of machine learning workflows and techniques
- Experience working in a terminal
- Experience using and building containers
- Familiarity with different versions of CUDA
- Basic understanding of linear algebra (high-school level math sufficient)
Prerequisites
[ ] NVIDIA Spark device with Blackwell architecture
[ ] ARM64 (AArch64) processor architecture
[ ] Docker or container runtime installed
[ ] NVIDIA Container Toolkit configured
[ ] Verify GPU access: nvidia-smi
[ ] Verify Docker GPU support: docker run --gpus all nvidia/cuda:12.0-base-ubuntu20.04 nvidia-smi
[ ] Port 8080 available for marimo notebook access
Ancillary files
All required assets can be found here on GitHub
- JAX introduction notebook — covers JAX programming model differences from NumPy and performance evaluation
- NumPy SOM implementation — reference implementation of self-organized map training algorithm in NumPy
- JAX SOM implementations — multiple iteratively refined implementations of SOM algorithm in JAX
- Environment configuration — package dependencies and container setup specifications
- Course guide notebook — overall material navigation and learning path
Time & risk
Duration: 2-3 hours including setup, tutorial completion, and validation
Risks:
- Package dependency conflicts in Python environment
- Performance validation may require architecture-specific optimizations
Rollback: Container environments provide isolation; remove containers and restart to reset state.
Instructions
Step 1. Verify system prerequisites
Confirm your NVIDIA Spark system meets the requirements and has GPU access configured.
## Verify GPU access
nvidia-smi
## Verify ARM64 architecture
uname -m
## Check Docker GPU support
docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi
If the docker command fails with a permission error, you can either
- run it with
sudo, e.g.,sudo docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi, or - add yourself to the
dockergroup so you can usedockerwithoutsudo.
To add yourself to the docker group, first run sudo usermod -aG docker $USER. Then, as your user account, either run newgrp docker or log out and log back in.
Step 2. Build a Docker image
Warning: This command will download a base image and build a container locally to support this environment
cd jax-assets
docker build -t jax-on-spark .
Step 3. Launch Docker container
Run the JAX development environment in a Docker container with GPU support and port forwarding for marimo access.
docker run --gpus all --rm -it \
--shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \
-p 8080:8080 \
jax-on-spark
Step 4. Access marimo interface
Connect to the marimo notebook server to begin the JAX tutorial.
## Access via web browser
## Navigate to: http://localhost:8080
The interface will load a table-of-contents display and brief introduction to marimo.
Step 5. Complete JAX introduction tutorial
Work through the introductory material to understand JAX programming model differences from NumPy.
Navigate to and complete the JAX introduction notebook, which covers:
- JAX programming model fundamentals
- Key differences from NumPy
- Performance evaluation techniques
Step 6. Implement NumPy baseline
Complete the NumPy-based self-organized map (SOM) implementation to establish a performance baseline.
Work through the NumPy SOM notebook to:
- Understand the SOM training algorithm
- Implement the algorithm using familiar NumPy operations
- Record performance metrics for comparison
Step 7. Optimize with JAX implementations
Progress through the iteratively refined JAX implementations to see performance improvements.
Complete the JAX SOM notebook sections:
- Basic JAX port of NumPy implementation
- Performance-optimized JAX version
- GPU-accelerated parallel JAX implementation
- Compare performance across all versions
Step 8. Validate performance gains
The notebooks will show you how to check the performance of each SOM training implementation; you'll see that that JAX implementations show performance improvements over NumPy baseline (and some will be quite a lot faster).
Visually inspect the SOM training output on random color data to confirm algorithm correctness.
Step 10. Validate installation
Confirm all components are working correctly and notebooks execute successfully.
## Test GPU JAX functionality
python -c "import jax; print(jax.devices()); print(jax.device_count())"
## Verify JAX can access GPU
python -c "import jax.numpy as jnp; x = jnp.array([1, 2, 3]); print(x.device())"
Expected output should show GPU devices detected and JAX arrays placed on GPU.
Step 11. Troubleshooting
Common issues and their solutions:
| Symptom | Cause | Fix |
|---|---|---|
nvidia-smi not found |
Missing NVIDIA drivers | Install NVIDIA drivers for ARM64 |
| Container fails to access GPU | Missing NVIDIA Container Toolkit | Install nvidia-container-toolkit |
| JAX only uses CPU | CUDA/JAX version mismatch | Reinstall JAX with CUDA support |
| Port 8080 unavailable | Port already in use | Use -p 8081:8080 or kill process on 8080 |
| Package conflicts in Docker build | Outdated environment file | Update environment file for Blackwell |
Step 12. Cleanup and rollback
Remove containers and reset environment if needed.
Warning: This will remove all container data and downloaded images.
## Stop and remove containers
docker stop $(docker ps -q)
docker system prune -f
## Reset pipenv environment
pipenv --rm
To rollback: Re-run installation steps from Step 2.
Step 13. Next steps
Apply JAX optimization techniques to your own NumPy-based machine learning code.
## Example: Profile your existing NumPy code
python -m cProfile your_numpy_script.py
## Then adapt to JAX and compare performance
Try adapting your favorite NumPy algorithms to JAX and measure performance improvements on Blackwell GPU architecture.