# jax: training

Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.

**Notes**

This is the code that defines the `predict`

and `mse`

(loss) function.

```
def predict(params, inputs):
return inputs @ params
def mse(params, inputs, targets):
preds = predict(params, inputs)
return np.mean((preds - targets)**2)
grad_fun = jit(grad(mse))
```

This is used in the update loop shown below;

```
import tqdm
from numpy import zeros
from numpy.random import normal
# we generate 10_000 rows and 5 columns
n, k = 10_000, 5
X = np.concatenate([np.ones((n, 1)), normal(0, 1, (n, k))], axis=1)
# these are the true coefficients that we have to learn
true_w = normal(0, 5, (k + 1,))
# this is the generated dataset used in training
y = X @ true_w
W = normal(0, 1, (k + 1,))
stepsize = 0.02
n_step = 100
hist_gd = zeros((n_step,))
for i in tqdm.tqdm(range(n_step)):
# we calculate the gradient
dW = grad_fun(W, inputs=X, targets=y)
# we apply the gradient
W -= dW*stepsize
# we keep track of the loss over time
hist_gd[i] = mse(W, inputs=X, targets=y)
```

You can now check the loss over time with this chart;

```
import matplotlib.pylab as plt
plt.figure(figsize=(20, 4))
plt.plot(hist_gd)
plt.ylim(0);
```

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider signing up for the newsletter.