jax logo jax: arrays

1 2 3 4 5 6 7 8

If you want to differentiate an array, just be mindful that your function has a single output. The code below will calculate the gradient vector as well as the hessian.

import jax.numpy as np
from jax import hessian

def f(x):
    return np.sum(x**2)

grad_f = grad(f)
hess_f = hessian(f)

dx = grad_f(np.array([1., 2.])
hx = hess_f(np.array([1., 2.])

Note that the array we're passing is not from numpy but from jax.numpy.