Skip to content

Commit dd4aef7

Browse files
authored
fix: rebuild flow logic in score methods (#1404)
* change rebuild_flow default to False * rebuild flow for each new x, but not in __call__. fix map test * remove old kwargs * add tolerance options in log_prob; fix map test. * fix tests
1 parent e063c58 commit dd4aef7

File tree

7 files changed

+69
-57
lines changed

7 files changed

+69
-57
lines changed

sbi/diagnostics/sbc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.distributions import Uniform
1111
from tqdm.auto import tqdm
1212

13-
from sbi.inference import DirectPosterior
13+
from sbi.inference import DirectPosterior, ScorePosterior
1414
from sbi.inference.posteriors.base_posterior import NeuralPosterior
1515
from sbi.inference.posteriors.vi_posterior import VIPosterior
1616
from sbi.utils.diagnostics_utils import (
@@ -186,7 +186,7 @@ def get_nltp(thetas: Tensor, xs: Tensor, posterior: NeuralPosterior) -> Tensor:
186186
nltp: negative log probs of true parameters under approximate posteriors.
187187
"""
188188
nltp = torch.zeros(thetas.shape[0])
189-
unnormalized_log_prob = not isinstance(posterior, DirectPosterior)
189+
unnormalized_log_prob = not isinstance(posterior, (DirectPosterior, ScorePosterior))
190190

191191
for idx, (tho, xo) in enumerate(zip(thetas, xs)):
192192
# Log prob of true params under posterior.

sbi/inference/posteriors/score_posterior.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def log_prob(
272272
x: Optional[Tensor] = None,
273273
track_gradients: bool = False,
274274
atol: float = 1e-5,
275-
rtol: float = 1e-6,
276-
exact: bool = True,
275+
rtol: float = 1e-5,
276+
exact: bool = False,
277277
) -> Tensor:
278278
r"""Returns the log-probability of the posterior $p(\theta|x)$.
279279
@@ -294,15 +294,14 @@ def log_prob(
294294
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
295295
support of the prior, -∞ (corresponding to 0 probability) outside.
296296
"""
297-
self.potential_fn.set_x(self._x_else_default_x(x))
297+
self.potential_fn.set_x(
298+
self._x_else_default_x(x), atol=atol, rtol=rtol, exact=exact
299+
)
298300

299301
theta = ensure_theta_batched(torch.as_tensor(theta))
300302
return self.potential_fn(
301303
theta.to(self._device),
302304
track_gradients=track_gradients,
303-
atol=atol,
304-
rtol=rtol,
305-
exact=exact,
306305
)
307306

308307
def sample_batched(
@@ -318,6 +317,31 @@ def sample_batched(
318317
max_sampling_batch_size: int = 10000,
319318
show_progress_bars: bool = True,
320319
) -> Tensor:
320+
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
321+
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
322+
manner.
323+
324+
Args:
325+
sample_shape: Desired shape of samples that are drawn from the posterior
326+
given every observation.
327+
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
328+
`batch_dim` corresponds to the number of observations to be
329+
drawn.
330+
predictor: The predictor for the diffusion-based sampler. Can be a string or
331+
a custom predictor following the API in `sbi.samplers.score.predictors`.
332+
Currently, only `euler_maruyama` is implemented.
333+
corrector: The corrector for the diffusion-based sampler.
334+
predictor_params: Additional parameters passed to predictor.
335+
corrector_params: Additional parameters passed to corrector.
336+
steps: Number of steps to take for the Euler-Maruyama method.
337+
ts: Time points at which to evaluate the diffusion process. If None, a
338+
linear grid between t_max and t_min is used.
339+
max_sampling_batch_size: Maximum batch size for sampling.
340+
show_progress_bars: Whether to show sampling progress monitor.
341+
342+
Returns:
343+
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
344+
"""
321345
num_samples = torch.Size(sample_shape).numel()
322346
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
323347
condition_dim = len(self.score_estimator.condition_shape)
@@ -339,7 +363,6 @@ def sample_batched(
339363
num_xos=batch_size,
340364
show_progress_bars=show_progress_bars,
341365
max_sampling_batch_size=max_sampling_batch_size,
342-
proposal_sampling_kwargs={"x": x},
343366
)[0]
344367
samples = samples.reshape(
345368
sample_shape + batch_shape + self.score_estimator.input_shape
@@ -436,7 +459,8 @@ def map(
436459
)
437460

438461
if self._map is None or force_update:
439-
self.potential_fn.set_x(self.default_x)
462+
# rebuild coarse flow fast for MAP optimization.
463+
self.potential_fn.set_x(self.default_x, atol=1e-2, rtol=1e-3, exact=True)
440464
callable_potential_fn = CallableDifferentiablePotentialFunction(
441465
self.potential_fn
442466
)

sbi/inference/potentials/score_based_potential.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from functools import partial
54
from typing import Optional, Tuple
65

76
import torch
@@ -79,41 +78,36 @@ def set_x(
7978
self,
8079
x_o: Optional[Tensor],
8180
x_is_iid: Optional[bool] = False,
82-
rebuild_flow: Optional[bool] = True,
81+
atol: float = 1e-5,
82+
rtol: float = 1e-6,
83+
exact: bool = True,
8384
):
8485
"""
8586
Set the observed data and whether it is IID.
87+
88+
Rebuids the continuous normalizing flow if the observed data is set.
89+
8690
Args:
87-
x_o: The observed data.
88-
x_is_iid: Whether the observed data is IID (if batch_dim>1).
89-
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if
90-
the flow needs to be evaluated many times (e.g. for MAP calculation).
91+
x_o: The observed data.
92+
x_is_iid: Whether the observed data is IID (if batch_dim>1).
93+
atol: Absolute tolerance for the ODE solver.
94+
rtol: Relative tolerance for the ODE solver.
95+
exact: Whether to use the exact ODE solver.
9196
"""
9297
super().set_x(x_o, x_is_iid)
93-
if rebuild_flow and self._x_o is not None:
94-
# By default, we want a high-tolerance flow.
95-
# This flow will be used mainly for MAP calculations, hence we want to save
96-
# it instead of rebuilding it every time.
97-
self.flow = self.rebuild_flow(atol=1e-2, rtol=1e-3, exact=True)
98+
if self._x_o is not None:
99+
self.flow = self.rebuild_flow(atol=atol, rtol=rtol, exact=exact)
98100

99101
def __call__(
100102
self,
101103
theta: Tensor,
102104
track_gradients: bool = True,
103-
rebuild_flow: bool = True,
104-
atol: float = 1e-5,
105-
rtol: float = 1e-6,
106-
exact: bool = True,
107105
) -> Tensor:
108106
"""Return the potential (posterior log prob) via probability flow ODE.
109107
110108
Args:
111109
theta: The parameters at which to evaluate the potential.
112110
track_gradients: Whether to track gradients.
113-
rebuild_flow: Whether to rebuild the CNF for accurate log_prob evaluation.
114-
atol: Absolute tolerance for the ODE solver.
115-
rtol: Relative tolerance for the ODE solver.
116-
exact: Whether to use the exact ODE solver.
117111
118112
Returns:
119113
The potential function, i.e., the log probability of the posterior.
@@ -123,15 +117,9 @@ def __call__(
123117
theta, theta.shape[1:], leading_is_sample=True
124118
)
125119
self.score_estimator.eval()
126-
# use rebuild_flow to evaluate log_prob with better precision, without
127-
# overwriting self.flow
128-
if rebuild_flow or self.flow is None:
129-
flow = self.rebuild_flow(atol=atol, rtol=rtol, exact=exact)
130-
else:
131-
flow = self.flow
132120

133121
with torch.set_grad_enabled(track_gradients):
134-
log_probs = flow.log_prob(theta_density_estimator).squeeze(-1)
122+
log_probs = self.flow.log_prob(theta_density_estimator).squeeze(-1)
135123
# Force probability to be zero outside prior support.
136124
in_prior_support = within_support(self.prior, theta)
137125

@@ -217,7 +205,7 @@ def rebuild_flow(
217205
x_density_estimator = reshape_to_batch_event(
218206
self.x_o, event_shape=self.score_estimator.condition_shape
219207
)
220-
assert x_density_estimator.shape[0] == 1, (
208+
assert x_density_estimator.shape[0] == 1 or not self.x_is_iid, (
221209
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
222210
)
223211

@@ -312,9 +300,8 @@ def __init__(self, posterior_score_based_potential):
312300
self.posterior_score_based_potential = posterior_score_based_potential
313301

314302
def __call__(self, input):
315-
prepared_potential = partial(
316-
self.posterior_score_based_potential.__call__, rebuild_flow=False
317-
)
318303
return DifferentiablePotentialFunction.apply(
319-
input, prepared_potential, self.posterior_score_based_potential.gradient
304+
input,
305+
self.posterior_score_based_potential.__call__,
306+
self.posterior_score_based_potential.gradient,
320307
)

sbi/utils/sbiutils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,8 @@ def gradient_ascent(
961961
f"Optimizing MAP estimate. Iterations: {iter_ + 1} / "
962962
f"{num_iter}. Performance in iteration "
963963
f"{divmod(iter_ + 1, save_best_every)[0] * save_best_every}: "
964-
f"{best_log_prob_iter.item():.2f} (= unnormalized log-prob)",
964+
f"{best_log_prob_iter.item():.2f} (= unnormalized log-prob). "
965+
"Press Ctrl-C to interrupt.",
965966
end="",
966967
)
967968
argmax_ = theta_transform.inv(best_theta_overall)

tests/linearGaussian_npse_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import List
22

33
import pytest
4-
import torch
54
from torch import eye, ones, zeros
65
from torch.distributions import MultivariateNormal
76

@@ -223,11 +222,9 @@ def test_npse_map():
223222
theta = prior.sample((num_simulations,))
224223
x = linear_gaussian(theta, likelihood_shift, likelihood_cov)
225224

226-
inference.append_simulations(theta, x).train(
227-
training_batch_size=100, max_num_epochs=10
228-
)
225+
inference.append_simulations(theta, x).train()
229226
posterior = inference.build_posterior().set_default_x(x_o)
230227

231-
map_ = posterior.map(show_progress_bars=True)
228+
map_ = posterior.map(show_progress_bars=True, num_iter=5)
232229

233-
assert torch.allclose(map_, gt_posterior.mean, atol=0.4), "MAP is not close to GT."
230+
assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5, "MAP is not close to GT."

tests/save_and_load_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
(NRE, "rejection"),
2222
),
2323
)
24-
def test_picklability(inference_method, sampling_method: str, tmp_path):
24+
def test_picklability(
25+
inference_method, sampling_method: str, tmp_path, mcmc_params_fast
26+
):
2527
num_dim = 2
2628
prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
2729
x_o = torch.zeros(1, num_dim)
@@ -31,15 +33,15 @@ def test_picklability(inference_method, sampling_method: str, tmp_path):
3133

3234
inference = inference_method(prior=prior)
3335
_ = inference.append_simulations(theta, x).train(max_num_epochs=1)
34-
posterior = inference.build_posterior(sample_with=sampling_method).set_default_x(
35-
x_o
36-
)
36+
posterior = inference.build_posterior(
37+
sample_with=sampling_method, mcmc_parameters=mcmc_params_fast
38+
).set_default_x(x_o)
3739

3840
# After sample and log_prob, the posterior should still be picklable
3941
if isinstance(posterior, VIPosterior):
4042
posterior.train(max_num_iters=10)
4143
_ = posterior.sample((1,))
42-
_ = posterior.log_prob(torch.zeros(1, num_dim))
44+
_ = posterior.potential(torch.zeros(1, num_dim))
4345

4446
with open(f"{tmp_path}/saved_posterior.pickle", "wb") as handle:
4547
pickle.dump(posterior, handle)

tests/sbc_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
(NPSE, None),
3030
),
3131
)
32-
def test_running_sbc(method, prior, reduce_fn_str, sampler, mcmc_params_accurate: dict):
32+
def test_running_sbc(method, prior, reduce_fn_str, sampler, mcmc_params_fast: dict):
3333
"""Tests running inference and then SBC and obtaining nltp."""
3434

3535
num_dim = 2
@@ -59,7 +59,7 @@ def test_running_sbc(method, prior, reduce_fn_str, sampler, mcmc_params_accurate
5959
posterior_kwargs = {
6060
"sample_with": "mcmc" if sampler == "mcmc" else "vi",
6161
"mcmc_method": "slice_np_vectorized",
62-
"mcmc_parameters": mcmc_params_accurate,
62+
"mcmc_parameters": mcmc_params_fast,
6363
}
6464
else:
6565
posterior_kwargs = {}
@@ -69,7 +69,7 @@ def test_running_sbc(method, prior, reduce_fn_str, sampler, mcmc_params_accurate
6969
thetas = prior.sample((num_sbc_runs,))
7070
xs = linear_gaussian(thetas, likelihood_shift, likelihood_cov)
7171

72-
reduce_fn = "marginals" if reduce_fn_str == "marginals" else posterior.log_prob
72+
reduce_fn = "marginals" if reduce_fn_str == "marginals" else posterior.potential
7373
run_sbc(
7474
thetas,
7575
xs,
@@ -79,7 +79,8 @@ def test_running_sbc(method, prior, reduce_fn_str, sampler, mcmc_params_accurate
7979
)
8080

8181
# Check nltp
82-
get_nltp(thetas, xs, posterior)
82+
if method in [NPE, NPSE]:
83+
get_nltp(thetas, xs, posterior)
8384

8485

8586
@pytest.mark.slow

0 commit comments

Comments
 (0)