Skip to content

Commit 16436e6

Browse files
authored
feat: refactoring and new features for NPSE (#1370)
* npse MAP * set default enable_Transform to True * sampling via diffusion twice * batched sampling for score-based posteriors * add test for score batched sampling * better convergence checks
1 parent aa05585 commit 16436e6

File tree

9 files changed

+387
-133
lines changed

9 files changed

+387
-133
lines changed

sbi/inference/posteriors/direct_posterior.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def sample(
132132
)
133133

134134
samples = rejection.accept_reject_sample(
135-
proposal=self.posterior_estimator,
135+
proposal=self.posterior_estimator.sample,
136136
accept_reject_fn=lambda theta: within_support(self.prior, theta),
137137
num_samples=num_samples,
138138
show_progress_bars=show_progress_bars,
@@ -176,7 +176,7 @@ def sample_batched(
176176
)
177177

178178
samples = rejection.accept_reject_sample(
179-
proposal=self.posterior_estimator,
179+
proposal=self.posterior_estimator.sample,
180180
accept_reject_fn=lambda theta: within_support(self.prior, theta),
181181
num_samples=num_samples,
182182
show_progress_bars=show_progress_bars,
@@ -373,7 +373,7 @@ def leakage_correction(
373373
def acceptance_at(x: Tensor) -> Tensor:
374374
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
375375
return rejection.accept_reject_sample(
376-
proposal=self.posterior_estimator,
376+
proposal=self.posterior_estimator.sample,
377377
accept_reject_fn=lambda theta: within_support(self.prior, theta),
378378
num_samples=num_rejection_samples,
379379
show_progress_bars=show_progress_bars,

sbi/inference/posteriors/score_posterior.py

Lines changed: 133 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@
99

1010
from sbi.inference.posteriors.base_posterior import NeuralPosterior
1111
from sbi.inference.potentials.score_based_potential import (
12+
CallableDifferentiablePotentialFunction,
1213
PosteriorScoreBasedPotential,
1314
score_estimator_based_potential,
1415
)
1516
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
1617
from sbi.neural_nets.estimators.shape_handling import (
1718
reshape_to_batch_event,
1819
)
20+
from sbi.samplers.rejection import rejection
1921
from sbi.samplers.score.correctors import Corrector
2022
from sbi.samplers.score.diffuser import Diffuser
2123
from sbi.samplers.score.predictors import Predictor
2224
from sbi.sbi_types import Shape
2325
from sbi.utils import check_prior
26+
from sbi.utils.sbiutils import gradient_ascent, within_support
2427
from sbi.utils.torchutils import ensure_theta_batched
2528

2629

@@ -46,7 +49,7 @@ def __init__(
4649
prior: Distribution,
4750
max_sampling_batch_size: int = 10_000,
4851
device: Optional[str] = None,
49-
enable_transform: bool = False,
52+
enable_transform: bool = True,
5053
sample_with: str = "sde",
5154
):
5255
"""
@@ -110,7 +113,6 @@ def sample(
110113
111114
Args:
112115
sample_shape: Shape of the samples to be drawn.
113-
x: Deprecated - use `.set_default_x()` prior to `.sample()`.
114116
predictor: The predictor for the diffusion-based sampler. Can be a string or
115117
a custom predictor following the API in `sbi.samplers.score.predictors`.
116118
Currently, only `euler_maruyama` is implemented.
@@ -136,23 +138,39 @@ def sample(
136138

137139
x = self._x_else_default_x(x)
138140
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
139-
self.potential_fn.set_x(x)
141+
self.potential_fn.set_x(x, x_is_iid=True)
142+
143+
num_samples = torch.Size(sample_shape).numel()
140144

141145
if self.sample_with == "ode":
142-
samples = self.sample_via_zuko(sample_shape=sample_shape, x=x)
143-
elif self.sample_with == "sde":
144-
samples = self._sample_via_diffusion(
145-
sample_shape=sample_shape,
146-
predictor=predictor,
147-
corrector=corrector,
148-
predictor_params=predictor_params,
149-
corrector_params=corrector_params,
150-
steps=steps,
151-
ts=ts,
146+
samples = rejection.accept_reject_sample(
147+
proposal=self.sample_via_ode,
148+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
149+
num_samples=num_samples,
150+
show_progress_bars=show_progress_bars,
152151
max_sampling_batch_size=max_sampling_batch_size,
152+
)[0]
153+
elif self.sample_with == "sde":
154+
proposal_sampling_kwargs = {
155+
"predictor": predictor,
156+
"corrector": corrector,
157+
"predictor_params": predictor_params,
158+
"corrector_params": corrector_params,
159+
"steps": steps,
160+
"ts": ts,
161+
"max_sampling_batch_size": max_sampling_batch_size,
162+
"show_progress_bars": show_progress_bars,
163+
}
164+
samples = rejection.accept_reject_sample(
165+
proposal=self._sample_via_diffusion,
166+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
167+
num_samples=num_samples,
153168
show_progress_bars=show_progress_bars,
154-
)
169+
max_sampling_batch_size=max_sampling_batch_size,
170+
proposal_sampling_kwargs=proposal_sampling_kwargs,
171+
)[0]
155172

173+
samples = samples.reshape(sample_shape + self.score_estimator.input_shape)
156174
return samples
157175

158176
def _sample_via_diffusion(
@@ -171,7 +189,6 @@ def _sample_via_diffusion(
171189
172190
Args:
173191
sample_shape: Shape of the samples to be drawn.
174-
x: Deprecated - use `.set_default_x()` prior to `.sample()`.
175192
predictor: The predictor for the diffusion-based sampler. Can be a string or
176193
a custom predictor following the API in `sbi.samplers.score.predictors`.
177194
Currently, only `euler_maruyama` is implemented.
@@ -222,11 +239,10 @@ def _sample_via_diffusion(
222239
)
223240
samples = torch.cat(samples, dim=0)[:num_samples]
224241

225-
return samples.reshape(sample_shape + self.score_estimator.input_shape)
242+
return samples
226243

227-
def sample_via_zuko(
244+
def sample_via_ode(
228245
self,
229-
x: Tensor,
230246
sample_shape: Shape = torch.Size(),
231247
) -> Tensor:
232248
r"""Return samples from posterior distribution with probability flow ODE.
@@ -243,10 +259,12 @@ def sample_via_zuko(
243259
"""
244260
num_samples = torch.Size(sample_shape).numel()
245261

246-
flow = self.potential_fn.get_continuous_normalizing_flow(condition=x)
262+
flow = self.potential_fn.get_continuous_normalizing_flow(
263+
condition=self.potential_fn.x_o
264+
)
247265
samples = flow.sample(torch.Size((num_samples,)))
248266

249-
return samples.reshape(sample_shape + self.score_estimator.input_shape)
267+
return samples
250268

251269
def log_prob(
252270
self,
@@ -291,19 +309,73 @@ def sample_batched(
291309
self,
292310
sample_shape: torch.Size,
293311
x: Tensor,
312+
predictor: Union[str, Predictor] = "euler_maruyama",
313+
corrector: Optional[Union[str, Corrector]] = None,
314+
predictor_params: Optional[Dict] = None,
315+
corrector_params: Optional[Dict] = None,
316+
steps: int = 500,
317+
ts: Optional[Tensor] = None,
294318
max_sampling_batch_size: int = 10000,
295319
show_progress_bars: bool = True,
296320
) -> Tensor:
297-
raise NotImplementedError(
298-
"Batched sampling is not implemented for ScorePosterior."
321+
num_samples = torch.Size(sample_shape).numel()
322+
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
323+
condition_dim = len(self.score_estimator.condition_shape)
324+
batch_shape = x.shape[:-condition_dim]
325+
batch_size = batch_shape.numel()
326+
self.potential_fn.set_x(x)
327+
328+
max_sampling_batch_size = (
329+
self.max_sampling_batch_size
330+
if max_sampling_batch_size is None
331+
else max_sampling_batch_size
299332
)
300333

334+
if self.sample_with == "ode":
335+
samples = rejection.accept_reject_sample(
336+
proposal=self.sample_via_ode,
337+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
338+
num_samples=num_samples,
339+
num_xos=batch_size,
340+
show_progress_bars=show_progress_bars,
341+
max_sampling_batch_size=max_sampling_batch_size,
342+
proposal_sampling_kwargs={"x": x},
343+
)[0]
344+
samples = samples.reshape(
345+
sample_shape + batch_shape + self.score_estimator.input_shape
346+
)
347+
elif self.sample_with == "sde":
348+
proposal_sampling_kwargs = {
349+
"predictor": predictor,
350+
"corrector": corrector,
351+
"predictor_params": predictor_params,
352+
"corrector_params": corrector_params,
353+
"steps": steps,
354+
"ts": ts,
355+
"max_sampling_batch_size": max_sampling_batch_size,
356+
"show_progress_bars": show_progress_bars,
357+
}
358+
samples = rejection.accept_reject_sample(
359+
proposal=self._sample_via_diffusion,
360+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
361+
num_samples=num_samples,
362+
num_xos=batch_size,
363+
show_progress_bars=show_progress_bars,
364+
max_sampling_batch_size=max_sampling_batch_size,
365+
proposal_sampling_kwargs=proposal_sampling_kwargs,
366+
)[0]
367+
samples = samples.reshape(
368+
sample_shape + batch_shape + self.score_estimator.input_shape
369+
)
370+
371+
return samples
372+
301373
def map(
302374
self,
303375
x: Optional[Tensor] = None,
304376
num_iter: int = 1000,
305377
num_to_optimize: int = 1000,
306-
learning_rate: float = 1e-5,
378+
learning_rate: float = 0.01,
307379
init_method: Union[str, Tensor] = "posterior",
308380
num_init_samples: int = 1000,
309381
save_best_every: int = 1000,
@@ -351,17 +423,41 @@ def map(
351423
Returns:
352424
The MAP estimate.
353425
"""
354-
raise NotImplementedError(
355-
"MAP estimation is currently not working accurately for ScorePosterior."
356-
)
357-
return super().map(
358-
x=x,
359-
num_iter=num_iter,
360-
num_to_optimize=num_to_optimize,
361-
learning_rate=learning_rate,
362-
init_method=init_method,
363-
num_init_samples=num_init_samples,
364-
save_best_every=save_best_every,
365-
show_progress_bars=show_progress_bars,
366-
force_update=force_update,
367-
)
426+
if x is not None:
427+
raise ValueError(
428+
"Passing `x` directly to `.map()` has been deprecated."
429+
"Use `.self_default_x()` to set `x`, and then run `.map()` "
430+
)
431+
432+
if self.default_x is None:
433+
raise ValueError(
434+
"Default `x` has not been set."
435+
"To set the default, use the `.set_default_x()` method."
436+
)
437+
438+
if self._map is None or force_update:
439+
self.potential_fn.set_x(self.default_x)
440+
callable_potential_fn = CallableDifferentiablePotentialFunction(
441+
self.potential_fn
442+
)
443+
if init_method == "posterior":
444+
inits = self.sample((num_init_samples,))
445+
elif init_method == "proposal":
446+
inits = self.proposal.sample((num_init_samples,)) # type: ignore
447+
elif isinstance(init_method, Tensor):
448+
inits = init_method
449+
else:
450+
raise ValueError
451+
452+
self._map = gradient_ascent(
453+
potential_fn=callable_potential_fn,
454+
inits=inits,
455+
theta_transform=self.theta_transform,
456+
num_iter=num_iter,
457+
num_to_optimize=num_to_optimize,
458+
learning_rate=learning_rate,
459+
save_best_every=save_best_every,
460+
show_progress_bars=show_progress_bars,
461+
)[0]
462+
463+
return self._map

0 commit comments

Comments
 (0)