Sometimes you'd like to write your own algorithm. You can get far with rapid proptyping in numpy but a common downside is the lack of a differentiation tool. Jax is a tool that birdges this caveat. It is numpy compatible and does not force you to write code in a different way. It can even compile for use on a CPU/GPU.
If you want to run this yourself on colab, you should be able to make use of the shared notebook by clicking here.
Feedback? See an issue? Something unclear? Feel free to mention it here.
If you want to be kept up to date, consider getting the newsletter.