... jax: introduction


Note that if you want to run jax you have to install it first. You can do that from jupyter via;

%pip install jax

This is the code that we ran in this video;

from jax import grad

def f(x):
    return x**2

grad_f = grad(f)

f(3.), grad_f(3.)

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider signing up for the newsletter.