[JAX](https://jax.readthedocs.io)isanimplementationofasignificantsubsetoftheNumPyAPIwithsomefeaturesthatmakeitespeciallysuitableformachinelearningresearchandhigh-performancecomputing.Aswe'll see, these same features also make JAX extremely useful for accelerating functions and prototypes that we'vedevelopedinNumPy.ThisnotebookwillprovideaquickintroductiontojustsomeofthefeaturesofJAXthatwe'll be using in the rest of this workshop -- as well as pointers to some of the potential pitfalls you might run in to with it. There'salotmoretoJAXthanwe'll be able to cover in this notebook (and the rest of the workshop), so you'llwanttoreadthe(great)documentationasyoudiveinmore.
We'll start by importing JAX and its implementation of the NumPy API. By convention, we'llimport`jax.numpy`as`jnp`andregular`numpy`as`np`—thisisbecausewecanusebothinourprogramsandwemaywanttousebothfordifferentthings.
"""
)
return
@app.cell
def_():
importjax
importnumpyasnp
importjax.numpyasjnp
fromtimeitimporttimeit
returnjax,jnp,np,timeit
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""We can use `jax.numpy` as a drop-in replacement for `numpy` in many cases (we'll see some caveats later in this notebook). In many cases, JAX arrays interoperate transparently with NumPy arrays.""")
return
@app.cell
def_(jnp):
za=jnp.zeros(7)
za
return(za,)
@app.cell
def_(np,za):
# note that we're doing elementwise addition
# between a NumPy array and a JAX array
zna=np.ones(7)
za+zna
return(zna,)
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""There are some differences, though, and an important one is that JAX arrays can be stored in GPU memory. If you're running this notebook with a GPU-enabled version of JAX, you can see where our array is stored:""")
return
@app.cell
def_(za):
za.device
return
@app.cell(hide_code=True)
def_(mo):
mo.md(
r"""
## NumPy and JAX
Byitself,theJAXimplementationsofNumPyoperationsareunlikelytobemuchfasterthanthosefromNumPyor(especially)cuPy,but—aswe'll see — JAX offers some special functionality that can make JAX code much faster, especially on a GPU. Let'sstartbydoingsomesimpletimingsofoperations,though.
We'll use the `device_put` method to convert a NumPy array to a JAX array.
"""
)
return
@app.cell
def_(jax,np):
random_shape=(8192,2048)
random_values=np.random.random(size=random_shape)
jrandom_values=jax.device_put(random_values)
returnjrandom_values,random_values
@app.cell
def_(np,random_values):
np.matmul(random_values,random_values.T)
return
@app.cell
def_(np,random_values,timeit):
_result=timeit(lambda:
np.matmul(random_values,random_values.T),
number=10)
print(f"NumPy matrix multiplication took {_result/10:.4f} seconds per iteration")
print(f"JAX matrix multiplication took {_result/10:.4f} seconds per iteration")
return
@app.cell(hide_code=True)
def_(mo):
mo.md(
r"""
The`block_until_ready`isaspecialdetailthatisimportantwhenwe're getting timings of single lines of JAX code — basically, JAX dispatches our code to the GPU asynchronously and we need to make sure that the operation has completed before we consider it done for the purposes of timing it.
✅YouprobablysawthatJAXwasfasterthanNumPywiththematrixshapeweprovided(in`random_shape`).Makesurethatyou've added `.block_until_ready()` back to the JAX code and try some other matrix shapes (both larger and smaller). Does JAX exhibit more of a speed advantage on some matrix sizes than others? Is JAX slower than NumPy for some of these? Why (or why not), do you suppose?
"""
)
return
@app.cell(hide_code=True)
def_(mo):
mo.md(
r"""
## Functional array updates
OnemajordifferencebetweenJAXandNumPyisthatJAXarraysare_immutable_.Thismeansthatonceyoucreateanarray,youcan't update it in place. In NumPy, you'ddothis:
"""
)
return
@app.cell
def_(zna):
zna[3]=5.0
zna
return
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""...whereas in JAX, you'd need to use some methods to make a copy of the array changing only one value:""")
return
@app.cell
def_(za):
za_1=za.at[3].set(5.0)
za_1
return
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""Those of you who have used functional languages will likely be comfortable with this style, but it may be an adjustment. (It's not necessarily as inefficient as it sounds! See [here](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation) for more details on how to avoid copies — and what JAX does under the hood.)""")
mo.md(r"""When we split `key`, we generated two new keys. The advantage of using `nextkey` in the call to `jax.random.poisson` is that we don't have to explicitly assign to `key` later on (e.g., if we were in a loop).""")
return
@app.cell(hide_code=True)
def_(mo):
mo.md(
r"""
## Just-in-time compilation
Ifyou're running with acclerated hardware, JAX'simplementationsofNumPyfunctionsrequiresendingcode(andsometimesdata)totheGPU.Thismaynotbenoticeableifyou're doing a lot of work, but it can impact performance if you'reinvokingmanysmallfunctions.JAXprovidesamethodfor_just-in-timecompilation_sothatthefirsttimeyouexecuteafunctionitproducesaspecializedversionthatcanexecutemoreefficiently.
We'll see the `jax.jit` function in action later in this workshop. There are some things we'llneedtokeepinmindtouseiteffectively,andwe'll cover those when we get to them!
"""
)
return
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""## Parallelizing along axes""")
return
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""A powerful feature of JAX is the capability to parallelize functions along axes. So if, for example, you want to calculate the norm of each row in a matrix, you can do each of these in parallel. We'll start by generating a random matrix and moving it to the GPU again:""")
print(f"JAX vmapped norm took {_result/100:.7f} seconds per iteration")
return
@app.cell(hide_code=True)
def_(mo):
mo.md(
r"""
CalculatingthenormofeveryrowisalreadyprettyefficientinJAX,sowedon't see much (if any) performance improvement from mapping over axes. (This will likely be the case for most NumPy functions in JAX that have an `axis` argument.) However, it'saneasyexampletounderstandandwe'll see a higher-impact application of `vmap` in the next notebook.
## Timings
Intheabovecell,we've used the `timeit` module in the standard library to repeatedly execute a small code snippet and get the average execution time. For longer-executing cells, we can simply use marimo'sdirectsupportforrecordingcelltimingstoseehowlongtheyexecuted.
✅Mouseoveracellthathasexecutedandlookfortiminginformation.IntheversionofmarimoI'm using now, it will show up in the right margin, but only when you mouse over the cell. For the cells above, the timing should be roughly the value printed out times the number of iterations in the `timeit` call, so if the cell printed something like:
thenyou'd expect the cell timing to show something like 3.6 seconds, given that we ran 100 iterations of the code.
## Automatic differentiation
AparticularlyinterestingfeatureofJAXisitssupportfor_automaticdifferentiation_.Thismeansthat,givenafunction$f(x)$,JAXcanautomaticallycalculate$f'(x)$, or the _derivative_ of $f$, which is a function describing the rate of change between $f(x)$ and $f(x + \epsilon)$, where $\epsilon$ is a very small number. (JAX can also calculate the derivative for functions of multiple arguments, but our running example in this notebook will be a single-argument function.)
Ifyourdailyworkregularlyinvolvesimplementingmachinelearningandoptimizationalgorithms,youprobablyalreadyhavesomeideaswhythisfunctionalitycouldbeuseful.(Ifitdoesn't and you'recurious,[here's an explanation](https://en.wikipedia.org/wiki/Gradient_descent) to read on your own time.)
Intherestofthisnotebook,we'll show an example of a problem we can solve with the help of JAX'ssupportforautomaticdifferentiation.Sincenoteveryonespendstheirdaysthinkingaboutoptimizingfunctions,we've chosen a problem that doesn'trequireanyspecializedmathematicalormachinelearningbackgroundtounderstand,butwe'll throw in a wrinkle at the end to show everyone why JAX'sautomaticdifferentiationisespeciallycool.
Let's start with a very simple Python function:
"""
)
return
@app.function
defsquare(x):
returnx*x
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""We can calculate the derivative of `square` numerically, by calculating the slope of `x * x` while making a very small change to `x`.""")
return
@app.function
defsquare_num_prime(x,h=1e-5):
above=x+h
below=x-h
rise=(above**2)-(below**2)
run=h*2
returnrise/run
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""You may remember the [power rule](https://en.wikipedia.org/wiki/Power_rule), which states that the the derivative of $x^a$ is $ax^{a-1}$. Given this rule, the derivative of $x^2 = 2x^{2-1} = 2x$. We can use this to check our answer for several values of $x$.""")
Youmayhavenoticedthatnoteveryresultiswhatwe'd expect! There are several ways in which numerical differentiation may not produce a precise result, but for this example we may be able to improve the results by changing the range around $x$ for which we'remeasuringthechangein$x^2$.
Wecanuse`square_prime`toimplement[Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method#Square_root) for the specific problem of approximating square roots. Basically the idea is that we'llstartwithaninitialguessasourcurrentguessandthenrepeatedly:
Afterthethirdstep,we'll compare the square of our guess to our goal number and stop if it'scloseenoughorifwe've gone a certain number of iterations. (You may have used a similar but less-efficient method of iteratively refining guesses for square roots on paper in a primary school arithmetic class!)
We'll start with a function that updates a guess value given a guess and a goal:
"""
)
return
@app.cell
def_(square_prime):
defguess_sqrt(guess,goal):
n=((guess*guess)-goal)
d=square_prime(guess)
returnguess-(n/d)
return(guess_sqrt,)
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""We'll then build out the whole method, including a tolerance value (i.e., how close does `guess ** 2` need to be to `target` for us to accept it?) and a maximum number of iterations so we don't accidentally get into an infinite loop.""")
Automaticdifferentiationisacooltechnology,but—asyoumayobject—differentiating$x^2$isn't a particularly cool application. If this were the extent of our requirements, we could simply implement a few rules that inspected Python functions and replaced functions by their derivatives. If we didn'twanttogetourhandsdirty,wecouldprobablyalsousealibrarylike`sympy`orhireafirst-yearundergraduatetoperformourcalculationsforus.
JAXisn't limited to trivial functions, though. Let'stakealookatamoresyntactically(andsemantically)compleximplementationofthesquarefunction.We'll call it `bogus_square` to emphasize that it is a contrived example that is meant to be difficult for JAX to deal with.
"""
)
return
@app.cell
def_(jnp):
defbogus_square(x):
result=0
for_inrange(int(jnp.floor(x))):
result=result+x
returnresult+x*(x-jnp.floor(x))
return(bogus_square,)
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""We can check our results on some examples:""")
mo.md(r"""✅ Test your function out with a few examples. You may want to avoid finding square roots of larger numbers (we'll see why in a second).""")
return
@app.cell
def_():
newton_bogus_sqrt(8.0,72.0)
return
@app.cell(hide_code=True)
def_(mo):
mo.md(
r"""
It's impressive that JAX can differentiate `bogus_square`, but this doesn'tmeanthatwe're free to use pathological implementations in our code. The derivative of `bogus_square` is much slower to compute than that of `square`, which means that the performance of an iterative process that depends on computing this many times (like machine learning model training or even like approximating square roots) will suffer.