Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.
This is the code that defines the
mse (loss) function.
def predict(params, inputs): return inputs @ params def mse(params, inputs, targets): preds = predict(params, inputs) return np.mean((preds - targets)**2) grad_fun = jit(grad(mse))
This is used in the update loop shown below;
import tqdm from numpy import zeros from numpy.random import normal # we generate 10_000 rows and 5 columns n, k = 10_000, 5 X = np.concatenate([np.ones((n, 1)), normal(0, 1, (n, k))], axis=1) # these are the true coefficients that we have to learn true_w = normal(0, 5, (k + 1,)) # this is the generated dataset used in training y = X @ true_w W = normal(0, 1, (k + 1,)) stepsize = 0.02 n_step = 100 hist_gd = zeros((n_step,)) for i in tqdm.tqdm(range(n_step)): # we calculate the gradient dW = grad_fun(W, inputs=X, targets=y) # we apply the gradient W -= dW*stepsize # we keep track of the loss over time hist_gd[i] = mse(W, inputs=X, targets=y)
You can now check the loss over time with this chart;
import matplotlib.pylab as plt plt.figure(figsize=(20, 4)) plt.plot(hist_gd) plt.ylim(0);
Feedback? See an issue? Feel free to mention it here.
If you want to be kept up to date, consider getting the newsletter.