... jax.

Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.


If you want to differentiate an array, just be mindful that your function has a single output. The code below will calculate the gradient vector as well as the hessian.

import jax.numpy as np
from jax import hessian

def f(x):
    return np.sum(x**2)

grad_f = grad(f)
hess_f = hessian(f)

dx = grad_f(np.array([1., 2.])
hx = hess_f(np.array([1., 2.])

Note that the array we're passing is not from numpy but from jax.numpy.

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider getting the newsletter.