Inthisnotebook,we'll develop implementations of online and batch self-organizing map training in JAX, refining each as we go to get better performance. We'llstartwiththeeasiestoption:simplyusingJAXasadrop-inreplacementfornumpy.
## Accelerating NumPy functions with JAX
"""
)
return
@app.cell
def_():
importjax
importjax.numpyasjnp
importnumpyasnp
returnjax,jnp,np
@app.cell(hide_code=True)
def_(mo):
mo.md(
r"""
Recallthatwe're initializing our map with random vectors. The result of this function is a matrix with a row for every element in a self-organizing map; each row contains uniformly-sampled random numbers between 0 and 1.
BecauseJAXusesapurelyfunctionalapproachtorandomnumbergeneration,we'll need to rewrite this code from the numpy implementation -- instead of using a stateful generator like numpy's`Generator`or`RandomState`,we'll create a `PRNGKey` object and pass that to `jax.random.uniform`. (For this example, we'renotdoinganythingwiththekey—forarealapplication,we'd want to _split_ it so we could get the next number in the seeded sequence.)
mo.md(r"""...and we should be able to see that this array is stored in GPU memory (if we're actually running on a GPU).""")
return
@app.cell
def_(random_map):
random_map.device
return
@app.cell(hide_code=True)
def_(mo):
mo.md(r"""As before, you can visualize the result if you want — JAX will transfer arrays directly to device memory when needed by plotting libraries.""")
We're now ready to see the basic online (i.e., one sample at a time) training algorithm. Most of it is unchanged from the numpy implementation, with a few key differences:
1.Thefirstdifferencesarerelatedtohowweshuffletheexamplearray.Becausewearen't using a stateful random number generator, we'llneedtosplittherandomstatekeyintotwoparts(onerepresentingthekeyfortheverynextgenerationandonerepresentingthekeyfortherestofthestream).We'll declare a little helper function that splits the key, shuffles the array, and returns both the key and the shuffled array.
2.TheseconddifferencerelatestohowJAXhandlesarrays.InJAX,arraysofferan_immutable_interface:insteadofchanginganarraydirectly,JAX's API lets you make a copy of the array with a change. (In practice, this does not always mean the array is actually copied!) This impacts our code because the numpy version used some functions with output parameters, which indicate where to write the return value (rather than merely returning a new array). So, instead of `np.add(a, b, a)`, we'ddo`a=jnp.add(a,b)`.
Dependingonyourcomputer,thismayhaveactuallybeenslowerthanthenumpyversion!Let's try using JAX's_just-in-time_compilationtoimproveourperformance.We'll make just-in-time compiled versions of our `neighborhood` and `shuffle` functions (as well as of the inner part of the training loop). We'llalsoaddaprogressbar.
*The`alphas`variable(whichwedidn't expose as a parameter) indicates how much of an effect each example has on the map. We'vesetitto`jnp.geomspace(0.35,0.01,max_iter)`;trysomedifferentvaluesandseeifyougetbetterorworseresults!
Let's now consider the batch variant of the algorithm. It can be much faster, can be implemented in parallel (or even on a cluster) and is less sensitive to hyperparameter settings. In order to exploit additional parallelism, we'regoingtouse`jax.vmap`tocalculateweightupdatesforeachtrainingexampleinparallel.Thisshouldresultinadramaticperformanceimprovement.
Wehaveonlyoptimizedthebatchstephere(i.e.,we're calculating the best matching unit and map updates for many examples in parallel and then summing these all at once). There are more opportunities to optimize this code with JAX, but we haven'texploitedtheminordertomakeitpossibletousecodethathassideeffectswithin`train_som_batch`--inparticular,we're
✅UsingJAXloopingconstructsinsteadofPythonlooping(e.g.,`fortintqdm.trange(epochs):`)mayenablefurtheroptimizationsandperformanceimprovements.Tryrewritingthe`train_som_batch`tousewithJAX's `lax.fori_loop` (use `help` or see the JAX documentation for details). How does the performance change?
sothatyou're taking the `argmin` (or `argmax`, if you'relookingforsimilarity!)ofadifferentfunctionovereachentryinthemapandthecurrentexample.Ifyoudon't have a favorite distance or similarity measure, a common example is cosine similarity, which you can calculate for two vectors by dividing their dot product by the product of their magnitudes, like this:
mo.md(r"""✅ If you implemented cosine similarity, what change did you notice to the performance of batch training? What changes could you make to `train_som_batch` to improve performance?""")