# Optimized Jax > Develop with Optimized Jax ## Table of Contents - [Overview](#overview) - [Instructions](#instructions) --- ## Overview ## Basic idea JAX lets you write **NumPy-style Python code** and run it fast on GPUs without writing CUDA. It does this by: - **NumPy on accelerators**: Use `jax.numpy` just like NumPy, but arrays live on the GPU. - **Function transformations**: - `jit` → Compiles your function into fast GPU code. - `grad` → Gives you automatic differentiation. - `vmap` → Vectorizes your function across batches. - `pmap` → Runs across multiple GPUs in parallel. - **XLA backend**: JAX hands your code to XLA (Accelerated Linear Algebra compiler), which fuses operations and generates optimized GPU kernels. ## What you'll accomplish You'll set up a JAX development environment on NVIDIA Spark with Blackwell architecture that enables high-performance machine learning prototyping using familiar NumPy-like abstractions, complete with GPU acceleration and performance optimization capabilities. ## What to know before starting - Comfortable with Python and NumPy programming - General understanding of machine learning workflows and techniques - Experience working in a terminal - Experience using and building containers - Familiarity with different versions of CUDA - Basic understanding of linear algebra (high-school level math sufficient) ## Prerequisites [ ] NVIDIA Spark device with Blackwell architecture [ ] ARM64 (AArch64) processor architecture [ ] Docker or container runtime installed [ ] NVIDIA Container Toolkit configured [ ] Verify GPU access: `nvidia-smi` [ ] Verify Docker GPU support: `docker run --gpus all nvidia/cuda:12.0-base-ubuntu20.04 nvidia-smi` [ ] Port 8080 available for marimo notebook access ## Ancillary files All required assets can be found [here on GitHub](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main) - [**JAX introduction notebook**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/jax-intro.py) — covers JAX programming model differences from NumPy and performance evaluation - [**NumPy SOM implementation**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/numpy-som.py) — reference implementation of self-organized map training algorithm in NumPy - [**JAX SOM implementations**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/som-jax.py) — multiple iteratively refined implementations of SOM algorithm in JAX - [**Environment configuration**](https://gitlab.com/nvidia/dgx-spark/temp-external-playbook-assets/dgx-spark-playbook-assets/-/blob/main/${MODEL}/assets/Dockerfile) — package dependencies and container setup specifications - [**Course guide notebook**]() — overall material navigation and learning path ## Time & risk **Duration:** 2-3 hours including setup, tutorial completion, and validation **Risks:** - Package dependency conflicts in Python environment - Performance validation may require architecture-specific optimizations **Rollback:** Container environments provide isolation; remove containers and restart to reset state. ## Instructions ## Step 1. Verify system prerequisites Confirm your NVIDIA Spark system meets the requirements and has GPU access configured. ```bash ## Verify GPU access nvidia-smi ## Verify ARM64 architecture uname -m ## Check Docker GPU support docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi ``` If the `docker` command fails with a permission error, you can either 1. run it with `sudo`, e.g., `sudo docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi`, or 2. add yourself to the `docker` group so you can use `docker` without `sudo`. To add yourself to the `docker` group, first run `sudo usermod -aG docker $USER`. Then, as your user account, either run `newgrp docker` or log out and log back in. ## Step 2. Build a Docker image > **Warning:** This command will download a base image and build a container locally to support this environment ```bash cd jax-assets docker build -t jax-on-spark . ``` ## Step 3. Launch Docker container Run the JAX development environment in a Docker container with GPU support and port forwarding for marimo access. ```bash docker run --gpus all --rm -it \ --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \ -p 8080:8080 \ jax-on-spark ``` ## Step 4. Access marimo interface Connect to the marimo notebook server to begin the JAX tutorial. ```bash ## Access via web browser ## Navigate to: http://localhost:8080 ``` The interface will load a table-of-contents display and brief introduction to marimo. ## Step 5. Complete JAX introduction tutorial Work through the introductory material to understand JAX programming model differences from NumPy. Navigate to and complete the JAX introduction notebook, which covers: - JAX programming model fundamentals - Key differences from NumPy - Performance evaluation techniques ## Step 6. Implement NumPy baseline Complete the NumPy-based self-organized map (SOM) implementation to establish a performance baseline. Work through the NumPy SOM notebook to: - Understand the SOM training algorithm - Implement the algorithm using familiar NumPy operations - Record performance metrics for comparison ## Step 7. Optimize with JAX implementations Progress through the iteratively refined JAX implementations to see performance improvements. Complete the JAX SOM notebook sections: - Basic JAX port of NumPy implementation - Performance-optimized JAX version - GPU-accelerated parallel JAX implementation - Compare performance across all versions ## Step 8. Validate performance gains The notebooks will show you how to check the performance of each SOM training implementation; you'll see that that JAX implementations show performance improvements over NumPy baseline (and some will be quite a lot faster). Visually inspect the SOM training output on random color data to confirm algorithm correctness. ## Step 10. Validate installation Confirm all components are working correctly and notebooks execute successfully. ```bash ## Test GPU JAX functionality python -c "import jax; print(jax.devices()); print(jax.device_count())" ## Verify JAX can access GPU python -c "import jax.numpy as jnp; x = jnp.array([1, 2, 3]); print(x.device())" ``` Expected output should show GPU devices detected and JAX arrays placed on GPU. ## Step 11. Troubleshooting Common issues and their solutions: | Symptom | Cause | Fix | |---------|--------|-----| | `nvidia-smi` not found | Missing NVIDIA drivers | Install NVIDIA drivers for ARM64 | | Container fails to access GPU | Missing NVIDIA Container Toolkit | Install nvidia-container-toolkit | | JAX only uses CPU | CUDA/JAX version mismatch | Reinstall JAX with CUDA support | | Port 8080 unavailable | Port already in use | Use `-p 8081:8080` or kill process on 8080 | | Package conflicts in Docker build | Outdated environment file | Update environment file for Blackwell | ## Step 12. Cleanup and rollback Remove containers and reset environment if needed. > **Warning:** This will remove all container data and downloaded images. ```bash ## Stop and remove containers docker stop $(docker ps -q) docker system prune -f ## Reset pipenv environment pipenv --rm ``` To rollback: Re-run installation steps from Step 2. ## Step 13. Next steps Apply JAX optimization techniques to your own NumPy-based machine learning code. ```bash ## Example: Profile your existing NumPy code python -m cProfile your_numpy_script.py ## Then adapt to JAX and compare performance ``` Try adapting your favorite NumPy algorithms to JAX and measure performance improvements on Blackwell GPU architecture.