From 2c87551d7f06d2286d36b541eb96a26048b6c99d Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 4 Sep 2025 11:42:31 +0200 Subject: [PATCH 1/7] allow sampling at every timepoint for diagnostic purposes --- sbi/inference/posteriors/vector_field_posterior.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sbi/inference/posteriors/vector_field_posterior.py b/sbi/inference/posteriors/vector_field_posterior.py index a5b5bfcfd..35c151a22 100644 --- a/sbi/inference/posteriors/vector_field_posterior.py +++ b/sbi/inference/posteriors/vector_field_posterior.py @@ -254,6 +254,7 @@ def _sample_via_diffusion( ts: Optional[Tensor] = None, max_sampling_batch_size: int = 10_000, show_progress_bars: bool = True, + save_intermediate: bool = False, ) -> Tensor: r"""Return samples from posterior distribution $p(\theta|x)$. @@ -275,6 +276,9 @@ def _sample_via_diffusion( sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to `.sample()`. show_progress_bars: Whether to show a progress bar during sampling. + save_intermediate: Whether to save intermediate results of the diffusion + process. If True, the returned tensor has shape + `(*sample_shape, steps, *input_shape)`. """ if not self.vector_field_estimator.SCORE_DEFINED: @@ -316,9 +320,10 @@ def _sample_via_diffusion( num_samples=max_sampling_batch_size, ts=ts, show_progress_bars=show_progress_bars, + save_intermediate=save_intermediate, ) ) - samples = torch.cat(samples, dim=0)[:num_samples] + samples = torch.cat(samples, dim=0)#[:num_samples] return samples From d4d4c1d60c58282218022b1b5c25b0432b416956 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 4 Sep 2025 11:42:45 +0200 Subject: [PATCH 2/7] Fix FMPE SDE sampling --- sbi/neural_nets/estimators/flowmatching_estimator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sbi/neural_nets/estimators/flowmatching_estimator.py b/sbi/neural_nets/estimators/flowmatching_estimator.py index 3eb37f52e..4ca540e87 100644 --- a/sbi/neural_nets/estimators/flowmatching_estimator.py +++ b/sbi/neural_nets/estimators/flowmatching_estimator.py @@ -268,7 +268,9 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Drift function at a given time. """ # analytical f(t) does not depend on noise_scale and is undefined at t = 1. - return -input / torch.maximum(1 - times, torch.tensor(1e-6).to(input)) + # NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01 + # this effectively prevents and explosion of the SDE in the beginning. + return -input / torch.maximum(1 - times, torch.tensor(1e-2).to(input)) def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: r"""Diffusion function for the flow matching estimator. @@ -293,10 +295,12 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Diffusion function at a given time. """ # analytical g(t) is undefined at t = 1. + # NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01 + # this effectively prevents and explosion of the SDE in the beginning. return torch.sqrt( 2 * (times + self.noise_scale) - / torch.maximum(1 - times, torch.tensor(1e-6).to(times)) + / torch.maximum(1 - times, torch.tensor(1e-2).to(times)) ) def mean_t_fn(self, times: Tensor) -> Tensor: From b88080add14b285a244abc06785da8c0e8faca6b Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 4 Sep 2025 11:43:15 +0200 Subject: [PATCH 3/7] Fix bug --- sbi/samplers/score/diffuser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/samplers/score/diffuser.py b/sbi/samplers/score/diffuser.py index f10c0e092..778e9cce8 100644 --- a/sbi/samplers/score/diffuser.py +++ b/sbi/samplers/score/diffuser.py @@ -163,6 +163,6 @@ def run( intermediate_samples.append(samples) if save_intermediate: - return torch.cat(intermediate_samples, dim=0) + return torch.cat(intermediate_samples, dim=1) else: return samples From 72ec46d15d9fee353e518eece07387487ced4f2a Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 4 Sep 2025 12:01:27 +0200 Subject: [PATCH 4/7] Add as argument with docstring --- .../estimators/flowmatching_estimator.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/sbi/neural_nets/estimators/flowmatching_estimator.py b/sbi/neural_nets/estimators/flowmatching_estimator.py index 4ca540e87..a3d552417 100644 --- a/sbi/neural_nets/estimators/flowmatching_estimator.py +++ b/sbi/neural_nets/estimators/flowmatching_estimator.py @@ -243,7 +243,9 @@ def score(self, input: Tensor, condition: Tensor, t: Tensor) -> Tensor: score = (-(1 - t) * v - input) / (t + self.noise_scale) return score - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: + def drift_fn( + self, input: Tensor, times: Tensor, effective_t_max: float = 0.99 + ) -> Tensor: r"""Drift function for the flow matching estimator. The drift function is calculated based on [3]_ (see Equation 7): @@ -263,16 +265,22 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Args: input: Parameters :math:`\theta_t`. times: SDE time variable in [0,1]. + effective_t_max: Upper bound on time to avoid numerical issues at t=1. + This effectively prevents and explosion of the SDE in the beginning. + Note that this does not affect the ODE sampling, which always uses + times in [0,1]. Returns: Drift function at a given time. """ # analytical f(t) does not depend on noise_scale and is undefined at t = 1. - # NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01 - # this effectively prevents and explosion of the SDE in the beginning. - return -input / torch.maximum(1 - times, torch.tensor(1e-2).to(input)) + return -input / torch.maximum( + 1 - times, torch.tensor(1 - effective_t_max).to(input) + ) - def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: + def diffusion_fn( + self, input: Tensor, times: Tensor, effective_t_max: float = 0.99 + ) -> Tensor: r"""Diffusion function for the flow matching estimator. The diffusion function is calculated based on [3]_ (see Equation 7): @@ -290,17 +298,19 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Args: input: Parameters :math:`\theta_t`. times: SDE time variable in [0,1]. + effective_t_max: Upper bound on time to avoid numerical issues at t=1. + This effectively prevents and explosion of the SDE in the beginning. + Note that this does not affect the ODE sampling, which always uses + times in [0,1]. Returns: Diffusion function at a given time. """ # analytical g(t) is undefined at t = 1. - # NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01 - # this effectively prevents and explosion of the SDE in the beginning. return torch.sqrt( 2 * (times + self.noise_scale) - / torch.maximum(1 - times, torch.tensor(1e-2).to(times)) + / torch.maximum(1 - times, torch.tensor(1 - effective_t_max).to(times)) ) def mean_t_fn(self, times: Tensor) -> Tensor: From 916d226c39eadc1dab01c6428728b2c474f0db8e Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 4 Sep 2025 12:03:35 +0200 Subject: [PATCH 5/7] Fix typo --- sbi/neural_nets/estimators/flowmatching_estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/neural_nets/estimators/flowmatching_estimator.py b/sbi/neural_nets/estimators/flowmatching_estimator.py index a3d552417..94e3b5857 100644 --- a/sbi/neural_nets/estimators/flowmatching_estimator.py +++ b/sbi/neural_nets/estimators/flowmatching_estimator.py @@ -266,7 +266,7 @@ def drift_fn( input: Parameters :math:`\theta_t`. times: SDE time variable in [0,1]. effective_t_max: Upper bound on time to avoid numerical issues at t=1. - This effectively prevents and explosion of the SDE in the beginning. + This effectively prevents an explosion of the SDE in the beginning. Note that this does not affect the ODE sampling, which always uses times in [0,1]. @@ -299,7 +299,7 @@ def diffusion_fn( input: Parameters :math:`\theta_t`. times: SDE time variable in [0,1]. effective_t_max: Upper bound on time to avoid numerical issues at t=1. - This effectively prevents and explosion of the SDE in the beginning. + This effectively prevents an explosion of the SDE in the beginning. Note that this does not affect the ODE sampling, which always uses times in [0,1]. From b56dd857b7fea3eda91315ca450fb77853e60c8a Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 4 Sep 2025 12:30:40 +0200 Subject: [PATCH 6/7] Mean_t_fn also need effective_t_max --- sbi/neural_nets/estimators/flowmatching_estimator.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sbi/neural_nets/estimators/flowmatching_estimator.py b/sbi/neural_nets/estimators/flowmatching_estimator.py index 94e3b5857..dfbb59da9 100644 --- a/sbi/neural_nets/estimators/flowmatching_estimator.py +++ b/sbi/neural_nets/estimators/flowmatching_estimator.py @@ -313,7 +313,7 @@ def diffusion_fn( / torch.maximum(1 - times, torch.tensor(1 - effective_t_max).to(times)) ) - def mean_t_fn(self, times: Tensor) -> Tensor: + def mean_t_fn(self, times: Tensor, effective_t_max: float = 0.99) -> Tensor: r"""Linear coefficient of the perturbation kernel expectation :math:`\mu_t(t) = E[\theta_t | \theta_0]` for the flow matching estimator. @@ -330,10 +330,18 @@ def mean_t_fn(self, times: Tensor) -> Tensor: Args: times: SDE time variable in [0,1]. + effective_t_max: Upper bound on time to avoid numerical issues at t=1. + This prevents singularity at t=1 in the mean function (mean_t=0.). + NOTE: This did affect the IID sampling as the analytical denoising + moments run into issues (as mean_t=0) effectively makes it pure + noise and equations are not well defined anymore. Alternatively + we could also adapt the analytical denoising equations in + `utils/score_utils.py` to account for this case. Returns: Mean function at a given time. """ + times = torch.clamp(times, max=effective_t_max) mean_t = 1 - times for _ in range(len(self.input_shape)): mean_t = mean_t.unsqueeze(-1) From 214c182a9168017d955f044bccb504035b9234b0 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 4 Sep 2025 13:44:25 +0200 Subject: [PATCH 7/7] All tests pass now --- tests/linearGaussian_vector_field_test.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/linearGaussian_vector_field_test.py b/tests/linearGaussian_vector_field_test.py index 32e82e5a8..6854637bb 100644 --- a/tests/linearGaussian_vector_field_test.py +++ b/tests/linearGaussian_vector_field_test.py @@ -373,14 +373,6 @@ def test_vector_field_iid_inference( num_trials: The number of trials to run. """ - if ( - vector_field_type == "fmpe" - and prior_type == "uniform" - and iid_method in ["gauss", "auto_gauss", "jac_gauss"] - ): - # TODO: Predictor produces NaNs for these cases, see #1656 - pytest.skip("Known issue of IID methods with uniform priors, see #1656.") - vector_field_trained_model = train_vector_field_model(vector_field_type, prior_type) # Extract data from the trained model