# jax: arguments

Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.

**Notes**

If you want to differentiate `f`

with regards to `x`

you can run;

```
def f(x, y):
return y * x**2
grad_f = grad(f, argnums=(0, 1))
dx = grad_f(1., 2.)
```

This is because `grad`

will automatically grab the first parameter
of `f`

to differentiate. You can also be more explicit by running;

```
grad_f = grad(f, argnums=(0, ))
dx = grad_f(1., 2.)
```

You can also differentiate with regards to `y`

via;

```
grad_f = grad(f, argnums=(1, ))
dy = grad_f(1., 2.)
```

Or you can differentiate towards both;

```
grad_f = grad(f, argnums=(0, 1, ))
dx, dy = grad_f(1., 2.)
```

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider signing up for the newsletter.