Fix NaN padding test: remove xfail decorator#1735
Fix NaN padding test: remove xfail decorator#1735DhyeyJoshi39 wants to merge 1 commit intosbi-dev:mainfrom
Conversation
PermutationInvariantEmbedding intentionally uses NaN padding for varying trial counts. PR sbi-dev#1701 NaN check conflicts with this legitimate use case. Removed @pytest.mark.xfail decorator. Refs: sbi-dev#1701 sbi-dev#1717
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1735 +/- ##
=======================================
Coverage 88.51% 88.51%
=======================================
Files 137 137
Lines 11527 11527
=======================================
Hits 10203 10203
Misses 1324 1324
Flags with carried forward coverage won't be shown. Click here to find out more. |
|
Thank you for making a PR @DhyeyJoshi39. You removed the |
|
Hi @janfb , Thank you for the quick review! I'd like to propose implementing the check_finite_x parameter upstream to resolve this properly. Problem Statement Proposed Solution NPE.build_posterior(check_finite_x=True) Posterior.sample(check_finite_x=True) Implementation Sketch: In NPE.build_posterior()def build_posterior(self, check_finite_x=True, ...): In Posterior.sample()def sample(self, sample_shape, x, check_finite_x=True, ...): Questions for Guidance Should we add validation docs explaining when to set False? Any other files affected by PR #1701 NaN check? I'm happy to implement this if the approach looks good to you! Looking forward to your thoughts. Thanks! |
|
Hi @DhyeyJoshi39, thanks for following up on this! Your proposed approach with check_finite_x is the right direction. A few housekeeping items first:
Now for the implementation. The test change alone won't work because BackgroundThere's an asymmetry right now:
There's even a TODO in # TODO: add option to allow for NaNs in certain dimensions, e.g., to encode varying
# numbers of trials.ImplementationThe NaN padding use case is specific to NPE with 1. def process_x(
x: Array,
x_event_shape: Optional[torch.Size] = None,
check_finite: bool = True,
) -> Tensor:
...
x = atleast_2d(torch.as_tensor(x, dtype=float32))
if check_finite:
assert_all_finite(x, "Observed data x_o contains Nans or Infs.")
...2.
3.
4.
5.
That's 5 functions to modify. The other posterior types inherit Let me know if you have questions about any of this! |
Removed pytest xfail decorator from test_npe_with_iid_embedding_varying_num_trials.
PermutationInvariantEmbedding intentionally uses NaN padding for varying trial counts.
PR #1701's NaN check conflicts with this legitimate use case.
Test passes once check_finite_x=False kwarg is available.
Fixes #1717
Refs: #1701