The slow part of the previous video is still the list comprehension. Let's start with out function and gradient.
from jax import vmap
def g(x):
return np.sum(x**2)
grad_g = jit(grad(g))
Let's now do another comparison.
Before
%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])
After
%time _ = vmap(grad_g)(np.linspace(0, 10, 1000))