... jax: arrays


If you want to differentiate an array, just be mindful that your function has a single output. The code below will calculate the gradient vector as well as the hessian.

import jax.numpy as np
from jax import hessian

def f(x):
    return np.sum(x**2)

grad_f = grad(f)
hess_f = hessian(f)

dx = grad_f(np.array([1., 2.])
hx = hess_f(np.array([1., 2.])

Note that the array we're passing is not from numpy but from jax.numpy.

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider signing up for the newsletter.