... jax.

Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.


If you want to differentiate f with regards to x you can run;

def f(x, y):
  return y * x**2

grad_f = grad(f, argnums=(0, 1))

dx = grad_f(1., 2.)

This is because grad will automatically grab the first parameter of f to differentiate. You can also be more explicit by running;

grad_f = grad(f, argnums=(0, ))

dx = grad_f(1., 2.)

You can also differentiate with regards to y via;

grad_f = grad(f, argnums=(1, ))

dy = grad_f(1., 2.)

Or you can differentiate towards both;

grad_f = grad(f, argnums=(0, 1, ))

dx, dy = grad_f(1., 2.)

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider getting the newsletter.