# jax.

Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.

**Episode Notes**

In the first cell we define a function `g`

together with a gradient
function for it.

```
def g(x):
return np.sum(x**2)
grad_g = grad(g)
```

We take this gradient and check how fast we can calculate it in a comprehension.

```
%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])
```

We can speed this up by first applying a just in time compiler.

```
from jax import jit
def g(x):
return np.sum(x**2)
grad_g = jit(grad(g))
```

You should now see that this runs much faster.

```
%time _ = np.stack([grad_g(i) for i in np.linspace(0, 10, 1000)])
```

Feedback? See an issue? Feel free to mention it here.

If you want to be kept up to date, consider getting the newsletter.