If you want to differentiate an array, just be mindful that your function has a single output. The code below will calculate the gradient vector as well as the hessian.
import jax.numpy as np
from jax import hessian
def f(x):
return np.sum(x**2)
grad_f = grad(f)
hess_f = hessian(f)
dx = grad_f(np.array([1., 2.])
hx = hess_f(np.array([1., 2.])
Note that the array we're passing is not from numpy
but from jax.numpy
.