Problem
The JAX treedef contains more information. Therefore, their tree_unflatten function does not need an is_leaf argument. The same way we could make the registry argument unnecessary.
This is not just more convenient. It reduces error potential because the unflattening will automatically use the same options as the flattening.
To-Do