Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.
The slow part of the previous video is still the list comprehension. Let's start with out function and gradient.
from jax import vmap def g(x): return np.sum(x**2) grad_g = jit(grad(g))
Let's now do another comparison.
%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])
%time _ = vmap(grad_g)(np.linspace(0, 10, 1000))
Feedback? See an issue? Something unclear? Feel free to mention it here.
If you want to be kept up to date, consider getting the newsletter.