In the first cell we define a function g
together with a gradient
function for it.
def g(x):
return np.sum(x**2)
grad_g = grad(g)
We take this gradient and check how fast we can calculate it in a comprehension.
%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])
We can speed this up by first applying a just in time compiler.
from jax import jit
def g(x):
return np.sum(x**2)
grad_g = jit(grad(g))
You should now see that this runs much faster.
%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])