... jax: training


This is the code that defines the predict and mse (loss) function.

def predict(params, inputs):
    return inputs @ params 

def mse(params, inputs, targets):
    preds = predict(params, inputs)
    return np.mean((preds - targets)**2)

grad_fun = jit(grad(mse))

This is used in the update loop shown below;

import tqdm
from numpy import zeros
from numpy.random import normal

# we generate 10_000 rows and 5 columns 
n, k = 10_000, 5
X = np.concatenate([np.ones((n, 1)), normal(0, 1, (n, k))], axis=1)

# these are the true coefficients that we have to learn 
true_w = normal(0, 5, (k + 1,))
# this is the generated dataset used in training 
y = X @ true_w
W = normal(0, 1, (k + 1,))

stepsize = 0.02
n_step = 100
hist_gd = zeros((n_step,))
for i in tqdm.tqdm(range(n_step)):
    # we calculate the gradient 
    dW = grad_fun(W, inputs=X, targets=y)
    # we apply the gradient 
    W -= dW*stepsize
    # we keep track of the loss over time
    hist_gd[i] = mse(W, inputs=X, targets=y)

You can now check the loss over time with this chart;

import matplotlib.pylab as plt 

plt.figure(figsize=(20, 4))

Feedback? See an issue? Something unclear? Feel free to mention it here.

If you want to be kept up to date, consider signing up for the newsletter.