mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-25 03:13:53 +00:00
1.1 KiB
1.1 KiB
| name | description |
|---|---|
| dgx-spark-jax | Optimize JAX to run on Spark — on NVIDIA DGX Spark. Use when setting up jax on Spark hardware. |
Optimized JAX
Optimize JAX to run on Spark
JAX lets you write NumPy-style Python code and run it fast on GPUs without writing CUDA. It does this by:
- NumPy on accelerators: Use
jax.numpyjust like NumPy, but arrays live on the GPU. - Function transformations:
jit→ Compiles your function into fast GPU codegrad→ Gives you automatic differentiationvmap→ Vectorizes your function across batchespmap→ Runs across multiple GPUs in parallel
Outcome: You'll set up a JAX development environment on NVIDIA Spark with Blackwell architecture that enables high-performance machine learning prototyping using familiar NumPy-like abstractions, complete with GPU acceleration and performance optimization capabilities.
Full playbook: /home/runner/work/dgx-spark-playbooks/dgx-spark-playbooks/nvidia/jax/README.md