Skip to content

Comments

1446 add api for guidance#1482

Open
manuelgloeckler wants to merge 45 commits intomainfrom
1446-add-api-for-guidance
Open

1446 add api for guidance#1482
manuelgloeckler wants to merge 45 commits intomainfrom
1446-add-api-for-guidance

Conversation

@manuelgloeckler
Copy link
Contributor

@manuelgloeckler manuelgloeckler commented Mar 18, 2025

This adds an API for post-hoc modifications of the trained score estimator, allowing modifications of the likelihood or prior, the support, or other additional constraints. This addresses issue #1446 .

To does:

  • Adds general API inline with current iid_method
  • Refactor the parameterizations from stringly type to strongly typed (also for iid!)
  • Adds some useful explanatory guidance approaches:
    • Classifier free guidance
    • Universal guidance - interval truncations.

@manuelgloeckler manuelgloeckler linked an issue Mar 18, 2025 that may be closed by this pull request
5 tasks
@manuelgloeckler manuelgloeckler marked this pull request as draft March 18, 2025 18:48
@manuelgloeckler manuelgloeckler self-assigned this Mar 18, 2025
@manuelgloeckler manuelgloeckler marked this pull request as ready for review March 20, 2025 07:05
@codecov
Copy link

codecov bot commented Mar 20, 2025

Codecov Report

❌ Patch coverage is 67.21311% with 160 lines in your changes missing coverage. Please review.
✅ Project coverage is 83.83%. Comparing base (937efc2) to head (672514b).
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
sbi/utils/vector_field_utils.py 65.78% 103 Missing ⚠️
sbi/inference/potentials/vector_field_adaptor.py 68.18% 56 Missing ⚠️
sbi/inference/potentials/vector_field_potential.py 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1482      +/-   ##
==========================================
- Coverage   88.54%   83.83%   -4.72%     
==========================================
  Files         137      136       -1     
  Lines       11515    11797     +282     
==========================================
- Hits        10196     9890     -306     
- Misses       1319     1907     +588     
Flag Coverage Δ
fast 83.83% <67.21%> (?)
full ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/vector_field_posterior.py 69.12% <ø> (-8.06%) ⬇️
...i/neural_nets/estimators/flowmatching_estimator.py 81.66% <ø> (-15.00%) ⬇️
sbi/inference/potentials/vector_field_potential.py 75.20% <90.90%> (-15.87%) ⬇️
sbi/inference/potentials/vector_field_adaptor.py 78.19% <68.18%> (ø)
sbi/utils/vector_field_utils.py 66.12% <65.78%> (-17.21%) ⬇️

... and 25 files with indirect coverage changes

@manuelgloeckler manuelgloeckler added the blocked Something is in the way of fixing this. Refer to it in the issue label Mar 24, 2025
@manuelgloeckler manuelgloeckler removed the blocked Something is in the way of fixing this. Refer to it in the issue label Sep 5, 2025
@manuelgloeckler manuelgloeckler requested a review from janfb February 2, 2026 16:49
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @manuelgloeckler 👏 Great to have PriorGuide in here as well 🚀

added mostly minor formatting and edge case handling comments.

More high-level, I am concerned about the long signature sample(...) has by now. We should think about introducing config classes for the different methods (sde config, iid, guidance), or method chaining (see comment below). But this would be a larger refactoring - let's discuss.

On the test side, maybe add another test on the combination of guidance and iid settings, given that this implemented as an option?

Documentation: great to have a tutorial as part of advanced tutorial 20 already. As a follow-up, I suggest a refactoring of tutorials 19 and 20, and 1-2 how-to-guide for the VF and guidance methods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to have this tutorial. Overall, the VF and NPSE variants tutorials 19 and 20 are very long now and with overlap. I am already working on a refactoring and will take the guidance part into account as well.

guidance_method: Method to guide the diffusion process. If None, no guidance
is used. currently we support `affine_classifier_free`, which allows to
scale and shift the "likelihood" or "prior" score contribution. This can
be used to perform "super" conditioning i.e. shring the variance of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: shrink

scale and shift the "likelihood" or "prior" score contribution. This can
be used to perform "super" conditioning i.e. shring the variance of the
likelihood. `Universal` can be used to guide the diffusion process with
a general guidance function. `Interval` is an isntance of that where
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: instance

Comment on lines +162 to +163
guidance_method: Optional[str] = None,
guidance_params: Optional[Dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sample method now has >20 kwargs. we could think about introducing config classes instead, e.g., sth like

from sbi.inference.posteriors import GuidanceConfig, PriorGuideCfg

guidance = GuidanceConfig(
    method="prior_guide",
    params=PriorGuideCfg(train_prior=..., test_prior=..., K=5)
)
samples = posterior.sample((1000,), x=x_o, guidance=guidance)

but this would be a larger, possibly follow-up refactoring. what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or even sth like

samples = posterior.with_guidance("prior_guide", ...).sample((1000,), x=x_o)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhh, I think config classes would be fine (I internally use config classes anyway). But at least at the time of implementation most user API was still configured via dicts so I made this the external API.

I do actually like your suggestions with with_guidance and with_iid to make it more explicit. But yes I would make this a follow-up refactoring.

iid_params: Additional parameters passed to the iid method. See the specific
`IIDScoreFunction` child class for details.
guidance_method: Method to guide the diffusion process. If None, no guidance
is used. currently we support `affine_classifier_free`, which allows to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently

return denoising_posterior_precision


def compute_score(p: Distribution, inputs: Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring and return type please.

max_log_ratio: float = 50.0,
device: Union[str, torch.device] = "cpu",
) -> Tuple[Tensor, Tensor, Tensor]:
"""Implementation for fitting a generalized GMM to the prior ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop "Implementation for fitting", replace by "Fit" to have it in one line

Comment on lines +179 to +191
class _HashableById:
__slots__ = ("obj", "_id")

def __init__(self, obj: Distribution):
"""Wraps a non-hashable Distribution to make it cache-key compatible."""
self.obj = obj
self._id = id(obj)

def __hash__(self) -> int:
return self._id

def __eq__(self, other: object) -> bool:
return isinstance(other, _HashableById) and self._id == other._id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall great to have this here! the main purpose is to not refit the GMM on every call from the potential, right?

just wondering, this caches by object ID not distribution parameters, so if someone passes the same dist with same params but a new instance, this won't have an effect. Accordingly, when someone changes the distribution params in place this will not be noticed in the cache.
But I think both are edge cases so it's fine

Copy link
Contributor Author

@manuelgloeckler manuelgloeckler Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there are definitely some edge cases where this will fail. This is why it is only used if required.

I generally a bit unhappy that the adaptor/score wrapper objects are reinitialized in each call of potential (which does require to hash as much as possible). I think it would be good the restructure the VF potential a bit that such quantities can be computed once at the beginning i.e. via an initalize_aux which needs to be called once berfore sample.

from sbi.utils import BoxUniform, MultipleIndependent


def build_some_priors(num_dim: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to build_test_priors and add short docstring?

Comment on lines +39 to +40
# Bug in Independent introduced????
priors = [prior1, prior2, prior2_2, prior3, prior4, prior5] # Something broke
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this about? is there a bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch will check, but it includes all the priors so these tests did atleast pass. (so might have to remove the comments)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add API for guidance

4 participants