Skip to content

Broadcast probability functions with value and size#8105

Open
ricardoV94 wants to merge 5 commits intopymc-devs:mainfrom
ricardoV94:xrv_bcast
Open

Broadcast probability functions with value and size#8105
ricardoV94 wants to merge 5 commits intopymc-devs:mainfrom
ricardoV94:xrv_bcast

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 11, 2026

Since our syntax is foo(rv, value), we should respect the shape of the rv. We were only respecting the one implied by the parameters, but not size.

pm.logp(pm.Normal.dist(size=(10, 3)), 0)) would return a single item tensor

This PR also tests we are allowed to further broadcast with value.

pm.logp(pm.Normal.dist(size=(10, 3)), zeros((3, 10, 3))

Impl-wise I think it's cleaner to not include size in the definitions themselves as they are our construct.

This should also be compatible out of the box when we replace ours by those in pytensor-distributions.

@codecov
Copy link

codecov bot commented Feb 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 90.92%. Comparing base (11d0f1b) to head (f8d5790).
⚠️ Report is 5 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #8105      +/-   ##
==========================================
+ Coverage   90.89%   90.92%   +0.02%     
==========================================
  Files         123      123              
  Lines       19501    19558      +57     
==========================================
+ Hits        17726    17783      +57     
  Misses       1775     1775              
Files with missing lines Coverage Δ
pymc/dims/distributions/core.py 91.86% <100.00%> (+1.86%) ⬆️
pymc/distributions/distribution.py 94.69% <100.00%> (+0.19%) ⬆️
pymc/distributions/moments/means.py 99.55% <100.00%> (-0.01%) ⬇️
pymc/distributions/shape_utils.py 92.15% <100.00%> (+0.27%) ⬆️
pymc/logprob/transforms.py 95.65% <100.00%> (+<0.01%) ⬆️
pymc/math.py 74.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

pytestmark = pytest.mark.filterwarnings(
"error",
"ignore::numba.core.errors.NumbaPerformanceWarning",
"ignore:create_index_for_new_dim:UserWarning",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fixed upstream, but we haven't released yet

@ricardoV94 ricardoV94 changed the title Allow xtensor probability functions to broadcast with value Broadcast probability functions with value and size Feb 11, 2026
@ricardoV94 ricardoV94 force-pushed the xrv_bcast branch 4 times, most recently from 94d7d75 to 0cc66d5 Compare February 11, 2026 21:22
@ricardoV94 ricardoV94 marked this pull request as ready for review February 11, 2026 21:43
logcdf = pt.switch(
pt.lt(backward_value, 0),
pt.log(pt.exp(logcdf_zero) - pt.exp(logcdf)),
logdiffexp(logcdf_zero, logcdf),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the broadcasted logp, this started triggering pymc-devs/pytensor#1883

So I fixed it in our math helper to go around the bug, and avoid having to wait for a pytensor release

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO that links to the pytensor issue and says to remove this once its fixed upstream

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this helper in pymc.math anyway, so no need to revert anything? It just starts already with the optimized form

- arviz>=0.13.0
- blas
- cachetools>=4.2.1
- cachetools>=4.2.1,<7
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated issue: #8106

Comment on lines +156 to +159
if size_idx is not None:
size = dist_params[size_idx]
if params_idxs is not None:
dist_params = [dist_params[i] for i in params_idxs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor this out to a helper, it's repeated 4x


def test_censored_categorical(self):
cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=(5,))
cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the shapes in these tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I get it, because previously you had to say 5 to match the incoming parameter shape?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the previous was just wrong but because it was ignored nothing happened.


def _to_xtensor(var, op: MeasurableXTensorFromTensor):
def _to_tensor(op: MeasurableXTensorFromTensor, value: XTensorVariable) -> TensorVariable:
# Align dims thate are shared between value and op to the right
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: that (can't suggest im in commit-by-commit mode)

missing_axis = [
i for i, dim in enumerate(op.dims, start=n_value_unique_dims) if dim not in value_dims_set
]
return pt_expand_dims(value.values, axis=missing_axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems lke we ought to have a helper that does this on the pytensor side

def _to_xtensor(
op: MeasurableXTensorFromTensor, value: XTensorVariable, var: TensorVariable
) -> XTensorVariable:
value_unique_dims = [dim for dim in value.dims if dim not in op.dims]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name could be better, because dims are always unique by construction. value_only_dims?

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_value_dims ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants