logo

... jax: vmap



Notes

The slow part of the previous video is still the list comprehension. Let's start with out function and gradient.

from jax import vmap

def g(x):
    return np.sum(x**2)

grad_g = jit(grad(g))

Let's now do another comparison.

Before

%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])

After

%time _ = vmap(grad_g)(np.linspace(0, 10, 1000))

Feedback? See an issue? Something unclear? Feel free to mention it here.



If you want to be kept up to date, consider signing up for the newsletter.