Add rewrites with RV reparametrization tricks#8056
Add rewrites with RV reparametrization tricks#8056lucianopaz wants to merge 5 commits intopymc-devs:mainfrom
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #8056 +/- ##
==========================================
- Coverage 91.42% 90.80% -0.62%
==========================================
Files 117 121 +4
Lines 19154 19443 +289
==========================================
+ Hits 17512 17656 +144
- Misses 1642 1787 +145 🚀 New features to boost your workflow:
|
| @node_rewriter([GammaRV, InvGammaRV]) | ||
| def gamma_reparametrization(fgraph, node): | ||
| rng, size, shape, scale = node.inputs | ||
| return scale * node.op(shape, 1.0, rng=rng, size=size) |
There was a problem hiding this comment.
We can't pull the shape parameter out of the RV?
There was a problem hiding this comment.
The actual reparameterization trick in this case requires some effort (Marsaglia-Tsang method), jax implementation here:
https://github.com/jax-ml/jax/blob/c6568036b83b39d556c68d647a99196e63062612/jax/_src/random.py#L1298
There was a problem hiding this comment.
I'll write the implementation for that then. Thanks @jessegrabowski
There was a problem hiding this comment.
@jessegrabowski, I wrote the generator code. I still need to test that the samples it returns are distributed like the reference Gamma, but maybe you could have a look and comment on the scans
There was a problem hiding this comment.
I've fixed the gamma reparametrization and it now matches the generation we get from Gamma.dist. I can't take gradients through the reparametrized version because of pymc-devs/pytensor#555
| @node_rewriter([BernoulliRV]) | ||
| def bernoulli_reparametrization(fgraph, node): | ||
| rng, size, p = node.inputs | ||
| return switch( |
There was a problem hiding this comment.
Commenting here but in general you may be making non equivalent graphs, by using the size argument. Remember it may be None, meaning it will be implied by the parameters shape. In that case you may be rewriting something like normal([0, 1, 2], size=None) as [0, 1, 2] + normal(size=None)
There's a non default PyTensor rewrite that makes size explicit that you may want to use prior to these rewrites. Then make sure these rewrites fail if size is None isinstance(size.type, NoneTypeT). If you do have size it will be all you need
There was a problem hiding this comment.
Hmm. I'll have to ask you how to add that rewrite into the database and make it go before any reparametrization rewrites
| ) | ||
| from pytensor.tensor.slinalg import cholesky | ||
|
|
||
| reparametrization_trick_db = RewriteDatabaseQuery(include=["random_reparametrization_trick"]) |
There was a problem hiding this comment.
This works but it's a bit expensive every time you extend it. You should start with an actual database like SequenceDB, then you can register rewrite phases in it. pymc/logprob/rewriting.py may be a good example.
I'm even surprise that the RewriteDatabaseQuery has a register argument. Feels like bad design
| @register_random_reparametrization | ||
| @node_rewriter([DirichletRV]) | ||
| def dirichlet_reparametrization(fgraph, node): | ||
| raise NotImplementedError("DirichletRV is not reparametrizable") |
There was a problem hiding this comment.
I doesn't need to block this PR but I know tfp has something for the dirichlet that is one time differentiable
There was a problem hiding this comment.
It comes from the Gamma, so we need to sort that out first
There was a problem hiding this comment.
They might have changed and they just define dirichlet with the normalized batch of gamma variates, which would be differentiable if the gamma also are?
There was a problem hiding this comment.
There was a problem hiding this comment.
Thanks, I'll have a look at that
| next_rng, U = UniformRV()( | ||
| zeros_like(alpha), | ||
| ones_like(alpha), | ||
| rng=rng, | ||
| ).owner.outputs | ||
| next_rng, x = NormalRV()( | ||
| zeros_like(c), | ||
| ones_like(c), | ||
| rng=next_rng, | ||
| ).owner.outputs |
There was a problem hiding this comment.
It's wasteful, but to avoid the current grad limitations you can sample all of these once outside of the scan then pass them in as sequences
There was a problem hiding this comment.
Because nothing about U or x actually depends on the state of the scan
This was born from pymc-devs/pytensor#1424. I started to write down the reparametrization tricks I could find or figure out for most of pytensor's basic
RandomVariableOps