fix derived scan logprob when observed provides more broadcastable information#8016
fix derived scan logprob when observed provides more broadcastable information#8016eclipse1605 wants to merge 1 commit intopymc-devs:mainfrom
Conversation
|
I'm inclined to fix this on the Scan side in PyTensor. Basically if Furthemore Otherwise it's too tricky to work with scans programatically, and we end up with failures/awkward work-arounds as demonstrated in this PR |
|
sorry for the slow reply, am i getting this right, implementation wise? we’d update PyTensor so that:
|
|
@ricardoV94 does this sound right? |
|
The inner should have as much information as the outer, the outre can have less. But yes. Let's open a PR and see how it goes |
|
so in we change and in if (
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):we replace the broadcastable eq check with a compatibility check using if type_input.dtype != type_output.dtype:
# dtype error
elif isinstance(type_input, TensorType) and isinstance(type_output, TensorType):
if not type_input.is_super(type_output):
# compatibility error using is_super
else:
if type_input != type_output:
# fallback error |
|
I think you can use |
|
@ricardoV94 i have openned a PR (pymc-devs/pytensor#1861) |
|
@ricardoV94 should i close this PR? |
|
No, but it should be changed to just have a regression test for the original issue to make sure it's working once we fix it in PyTensor |
|
yeah makes sense |
|
@ricardoV94 once we fix it in pytensor do we just add the snippet used to repro the bug in #7892 as a test? |
Yes, plus we test the result is non-sensical when evaluated |
Description
fixed a failure in derived scan logprob construction when the observed/value tensor provides more static broadcastability information than the generative scan graph (e.g. observed has a size-1 axis like
(date, 1)while the scan state was inferred as non broadcastable on that axis).in this,
model.logp()could fail during the measurable scan rewrite with a scanoutputs_infobroadcast pattern mismatch (scan output inferred as matrix like vs.outputs_infoexpecting vector-like).Applynodes (so scan reconstruction remains valid).note
I think the same idea can be generalized by treating static broadcastability metadata as part of the measurable scan rewrite contract:
outputs_infoproxies for the logprob rewritten scan, ensure theirTensorType.shapereflects any size-1/broadcastable axes implied by the outer variables.pt.join/outputs_info) so that init and scan outputs agree on broadcastability, without inserting broadcastApplynodes into the inner graph (placeholders must remain nominal vars).Related Issue
Checklist
Type of change