jax logo jax: arguments

1 2 3 4 5 6 7 8
Notes

If you want to differentiate f with regards to x you can run;

def f(x, y):
  return y * x**2

grad_f = grad(f, argnums=(0, 1))

dx = grad_f(1., 2.)

This is because grad will automatically grab the first parameter of f to differentiate. You can also be more explicit by running;

grad_f = grad(f, argnums=(0, ))

dx = grad_f(1., 2.)

You can also differentiate with regards to y via;

grad_f = grad(f, argnums=(1, ))

dy = grad_f(1., 2.)

Or you can differentiate towards both;

grad_f = grad(f, argnums=(0, 1, ))

dx, dy = grad_f(1., 2.)