Calmcode - jax: training

Training

1 2 3 4 5 6 7 8

This is the code that defines the predict and mse (loss) function.

def predict(params, inputs):
    return inputs @ params

def mse(params, inputs, targets):
    preds = predict(params, inputs)
    return np.mean((preds - targets)**2)

grad_fun = jit(grad(mse))

This is used in the update loop shown below;

import tqdm
from numpy import zeros
from numpy.random import normal

# we generate 10_000 rows and 5 columns
n, k = 10_000, 5
X = np.concatenate([np.ones((n, 1)), normal(0, 1, (n, k))], axis=1)

# these are the true coefficients that we have to learn
true_w = normal(0, 5, (k + 1,))
# this is the generated dataset used in training
y = X @ true_w
W = normal(0, 1, (k + 1,))

stepsize = 0.02
n_step = 100
hist_gd = zeros((n_step,))
for i in tqdm.tqdm(range(n_step)):
    # we calculate the gradient
    dW = grad_fun(W, inputs=X, targets=y)
    # we apply the gradient
    W -= dW*stepsize
    # we keep track of the loss over time
    hist_gd[i] = mse(W, inputs=X, targets=y)

You can now check the loss over time with this chart;

import matplotlib.pylab as plt

plt.figure(figsize=(20, 4))
plt.plot(hist_gd)
plt.ylim(0);