If you want to differentiate an array, just be mindful that your function has a single output. The code below will calculate the gradient vector as well as the hessian.
import jax.numpy as np from jax import hessian def f(x): return np.sum(x**2) grad_f = grad(f) hess_f = hessian(f) dx = grad_f(np.array([1., 2.]) hx = hess_f(np.array([1., 2.])
Note that the array we're passing is not from
numpy but from