... jax: arguments


If you want to differentiate f with regards to x you can run;

def f(x, y):
  return y * x**2

grad_f = grad(f, argnums=(0, 1))

dx = grad_f(1., 2.)

This is because grad will automatically grab the first parameter of f to differentiate. You can also be more explicit by running;

grad_f = grad(f, argnums=(0, ))

dx = grad_f(1., 2.)

You can also differentiate with regards to y via;

grad_f = grad(f, argnums=(1, ))

dy = grad_f(1., 2.)

Or you can differentiate towards both;

grad_f = grad(f, argnums=(0, 1, ))

dx, dy = grad_f(1., 2.)

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider signing up for the newsletter.