logprob for non-overlapping switch accepts equivalent spellings #8058
logprob for non-overlapping switch accepts equivalent spellings #8058eclipse1605 wants to merge 5 commits intopymc-devs:mainfrom
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8058 +/- ##
==========================================
- Coverage 91.42% 90.74% -0.68%
==========================================
Files 117 121 +4
Lines 19154 19487 +333
==========================================
+ Hits 17512 17684 +172
- Misses 1642 1803 +161
🚀 New features to boost your workflow:
|
ricardoV94
left a comment
There was a problem hiding this comment.
I'm hoping we can simplify the code a bit further. Also let's get rid of all these typing.cast. It's in almost every line. Just mark the file as failing mypy if it requires so much work.
pymc/logprob/switch.py
Outdated
| return (True, False) | ||
| return None | ||
|
|
||
| return _direct_form() or _swapped_form() |
tests/logprob/test_switch.py
Outdated
| x = pm.Normal.dist(mu=0, sigma=1, size=(3,)) | ||
| y = pt.switch(x > 0, x, scale * x) | ||
|
|
||
| if cond_variant == "x_gt_0": |
There was a problem hiding this comment.
Why not parametrize already with the objects you gonna need?
pymc/logprob/switch.py
Outdated
|
|
||
| a = _extract_scale_from_measurable_mul( | ||
| cast(TensorVariable, neg_branch), cast(TensorVariable, x) | ||
| match = _match_scaled_switch_branches( |
There was a problem hiding this comment.
Instead of having to guess, you could normalize the switch so that it's always switch(cond, x, neg_branch). You can write switch(c, t, f) -> switch(~c, f, t).
Maybe this allows you to simplify more logic elsewhere?
There was a problem hiding this comment.
ya makes sense, though we still can’t use cond directly inside the logprob, because cond depends on the latent x
There was a problem hiding this comment.
in the logprob you should be able to assume you already have a canonical form, because that's how you made it in the rewrite
There was a problem hiding this comment.
actually i think we can do this and it should simplify the logprob. we can canonicalize in the rewrite so downstream logprob can assume one shape.
maybe something like this (pseudocode):
# in measurable rewrite:
# detect either ordering + equivalent conditions
# switch(cond(x ? 0), x, a*x)
# switch(cond(x ? 0), a*x, x)
# and also swapped comparisons like 0 < x, 0 >= x, ...
(x, ax, sem) = match_branches_and_condition(node)
# normalize to always: switch(cond_canon, x, ax)
if true_branch_is(ax):
sem = negate(sem) # because switch(c, ax, x) == switch(~c, x, ax)
cond_canon = cond_from_semantics(x, sem)
return measurable_switch_non_overlapping(cond_canon, x, ax)
# then in logprob:
# assume inputs are already canonical: switch(cond_canon, x, a*x)
sem = parse_canonical_cond(cond_canon, x)
gate = value_based_gate(value, sem)
return switch(gate, logp(x, value), logp(a*x, value)) + check(a > 0)
sure, is this a general convention we follow throughout the codebase? |
No, but in general I favor not making our code a mess because of mypy |
|
anyways i misunderstood mypy failing locally in the earlier commit, it wasn't required |
|
@ricardoV94 the failing test is flaky right? |
|
@ricardoV94 anything left in this? |
Closes #8049