Refactor transformed_conditional_logp to remove IR dummy Ops (issue #8100)#8101
Refactor transformed_conditional_logp to remove IR dummy Ops (issue #8100)#8101mikjkd wants to merge 2 commits intopymc-devs:mainfrom
Conversation
Replace the TransformValuesRewrite mechanism with a simpler two-stage approach as suggested by Ricardo Vieira in the issue: apply transform.backward() before conditional_logp, then add Jacobian correction afterward. This eliminates TransformedValue and TransformedValueRV dummy Ops from the intermediate representation. Closes pymc-devs#8100
|
|
There was a problem hiding this comment.
Pull request overview
Refactors transformed_conditional_logp to avoid the TransformValuesRewrite/dummy-Op based transform injection by computing logp on back-transformed (constrained) values first and then adding Jacobian corrections afterward, aligning with the two-stage approach proposed in issue #8100.
Changes:
- Replaced
TransformValuesRewriteusage intransformed_conditional_logpwith explicittransform.backward()pre-processing and post-hoc Jacobian correction. - Added explicit RV-detection checks and a replacement pass (
replace_rvs_by_values) for conditional dependencies. - Updated a regression test to no longer expect
TransformedValuewarnings (onlyValuedVarremains).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
pymc/logprob/basic.py |
Implements the two-stage transformed logp derivation (back-transform before conditional_logp, Jacobian after) and removes the rewrite-based approach. |
tests/logprob/test_basic.py |
Updates warning expectations to reflect removal of TransformedValue from the final/evaluated graph. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| rvs_to_values: dict[Variable, Variable], | ||
| rvs_to_transforms: dict[Variable, Transform], | ||
| jacobian: bool = True, |
There was a problem hiding this comment.
Type hint for rvs_to_transforms should allow None (and possibly missing keys). Model.create_value_var stores transform which can be None, and this function checks transform is not None, so the current dict[Variable, Transform] annotation is inconsistent and breaks type checking.
| """Compute conditional log-probabilities for RVs, applying value transforms and Jacobian corrections. | ||
|
|
||
| This helper will only return the subset of logprob terms corresponding to `rvs`. | ||
| All rvs_to_values and rvs_to_transforms mappings are required. |
There was a problem hiding this comment.
Docstring says "All rvs_to_values and rvs_to_transforms mappings are required", but the implementation treats missing entries in rvs_to_transforms as "no transform" (and there are callers that pass {}). Please update the docstring to match the actual contract (or enforce the requirement with an explicit check).
| All rvs_to_values and rvs_to_transforms mappings are required. | |
| All `rvs` must appear in ``rvs_to_values``. Entries in ``rvs_to_transforms`` are | |
| optional; RVs missing from this mapping are treated as having no transform. |
| f"Random variables detected in the logp graph: {measurable_logp_terms}.\n" | ||
| "This can happen when mixing variables from different models, " | ||
| "or when CustomDist logp or Interval transform functions reference nonlocal variables." |
There was a problem hiding this comment.
The early measurable_logp_terms error message describes mixing models / nonlocal references, but this specific check is triggered when a logp term itself is a MeasurableOp/RandomVariable output (e.g., a CustomDist logp function returning an RV). Consider adjusting the message (or splitting messages) so it points to the actual failure mode and suggested fix.
| f"Random variables detected in the logp graph: {measurable_logp_terms}.\n" | |
| "This can happen when mixing variables from different models, " | |
| "or when CustomDist logp or Interval transform functions reference nonlocal variables." | |
| "Invalid logp terms: some logp outputs are themselves RandomVariables/MeasurableOps.\n" | |
| f"Offending logp terms: {measurable_logp_terms}.\n" | |
| "This usually happens when a CustomDist logp or transform function returns a random " | |
| "variable (e.g., by calling a distribution constructor) instead of a numeric log-density " | |
| "tensor. Ensure that your custom logp/transform functions return a deterministic tensor " | |
| "representing the log-probability, and do not create or return RandomVariables." |
Using constrained value expressions (transform.backward(...)) as the keys passed to conditional_logp means conditional_logp can no longer name logp terms via original_value.name (it will often be None for these intermediate expressions). If logp term names are relied on for debugging/traceability, consider propagating a meaningful name from the original unconstrained value var / RV (e.g., set a name on val_constrained or re-name the retrieved logp term).

Description
Replace the TransformValuesRewrite mechanism with a simpler two-stage approach as suggested by Ricardo Vieira in the issue #8100 : apply transform.backward() before conditional_logp, then add Jacobian correction afterward. This eliminates TransformedValue and TransformedValueRV dummy Ops from the intermediate representation.
Related Issue
Checklist
Type of change