Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.
Note that if you want to run jax you have to install it first. You can do that from jupyter via;
%pip install jax
This is the code that we ran in this video;
from jax import grad
def f(x):
return x**2
grad_f = grad(f)
f(3.), grad_f(3.)