If you want to differentiate f with regards to x you can run;
def f(x, y):
return y * x**2
grad_f = grad(f, argnums=(0, 1))
dx = grad_f(1., 2.)
This is because grad will automatically grab the first parameter
of f to differentiate. You can also be more explicit by running;
grad_f = grad(f, argnums=(0, ))
dx = grad_f(1., 2.)
You can also differentiate with regards to y via;
grad_f = grad(f, argnums=(1, ))
dy = grad_f(1., 2.)
Or you can differentiate towards both;
grad_f = grad(f, argnums=(0, 1, ))
dx, dy = grad_f(1., 2.)