diff --git a/sbi/inference/posteriors/vector_field_posterior.py b/sbi/inference/posteriors/vector_field_posterior.py index 37f7dc96c..97675a03b 100644 --- a/sbi/inference/posteriors/vector_field_posterior.py +++ b/sbi/inference/posteriors/vector_field_posterior.py @@ -260,6 +260,7 @@ def _sample_via_diffusion( ts: Optional[Tensor] = None, max_sampling_batch_size: int = 10_000, show_progress_bars: bool = True, + save_intermediate: bool = False, ) -> Tensor: r"""Return samples from posterior distribution $p(\theta|x)$. @@ -281,6 +282,9 @@ def _sample_via_diffusion( sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to `.sample()`. show_progress_bars: Whether to show a progress bar during sampling. + save_intermediate: Whether to save intermediate results of the diffusion + process. If True, the returned tensor has shape + `(*sample_shape, steps, *input_shape)`. """ if not self.vector_field_estimator.SCORE_DEFINED: @@ -332,6 +336,7 @@ def _sample_via_diffusion( num_samples=current_batch_size, ts=ts, show_progress_bars=show_progress_bars, + save_intermediate=save_intermediate, ) all_samples.append(batch_samples) diff --git a/sbi/neural_nets/estimators/flowmatching_estimator.py b/sbi/neural_nets/estimators/flowmatching_estimator.py index 3eb37f52e..dfbb59da9 100644 --- a/sbi/neural_nets/estimators/flowmatching_estimator.py +++ b/sbi/neural_nets/estimators/flowmatching_estimator.py @@ -243,7 +243,9 @@ def score(self, input: Tensor, condition: Tensor, t: Tensor) -> Tensor: score = (-(1 - t) * v - input) / (t + self.noise_scale) return score - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: + def drift_fn( + self, input: Tensor, times: Tensor, effective_t_max: float = 0.99 + ) -> Tensor: r"""Drift function for the flow matching estimator. The drift function is calculated based on [3]_ (see Equation 7): @@ -263,14 +265,22 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Args: input: Parameters :math:`\theta_t`. times: SDE time variable in [0,1]. + effective_t_max: Upper bound on time to avoid numerical issues at t=1. + This effectively prevents an explosion of the SDE in the beginning. + Note that this does not affect the ODE sampling, which always uses + times in [0,1]. Returns: Drift function at a given time. """ # analytical f(t) does not depend on noise_scale and is undefined at t = 1. - return -input / torch.maximum(1 - times, torch.tensor(1e-6).to(input)) + return -input / torch.maximum( + 1 - times, torch.tensor(1 - effective_t_max).to(input) + ) - def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: + def diffusion_fn( + self, input: Tensor, times: Tensor, effective_t_max: float = 0.99 + ) -> Tensor: r"""Diffusion function for the flow matching estimator. The diffusion function is calculated based on [3]_ (see Equation 7): @@ -288,6 +298,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Args: input: Parameters :math:`\theta_t`. times: SDE time variable in [0,1]. + effective_t_max: Upper bound on time to avoid numerical issues at t=1. + This effectively prevents an explosion of the SDE in the beginning. + Note that this does not affect the ODE sampling, which always uses + times in [0,1]. Returns: Diffusion function at a given time. @@ -296,10 +310,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: return torch.sqrt( 2 * (times + self.noise_scale) - / torch.maximum(1 - times, torch.tensor(1e-6).to(times)) + / torch.maximum(1 - times, torch.tensor(1 - effective_t_max).to(times)) ) - def mean_t_fn(self, times: Tensor) -> Tensor: + def mean_t_fn(self, times: Tensor, effective_t_max: float = 0.99) -> Tensor: r"""Linear coefficient of the perturbation kernel expectation :math:`\mu_t(t) = E[\theta_t | \theta_0]` for the flow matching estimator. @@ -316,10 +330,18 @@ def mean_t_fn(self, times: Tensor) -> Tensor: Args: times: SDE time variable in [0,1]. + effective_t_max: Upper bound on time to avoid numerical issues at t=1. + This prevents singularity at t=1 in the mean function (mean_t=0.). + NOTE: This did affect the IID sampling as the analytical denoising + moments run into issues (as mean_t=0) effectively makes it pure + noise and equations are not well defined anymore. Alternatively + we could also adapt the analytical denoising equations in + `utils/score_utils.py` to account for this case. Returns: Mean function at a given time. """ + times = torch.clamp(times, max=effective_t_max) mean_t = 1 - times for _ in range(len(self.input_shape)): mean_t = mean_t.unsqueeze(-1) diff --git a/sbi/samplers/score/diffuser.py b/sbi/samplers/score/diffuser.py index e4f32be80..90349a9e2 100644 --- a/sbi/samplers/score/diffuser.py +++ b/sbi/samplers/score/diffuser.py @@ -175,6 +175,6 @@ def run( intermediate_samples.append(samples) if save_intermediate: - return torch.cat(intermediate_samples, dim=0) + return torch.cat(intermediate_samples, dim=1) else: return samples diff --git a/tests/linearGaussian_vector_field_test.py b/tests/linearGaussian_vector_field_test.py index 32e82e5a8..6854637bb 100644 --- a/tests/linearGaussian_vector_field_test.py +++ b/tests/linearGaussian_vector_field_test.py @@ -373,14 +373,6 @@ def test_vector_field_iid_inference( num_trials: The number of trials to run. """ - if ( - vector_field_type == "fmpe" - and prior_type == "uniform" - and iid_method in ["gauss", "auto_gauss", "jac_gauss"] - ): - # TODO: Predictor produces NaNs for these cases, see #1656 - pytest.skip("Known issue of IID methods with uniform priors, see #1656.") - vector_field_trained_model = train_vector_field_model(vector_field_type, prior_type) # Extract data from the trained model