# jax.

Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.

**Episode Notes**

If you want to differentiate an array, just be mindful that your function has a single output. The code below will calculate the gradient vector as well as the hessian.

```
import jax.numpy as np
from jax import hessian
def f(x):
return np.sum(x**2)
grad_f = grad(f)
hess_f = hessian(f)
dx = grad_f(np.array([1., 2.])
hx = hess_f(np.array([1., 2.])
```

Note that the array we're passing is not from `numpy`

but from `jax.numpy`

.

Feedback? See an issue? Feel free to mention it here.

If you want to be kept up to date, consider getting the newsletter.