Open
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1778 +/- ##
==========================================
- Coverage 87.65% 86.16% -1.50%
==========================================
Files 140 143 +3
Lines 12178 12503 +325
==========================================
+ Hits 10675 10773 +98
- Misses 1503 1730 +227
Flags with carried forward coverage won't be shown. Click here to find out more.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The goal of this PR is to add NPE-PFN to SBI, as discussed in #1682.
The implementation is realized mostly by three new components, which I will briefly describe in the following.
Happy to discuss all of this, as the exisiting assumptions encoded trough base classes like
NeuralInferenceorConditionalDensityEstimatorsometimes make more and sometimes make less sense for NPE-PFN.There are three key files that implement the method:
1.)
tabpfn_flow.pyimplements the in-contextConditionalDensityEstimatorbased on the autoregressive use of TabPFN. It behaves exactly like other estimators, and given some context dataset provides sampling and log-prob functionality.2.)
npe_pfn.pyimplements theNPE_PFNclass which, inherits fromNeuralInferenceand implements the basic logic used across the package (append_simulations,train,build_posterioretc.). Since NPE-PFN is training free, thetrainmethod is a stub, and most functionality is handled directly bybuild_posterior. This allows users to calltrainwithout breaking any previous workflow, but they can also "forget" about it as would be suggested by a training-free method.Since the TabPFN-based flow behaves like any other flow, NPE-PFN supports out-of-the-box many different types of posteriors (Direct, Rejection, IS, could add more, but inference is too slow for MCMC). However, a crucial feature of NPE-PFN is filtering, where the context dataset is selected based on a given observation.
To support this functionality, a new posterior class is required.
3.)
filtered_direct_posterior.pyimplements this posterior (inheriting fromDirectPosterior), which allows filtering based on different filters (usually KNN, but users can also provide a custom callable).There are many other smaller changes (builders, dataclasses, etc.) and so far no tests.
Also, this PR contains the core functionality for amortized inference. More advanced stuff like sequential inference, or even support for finetuning etc. (which we didn't even do in the paper) are not added.
It probably makes sense to dicuss this approach first, before I add fine-grained tests or possibly more functionality.
Will also share mini benchmark results here in a minute.