## To-Do - [ ] Write `sample_tree` function that creates a random pytree based on some high level information - [ ] Create random pytrees and test against JAX - [ ] Test current example pytrees against JAX ## Note Need to find a way of not installing JAX on windows in `tox`.