Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _sample_via_diffusion(
ts: Optional[Tensor] = None,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
save_intermediate: bool = False,
) -> Tensor:
r"""Return samples from posterior distribution $p(\theta|x)$.

Expand All @@ -281,6 +282,9 @@ def _sample_via_diffusion(
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
`.sample()`.
show_progress_bars: Whether to show a progress bar during sampling.
save_intermediate: Whether to save intermediate results of the diffusion
process. If True, the returned tensor has shape
`(*sample_shape, steps, *input_shape)`.
"""

if not self.vector_field_estimator.SCORE_DEFINED:
Expand Down Expand Up @@ -332,6 +336,7 @@ def _sample_via_diffusion(
num_samples=current_batch_size,
ts=ts,
show_progress_bars=show_progress_bars,
save_intermediate=save_intermediate,
)

all_samples.append(batch_samples)
Expand Down
8 changes: 6 additions & 2 deletions sbi/neural_nets/estimators/flowmatching_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
Drift function at a given time.
"""
# analytical f(t) does not depend on noise_scale and is undefined at t = 1.
return -input / torch.maximum(1 - times, torch.tensor(1e-6).to(input))
# NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01
# this effectively prevents and explosion of the SDE in the beginning.
return -input / torch.maximum(1 - times, torch.tensor(1e-2).to(input))

def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
r"""Diffusion function for the flow matching estimator.
Expand All @@ -293,10 +295,12 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Diffusion function at a given time.
"""
# analytical g(t) is undefined at t = 1.
# NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01
# this effectively prevents and explosion of the SDE in the beginning.
return torch.sqrt(
2
* (times + self.noise_scale)
/ torch.maximum(1 - times, torch.tensor(1e-6).to(times))
/ torch.maximum(1 - times, torch.tensor(1e-2).to(times))
)

def mean_t_fn(self, times: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion sbi/samplers/score/diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,6 @@ def run(
intermediate_samples.append(samples)

if save_intermediate:
return torch.cat(intermediate_samples, dim=0)
return torch.cat(intermediate_samples, dim=1)
else:
return samples