Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.
If you want to play with the notebook that contains all the final code for all of the videos here, check out the notebook repo.
Feedback? See an issue? Something unclear? Feel free to mention it here.
If you want to be kept up to date, consider getting the newsletter.