Skip to content

Comments

Re-enable time-dependent z-scoring for Flow Matching#1752

Open
satwiksps wants to merge 16 commits intosbi-dev:mainfrom
satwiksps:fm-z-scoring
Open

Re-enable time-dependent z-scoring for Flow Matching#1752
satwiksps wants to merge 16 commits intosbi-dev:mainfrom
satwiksps:fm-z-scoring

Conversation

@satwiksps
Copy link
Contributor

@satwiksps satwiksps commented Feb 3, 2026

Description

This PR re-introduces z-scoring for Flow Matching estimators using a time-dependent normalization approach and adds a Gaussian Baseline for improved training stability.

As discussed in #1623, standard z-scoring is problematic because the network input evolves from data to noise. This implementation provides two distinct normalization modes to handle this evolution while maintaining training stability.

Corrected Normalization Statistics:
Since we define $t=0$ as Data and $t=1$ as Noise, the statistics are handled based on the chosen mode:

  1. Gaussian Baseline (gaussian_baseline=True): Normalizes inputs to $N(0, 1)$ across the entire path. The drift signal is handled by the hard-coded affine baseline.

    $$\mu_t = (1 - t) \cdot \mu_{data}$$
    $$\sigma_t = \sqrt{(1 - t)^2 \sigma_{data}^2 + t^2}$$

  2. Variance Only (gaussian_baseline=False): Normalizes variance while preserving the raw data location at $t=0$. This ensures the network can still learn the drift signal when no baseline is used.

    $$\mu_t = t \cdot \mu_{data}$$
    $$\sigma_t = \sqrt{t^2 \sigma_{data}^2 + (1 - t)^2}$$

Gaussian Baseline:
We implemented an affine vector field baseline (enabled by default). The network now learns the residual vector field with respect to the optimal Gaussian probability path, significantly improving convergence on shifted datasets.

Related Issues/PRs

Changes

  1. sbi/neural_nets/net_builders/vector_field_nets.py: Updated build_vector_field_estimator to calculate training data statistics, accept the gaussian_baseline flag, and pass them to the estimator.
  2. sbi/neural_nets/estimators/flowmatching_estimator.py:
    • Buffer Management: Registered mean_1 and std_1 as buffers and expanded them to match input_shape to ensure compatibility with multi-dimensional data in CI.
    • Split Logic: Implemented the split logic in forward() to support both Gaussian Baseline (residual learning) and Variance-only (signal preserving) modes.
    • Numerical Stability: Added a small epsilon (1e-5) to variance calculations to prevent division-by-zero errors.
  3. tests/linearGaussian_vector_field_test.py:
    • Added test_fmpe_time_dependent_z_scoring_integration: Verifies statistics population, buffer registration, and forward pass shapes.
    • Added test_fmpe_shifted_data_gaussian_baseline: Verifies that the Gaussian Baseline outperforms variance-only scaling on shifted data ($U[95, 105]$) with robust simulation counts ($N=2000$).

Verification

Verification

  • Shifted Data Benchmark: Confirmed that gaussian_baseline=True achieves lower validation loss and faster convergence than variance-only scaling on a shifted 1D prior ($U[95, 105]$).
  • Integration Tests: All new tests pass, confirming correct buffer registration and stability with z_score_x='independent'.
  • Benchmarks: I ran the sbi benchmarks locally (pytest --bm --bm-mode fmpe) to check for stability and performance. All 12 tests passed successfully.

mini SBIBM results

@satwiksps satwiksps marked this pull request as ready for review February 3, 2026 19:51
@codecov
Copy link

codecov bot commented Feb 3, 2026

Codecov Report

❌ Patch coverage is 76.92308% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 88.20%. Comparing base (937efc2) to head (3998d75).
⚠️ Report is 11 commits behind head on main.

Files with missing lines Patch % Lines
...i/neural_nets/estimators/flowmatching_estimator.py 76.00% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1752      +/-   ##
==========================================
- Coverage   88.54%   88.20%   -0.35%     
==========================================
  Files         137      138       +1     
  Lines       11515    13467    +1952     
==========================================
+ Hits        10196    11878    +1682     
- Misses       1319     1589     +270     
Flag Coverage Δ
fast 84.39% <76.92%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/neural_nets/net_builders/vector_field_nets.py 93.16% <100.00%> (ø)
...i/neural_nets/estimators/flowmatching_estimator.py 90.36% <76.00%> (-6.31%) ⬇️

... and 35 files with indirect coverage changes

@satwiksps
Copy link
Contributor Author

It seems tests/torchutils_test.py::TorchUtilsTest::test_searchsorted is consistently failing in the CI with an execnet.gateway_base.DumpError.

Since this failure is in torchutils_test.py (which I haven't touched) and appears to be a serialization issue with pytest-xdist masking a local assertion error, I believe it is unrelated to my changes in flowmatching_estimator.py ?

The actual Flow Matching benchmarks and integration tests for this PR passed successfully though

this was an old bug that surfaced now likely because codecov was trying to serialize things.
@janfb
Copy link
Contributor

janfb commented Feb 5, 2026

It seems tests/torchutils_test.py::TorchUtilsTest::test_searchsorted is consistently failing in the CI with an execnet.gateway_base.DumpError.

Since this failure is in torchutils_test.py (which I haven't touched) and appears to be a serialization issue with pytest-xdist masking a local assertion error, I believe it is unrelated to my changes in flowmatching_estimator.py ?

The actual Flow Matching benchmarks and integration tests for this PR passed successfully though

It seems tests/torchutils_test.py::TorchUtilsTest::test_searchsorted is consistently failing in the CI with an execnet.gateway_base.DumpError.

Since this failure is in torchutils_test.py (which I haven't touched) and appears to be a serialization issue with pytest-xdist masking a local assertion error, I believe it is unrelated to my changes in flowmatching_estimator.py ?

The actual Flow Matching benchmarks and integration tests for this PR passed successfully though

Yes, this is unrelated and popped up here by chance or because of an unrelated change in a downstream package. I pushed a fix to this branch ✅

@janfb
Copy link
Contributor

janfb commented Feb 5, 2026

Thanks for working on this @satwiksps !

Overall, this looks exactly right. However, after reviewing the code and tracing through the flow matching implementation, I believe the z-scoring formula is inverted relative to the interpolation convention (quite confusing!)

The interpolation in the loss function is:

theta_t = (1 - t) * theta_data + t * theta_noise

So the expected input mean at each time is:

E[θ_t] = (1-t) * mean_data + t * 0 = (1-t) * mean_data

Current PR formula:

  • mu_t = t * mean_1
  • var_t = (t * std_1)² + (1 - t)²

This gives mu_t = 0 at t=0 and mu_t = mean_data at t=1 — exactly backwards.

Correct formula should be:

  • mu_t = (1 - t) * mean_1
  • var_t = ((1 - t) * std_1)² + t²

The formula only matches at t=0.5 and is maximally wrong at the boundaries.

Note on zuko's sampling: I had to dig a bit but in zuko, NormalizingFlow.sample() uses transform.inv() which integrates backward (t1→t0), so training and sampling conventions do align — the issue is purely the z-scoring formula.

To verify this, I suggest the following test: The standard linear Gaussian test, but with uniform prior between 95 and 100, and with data x_o centered at 100 (far from N(0,1)). With the inverted formula, C2ST should degrades significantly compared to no z-scoring and it should be fixed (c2st close 0.5) with the correct formula.

Can you confirm this (maybe I got confused with the integration directions after all)?

Copy link
Contributor

@manuelgloeckler manuelgloeckler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @satwiksps !

Thanks for the contribution! I checked with main and as of now it does I guess on average perform very similar if not a bit worse than before (although, I think thats mostly fine i.e. these tasks).

image

I wonder if it would make sense to improve the "preconditioning" a bit more (see comments).

@janfb
Copy link
Contributor

janfb commented Feb 6, 2026

Hey @satwiksps !

Thanks for the contribution! I checked with main and as of now it does I guess on average perform very similar if not a bit worse than before (although, I think thats mostly fine i.e. these tasks).

image I wonder if it would make sense to improve the "preconditioning" a bit more (see comments).

Thanks for adding the comparison to main. What could happen here is that the benchmarking tasks are not discriminative w.r.t. to z-scoring, no? I.e., we need a task that benefits from z-scoring.

@janfb
Copy link
Contributor

janfb commented Feb 6, 2026

Alright, I looked at it again and I realized that my proposal was actually incorrect. The formulas I proposed would result in total normalization, i.e., "independent" z-scoring, where all time steps have equal zero mean after z-scoring and we lose valuable time-depenedent information - sorry @satwiksps , your formulas where actually correct!

What Manuel proposed is great, we z-score with respect to the Gaussian baseline, e.g., what one would expect when the posterior is actually Gaussian. Then the flow matching network only has to learn the residual from this ideal baseline (please correct me @manuelgloeckler if this intuition is inaccurate).

I tested this locally with the following setup:

  • Prior: BoxUniform([95, 105]), x_o=100
  • Simulator: x = theta + 0.5 * noise
  • Reference posterior: N(x_o, 0.5²I)
  • 3000 simulations.

Results:

Formula C2ST Description
Gaussian 0.631 Gaussian baseline + residual learning
var_only 0.772 Variance scaling only
pr 0.774 PR's time-dependent z-scoring
static 0.796 Static mean subtraction
none 0.865 No z-scoring
independent z-scoring 0.922 "Correct" mean formula

Thus, @satwiksps I suggest you implement both options, your proposal and Manuel's proposal and add the test as a new z-scoring test and confirm the results.
@manuelgloeckler I think it would be good to have both options as the gaussian baseline assumption can be suboptimal when the posterior is multi-modal or skewed?

@manuelgloeckler
Copy link
Contributor

@janfb The preconditioning is with respect to the "prior" not the posterior (as this would require regression from x). I don't think that it will "hurt" in almost all cases i.e. FM nets are initialized to output zero hence effectively will let the initialized network sample from a mass covering Gaussian approximation of the prior (and everything else needs to be learned).

Nonetheless having an option to disable it is always good.

Agree that the benchmark tests are not really sensitive to the z-scoreing, but as we usually enable z-scoreing by default it shouldn't hurt performance even if its not necessary. But as said the deviation is small enough to be fine (and might improve with the additional baseline).

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @satwiksps ! looks good, I just have one crucial question on the standard z-scoring formulas again, please check 🙏


inference_gauss = FMPE(
prior,
z_score_x="independent",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to be sure, the z-scoring affects the data only as it's therefore fine to have it as "independent" (because data is not structured here).

but internally, the flow-time z-scoring will happen according to the gaussian_baseline option, yes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the z_score_x="independent" only handles the initial, static normalization of the training data. internally, the FlowMatchingEstimator then performs the dynamic, time-dependent z-scoring based on the gaussian_baseline flag

Comment on lines 151 to 152
mu_t = (1 - t_view) * mean_1_view
var_t = ((1 - t_view) * std_1_view) ** 2 + t_view**2 + 1e-5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, this is the formula I suggested, no? and we discussed that this is actually not what we would want here because it normalizes to zero and performs worse (see table).

Please double check this change. We should use your initially suggested formula as one option, and the gaussian_baseline formulas as the second (default) option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry @janfb I should have read the conversation properly, using the (1-t) formula without the baseline does indeed force inputs to zero.

Comment on lines +147 to +154
if self.gaussian_baseline:
mu_t = (1 - t_view) * mean_1_view
var_t = ((1 - t_view) * std_1_view) ** 2 + t_view**2 + 1e-5

std_t = torch.sqrt(var_t)
input_norm = (input - mu_t) / std_t

num = t_view - (1 - t_view) * std_1_view**2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is still wrong (but my previous comment also was a bit incomplete, sorry).

Your previous approach was already right, the gaussian_baseline should just additively extend i.e. I would suggest the following (but I might miss something).

      mu_t = t_view * mean_1_view
      var_t = (t_view * std_1_view) ** 2 + (1 - t_view) ** 2 + 1e-5

      std_t = torch.sqrt(var_t)
      input_centered = x - mut
      input_norm = input_centered / std_t

      v_out = self.net(input_norm, condition_emb, time)
      v = v_out * std_t

       if self.gaussian_baseline:
            k_t = (t_view * std_1_view**2) / var_t
            x1_hat = mean_1_view + k_t * x_centered  
            v_affine = (t_view * input) + (1-t_view) * x1_hat
            v += v_affine

self.noise_scale = noise_scale
self.gaussian_baseline = gaussian_baseline

mean_1_tensor = torch.as_tensor(mean_1).expand(input_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do not have to always extend mean_1 and std_1 by default i.e. if its a scalar then we should keep it as a scalar, if its a vector we check that its of size input_shape (or can be reshape into input_shape) otherwise we raise an error. This would save a bit of memory.

However, note that this would require adjusting the forward implementation.



@pytest.mark.slow
def test_fmpe_shifted_data_gaussian_baseline():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would adapt this test: At initialization FM nets will produce zero on any input. So this basically tests if predicting 0 is better to predicting "something".

Instead, I think a better tests would be to:

  • Switch to a shifted Gaussian prior (with similar extreme values)
  • init FMPE with gaussian baseline, append simulations, and build a posterior + sample (untrained!)
  • compare samples to prior (i.e.sample mean/var vs prior mean var, or c2st to prior samples).

The preconditioning with Gaussian baseline should transport the samples to the prior (which should be tested here).

Copy link
Contributor

@manuelgloeckler manuelgloeckler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for you contribution.

I think the formula is still a bit off (but it also was never very clearly defined by us anyway (: ).

I do have a minor suggestion on t he mean,var buffers as well as the gaussian baseline test, which should be addressed (see comments). Once this done, we can merge it :)

Kind regards,
Manuel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add back z-scoreing for flow matching

3 participants