logo

... jax: jit



Notes

In the first cell we define a function g together with a gradient function for it.

def g(x):
    return np.sum(x**2)

grad_g = grad(g)

We take this gradient and check how fast we can calculate it in a comprehension.

%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])

We can speed this up by first applying a just in time compiler.

from jax import jit

def g(x):
    return np.sum(x**2)

grad_g = jit(grad(g))

You should now see that this runs much faster.

%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])

Feedback? See an issue? Something unclear? Feel free to mention it here.



If you want to be kept up to date, consider signing up for the newsletter.