logo


jax


<p>Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. <a href="https://jax.readthedocs.io/en/latest/">Jax</a> is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.</p>


1 - Introduction
2 - Arguments
3 - Arrays
4 - Jit
5 - Vmap
6 - Training
7 - GPU
8 - Closing Notes

Note that if you want to run jax you have to install it first. You can do that from jupyter via;

%pip install jax

This is the code that we ran in this video;

from jax import grad

def f(x):
    return x**2

grad_f = grad(f)

f(3.), grad_f(3.)