9
9
10
10
from sbi .inference .posteriors .base_posterior import NeuralPosterior
11
11
from sbi .inference .potentials .score_based_potential import (
12
+ CallableDifferentiablePotentialFunction ,
12
13
PosteriorScoreBasedPotential ,
13
14
score_estimator_based_potential ,
14
15
)
15
16
from sbi .neural_nets .estimators .score_estimator import ConditionalScoreEstimator
16
17
from sbi .neural_nets .estimators .shape_handling import (
17
18
reshape_to_batch_event ,
18
19
)
20
+ from sbi .samplers .rejection import rejection
19
21
from sbi .samplers .score .correctors import Corrector
20
22
from sbi .samplers .score .diffuser import Diffuser
21
23
from sbi .samplers .score .predictors import Predictor
22
24
from sbi .sbi_types import Shape
23
25
from sbi .utils import check_prior
26
+ from sbi .utils .sbiutils import gradient_ascent , within_support
24
27
from sbi .utils .torchutils import ensure_theta_batched
25
28
26
29
@@ -46,7 +49,7 @@ def __init__(
46
49
prior : Distribution ,
47
50
max_sampling_batch_size : int = 10_000 ,
48
51
device : Optional [str ] = None ,
49
- enable_transform : bool = False ,
52
+ enable_transform : bool = True ,
50
53
sample_with : str = "sde" ,
51
54
):
52
55
"""
@@ -110,7 +113,6 @@ def sample(
110
113
111
114
Args:
112
115
sample_shape: Shape of the samples to be drawn.
113
- x: Deprecated - use `.set_default_x()` prior to `.sample()`.
114
116
predictor: The predictor for the diffusion-based sampler. Can be a string or
115
117
a custom predictor following the API in `sbi.samplers.score.predictors`.
116
118
Currently, only `euler_maruyama` is implemented.
@@ -136,23 +138,39 @@ def sample(
136
138
137
139
x = self ._x_else_default_x (x )
138
140
x = reshape_to_batch_event (x , self .score_estimator .condition_shape )
139
- self .potential_fn .set_x (x )
141
+ self .potential_fn .set_x (x , x_is_iid = True )
142
+
143
+ num_samples = torch .Size (sample_shape ).numel ()
140
144
141
145
if self .sample_with == "ode" :
142
- samples = self .sample_via_zuko (sample_shape = sample_shape , x = x )
143
- elif self .sample_with == "sde" :
144
- samples = self ._sample_via_diffusion (
145
- sample_shape = sample_shape ,
146
- predictor = predictor ,
147
- corrector = corrector ,
148
- predictor_params = predictor_params ,
149
- corrector_params = corrector_params ,
150
- steps = steps ,
151
- ts = ts ,
146
+ samples = rejection .accept_reject_sample (
147
+ proposal = self .sample_via_ode ,
148
+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
149
+ num_samples = num_samples ,
150
+ show_progress_bars = show_progress_bars ,
152
151
max_sampling_batch_size = max_sampling_batch_size ,
152
+ )[0 ]
153
+ elif self .sample_with == "sde" :
154
+ proposal_sampling_kwargs = {
155
+ "predictor" : predictor ,
156
+ "corrector" : corrector ,
157
+ "predictor_params" : predictor_params ,
158
+ "corrector_params" : corrector_params ,
159
+ "steps" : steps ,
160
+ "ts" : ts ,
161
+ "max_sampling_batch_size" : max_sampling_batch_size ,
162
+ "show_progress_bars" : show_progress_bars ,
163
+ }
164
+ samples = rejection .accept_reject_sample (
165
+ proposal = self ._sample_via_diffusion ,
166
+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
167
+ num_samples = num_samples ,
153
168
show_progress_bars = show_progress_bars ,
154
- )
169
+ max_sampling_batch_size = max_sampling_batch_size ,
170
+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
171
+ )[0 ]
155
172
173
+ samples = samples .reshape (sample_shape + self .score_estimator .input_shape )
156
174
return samples
157
175
158
176
def _sample_via_diffusion (
@@ -171,7 +189,6 @@ def _sample_via_diffusion(
171
189
172
190
Args:
173
191
sample_shape: Shape of the samples to be drawn.
174
- x: Deprecated - use `.set_default_x()` prior to `.sample()`.
175
192
predictor: The predictor for the diffusion-based sampler. Can be a string or
176
193
a custom predictor following the API in `sbi.samplers.score.predictors`.
177
194
Currently, only `euler_maruyama` is implemented.
@@ -222,11 +239,10 @@ def _sample_via_diffusion(
222
239
)
223
240
samples = torch .cat (samples , dim = 0 )[:num_samples ]
224
241
225
- return samples . reshape ( sample_shape + self . score_estimator . input_shape )
242
+ return samples
226
243
227
- def sample_via_zuko (
244
+ def sample_via_ode (
228
245
self ,
229
- x : Tensor ,
230
246
sample_shape : Shape = torch .Size (),
231
247
) -> Tensor :
232
248
r"""Return samples from posterior distribution with probability flow ODE.
@@ -243,10 +259,12 @@ def sample_via_zuko(
243
259
"""
244
260
num_samples = torch .Size (sample_shape ).numel ()
245
261
246
- flow = self .potential_fn .get_continuous_normalizing_flow (condition = x )
262
+ flow = self .potential_fn .get_continuous_normalizing_flow (
263
+ condition = self .potential_fn .x_o
264
+ )
247
265
samples = flow .sample (torch .Size ((num_samples ,)))
248
266
249
- return samples . reshape ( sample_shape + self . score_estimator . input_shape )
267
+ return samples
250
268
251
269
def log_prob (
252
270
self ,
@@ -291,19 +309,73 @@ def sample_batched(
291
309
self ,
292
310
sample_shape : torch .Size ,
293
311
x : Tensor ,
312
+ predictor : Union [str , Predictor ] = "euler_maruyama" ,
313
+ corrector : Optional [Union [str , Corrector ]] = None ,
314
+ predictor_params : Optional [Dict ] = None ,
315
+ corrector_params : Optional [Dict ] = None ,
316
+ steps : int = 500 ,
317
+ ts : Optional [Tensor ] = None ,
294
318
max_sampling_batch_size : int = 10000 ,
295
319
show_progress_bars : bool = True ,
296
320
) -> Tensor :
297
- raise NotImplementedError (
298
- "Batched sampling is not implemented for ScorePosterior."
321
+ num_samples = torch .Size (sample_shape ).numel ()
322
+ x = reshape_to_batch_event (x , self .score_estimator .condition_shape )
323
+ condition_dim = len (self .score_estimator .condition_shape )
324
+ batch_shape = x .shape [:- condition_dim ]
325
+ batch_size = batch_shape .numel ()
326
+ self .potential_fn .set_x (x )
327
+
328
+ max_sampling_batch_size = (
329
+ self .max_sampling_batch_size
330
+ if max_sampling_batch_size is None
331
+ else max_sampling_batch_size
299
332
)
300
333
334
+ if self .sample_with == "ode" :
335
+ samples = rejection .accept_reject_sample (
336
+ proposal = self .sample_via_ode ,
337
+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
338
+ num_samples = num_samples ,
339
+ num_xos = batch_size ,
340
+ show_progress_bars = show_progress_bars ,
341
+ max_sampling_batch_size = max_sampling_batch_size ,
342
+ proposal_sampling_kwargs = {"x" : x },
343
+ )[0 ]
344
+ samples = samples .reshape (
345
+ sample_shape + batch_shape + self .score_estimator .input_shape
346
+ )
347
+ elif self .sample_with == "sde" :
348
+ proposal_sampling_kwargs = {
349
+ "predictor" : predictor ,
350
+ "corrector" : corrector ,
351
+ "predictor_params" : predictor_params ,
352
+ "corrector_params" : corrector_params ,
353
+ "steps" : steps ,
354
+ "ts" : ts ,
355
+ "max_sampling_batch_size" : max_sampling_batch_size ,
356
+ "show_progress_bars" : show_progress_bars ,
357
+ }
358
+ samples = rejection .accept_reject_sample (
359
+ proposal = self ._sample_via_diffusion ,
360
+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
361
+ num_samples = num_samples ,
362
+ num_xos = batch_size ,
363
+ show_progress_bars = show_progress_bars ,
364
+ max_sampling_batch_size = max_sampling_batch_size ,
365
+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
366
+ )[0 ]
367
+ samples = samples .reshape (
368
+ sample_shape + batch_shape + self .score_estimator .input_shape
369
+ )
370
+
371
+ return samples
372
+
301
373
def map (
302
374
self ,
303
375
x : Optional [Tensor ] = None ,
304
376
num_iter : int = 1000 ,
305
377
num_to_optimize : int = 1000 ,
306
- learning_rate : float = 1e-5 ,
378
+ learning_rate : float = 0.01 ,
307
379
init_method : Union [str , Tensor ] = "posterior" ,
308
380
num_init_samples : int = 1000 ,
309
381
save_best_every : int = 1000 ,
@@ -351,17 +423,41 @@ def map(
351
423
Returns:
352
424
The MAP estimate.
353
425
"""
354
- raise NotImplementedError (
355
- "MAP estimation is currently not working accurately for ScorePosterior."
356
- )
357
- return super ().map (
358
- x = x ,
359
- num_iter = num_iter ,
360
- num_to_optimize = num_to_optimize ,
361
- learning_rate = learning_rate ,
362
- init_method = init_method ,
363
- num_init_samples = num_init_samples ,
364
- save_best_every = save_best_every ,
365
- show_progress_bars = show_progress_bars ,
366
- force_update = force_update ,
367
- )
426
+ if x is not None :
427
+ raise ValueError (
428
+ "Passing `x` directly to `.map()` has been deprecated."
429
+ "Use `.self_default_x()` to set `x`, and then run `.map()` "
430
+ )
431
+
432
+ if self .default_x is None :
433
+ raise ValueError (
434
+ "Default `x` has not been set."
435
+ "To set the default, use the `.set_default_x()` method."
436
+ )
437
+
438
+ if self ._map is None or force_update :
439
+ self .potential_fn .set_x (self .default_x )
440
+ callable_potential_fn = CallableDifferentiablePotentialFunction (
441
+ self .potential_fn
442
+ )
443
+ if init_method == "posterior" :
444
+ inits = self .sample ((num_init_samples ,))
445
+ elif init_method == "proposal" :
446
+ inits = self .proposal .sample ((num_init_samples ,)) # type: ignore
447
+ elif isinstance (init_method , Tensor ):
448
+ inits = init_method
449
+ else :
450
+ raise ValueError
451
+
452
+ self ._map = gradient_ascent (
453
+ potential_fn = callable_potential_fn ,
454
+ inits = inits ,
455
+ theta_transform = self .theta_transform ,
456
+ num_iter = num_iter ,
457
+ num_to_optimize = num_to_optimize ,
458
+ learning_rate = learning_rate ,
459
+ save_best_every = save_best_every ,
460
+ show_progress_bars = show_progress_bars ,
461
+ )[0 ]
462
+
463
+ return self ._map
0 commit comments