Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.
If you want to differentiate
f with regards to
x you can run;
def f(x, y): return y * x**2 grad_f = grad(f, argnums=(0, 1)) dx = grad_f(1., 2.)
This is because
grad will automatically grab the first parameter
f to differentiate. You can also be more explicit by running;
grad_f = grad(f, argnums=(0, )) dx = grad_f(1., 2.)
You can also differentiate with regards to
grad_f = grad(f, argnums=(1, )) dy = grad_f(1., 2.)
Or you can differentiate towards both;
grad_f = grad(f, argnums=(0, 1, )) dx, dy = grad_f(1., 2.)
Feedback? See an issue? Something unclear? Feel free to mention it here.
If you want to be kept up to date, consider signing up for the newsletter.