jax:
arrays
Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.
Notes
If you want to differentiate an array, just be mindful that your function has a single output. The code below will calculate the gradient vector as well as the hessian.
import jax.numpy as np
from jax import hessian
def f(x):
return np.sum(x**2)
grad_f = grad(f)
hess_f = hessian(f)
dx = grad_f(np.array([1., 2.])
hx = hess_f(np.array([1., 2.])
Note that the array we're passing is not from numpy
but from jax.numpy
.
Feedback? See an issue? Something unclear? Feel free to mention it here.
If you want to be kept up to date, consider signing up for the newsletter.