There are currently no unit tests for gradient_creator() and I don't trust the implementation for JAX -- it calls jax.grad() on a ParametrizedFunction, not the internal JAX function.
It would also be nice to illustrate the creation of a gradient for a function in a notebook. (E.g., a simple demo of error propagation with autodiff.)