See how they do their foo to better support math with pytrees. Especially useful for taking gradients. https://github.com/google/tree-math