Skip to content

Comments

Add NPE-PFN#1778

Open
jsvetter wants to merge 27 commits intosbi-dev:mainfrom
jsvetter:npe_pfn_dev
Open

Add NPE-PFN#1778
jsvetter wants to merge 27 commits intosbi-dev:mainfrom
jsvetter:npe_pfn_dev

Conversation

@jsvetter
Copy link
Contributor

@jsvetter jsvetter commented Feb 24, 2026

The goal of this PR is to add NPE-PFN to SBI, as discussed in #1682.

The implementation is realized mostly by three new components, which I will briefly describe in the following.
Happy to discuss all of this, as the exisiting assumptions encoded trough base classes like NeuralInference or ConditionalDensityEstimator sometimes make more and sometimes make less sense for NPE-PFN.

There are three key files that implement the method:

1.) tabpfn_flow.py implements the in-context ConditionalDensityEstimator based on the autoregressive use of TabPFN. It behaves exactly like other estimators, and given some context dataset provides sampling and log-prob functionality.

2.) npe_pfn.py implements the NPE_PFN class which, inherits from NeuralInference and implements the basic logic used across the package (append_simulations, train, build_posterior etc.). Since NPE-PFN is training free, the train method is a stub, and most functionality is handled directly by build_posterior. This allows users to call train without breaking any previous workflow, but they can also "forget" about it as would be suggested by a training-free method.

Since the TabPFN-based flow behaves like any other flow, NPE-PFN supports out-of-the-box many different types of posteriors (Direct, Rejection, IS, could add more, but inference is too slow for MCMC). However, a crucial feature of NPE-PFN is filtering, where the context dataset is selected based on a given observation.

To support this functionality, a new posterior class is required.

3.) filtered_direct_posterior.py implements this posterior (inheriting from DirectPosterior), which allows filtering based on different filters (usually KNN, but users can also provide a custom callable).

There are many other smaller changes (builders, dataclasses, etc.) and so far no tests.
Also, this PR contains the core functionality for amortized inference. More advanced stuff like sequential inference, or even support for finetuning etc. (which we didn't even do in the paper) are not added.

It probably makes sense to dicuss this approach first, before I add fine-grained tests or possibly more functionality.

Will also share mini benchmark results here in a minute.

@codecov
Copy link

codecov bot commented Feb 24, 2026

Codecov Report

❌ Patch coverage is 31.13772% with 230 lines in your changes missing coverage. Please review.
✅ Project coverage is 86.16%. Comparing base (d41efa6) to head (690df4a).

Files with missing lines Patch % Lines
sbi/neural_nets/estimators/tabpfn_flow.py 20.00% 104 Missing ⚠️
sbi/inference/trainers/npe/npe_pfn.py 36.66% 57 Missing ⚠️
.../inference/posteriors/filtered_direct_posterior.py 34.32% 44 Missing ⚠️
sbi/neural_nets/net_builders/flow.py 20.00% 8 Missing ⚠️
sbi/inference/posteriors/posterior_parameters.py 53.84% 6 Missing ⚠️
sbi/utils/torchutils.py 40.00% 6 Missing ⚠️
sbi/utils/user_input_checks.py 0.00% 3 Missing ⚠️
sbi/inference/trainers/base.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1778      +/-   ##
==========================================
- Coverage   87.65%   86.16%   -1.50%     
==========================================
  Files         140      143       +3     
  Lines       12178    12503     +325     
==========================================
+ Hits        10675    10773      +98     
- Misses       1503     1730     +227     
Flag Coverage Δ
fast 81.90% <31.13%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/__init__.py 100.00% <100.00%> (ø)
sbi/inference/posteriors/__init__.py 100.00% <100.00%> (ø)
.../inference/potentials/posterior_based_potential.py 89.36% <100.00%> (ø)
sbi/inference/trainers/npe/__init__.py 100.00% <100.00%> (ø)
sbi/neural_nets/factory.py 81.92% <ø> (ø)
sbi/neural_nets/net_builders/__init__.py 100.00% <ø> (ø)
sbi/inference/trainers/base.py 93.03% <60.00%> (-0.51%) ⬇️
sbi/utils/user_input_checks.py 76.68% <0.00%> (ø)
sbi/inference/posteriors/posterior_parameters.py 80.57% <53.84%> (-2.76%) ⬇️
sbi/utils/torchutils.py 67.77% <40.00%> (-1.64%) ⬇️
... and 4 more

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant