dgx-spark-playbooks/skills/dgx-spark-jax/SKILL.md
2026-04-19 09:25:00 +00:00

1.1 KiB

name description
dgx-spark-jax Optimize JAX to run on Spark — on NVIDIA DGX Spark. Use when setting up jax on Spark hardware.

Optimized JAX

Optimize JAX to run on Spark

JAX lets you write NumPy-style Python code and run it fast on GPUs without writing CUDA. It does this by:

  • NumPy on accelerators: Use jax.numpy just like NumPy, but arrays live on the GPU.
  • Function transformations:
    • jit → Compiles your function into fast GPU code
    • grad → Gives you automatic differentiation
    • vmap → Vectorizes your function across batches
    • pmap → Runs across multiple GPUs in parallel

Outcome: You'll set up a JAX development environment on NVIDIA Spark with Blackwell architecture that enables high-performance machine learning prototyping using familiar NumPy-like abstractions, complete with GPU acceleration and performance optimization capabilities.

Full playbook: /home/runner/work/dgx-spark-playbooks/dgx-spark-playbooks/nvidia/jax/README.md