1
1
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2
2
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3
3
4
+ import math
4
5
import warnings
5
6
from typing import Dict , Literal , Optional , Union
6
7
@@ -150,7 +151,9 @@ def sample(
150
151
corrector_params : Optional [Dict ] = None ,
151
152
steps : int = 500 ,
152
153
ts : Optional [Tensor ] = None ,
153
- iid_method : Literal ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" ] = "auto_gauss" ,
154
+ iid_method : Optional [
155
+ Literal ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" ]
156
+ ] = None ,
154
157
iid_params : Optional [Dict ] = None ,
155
158
max_sampling_batch_size : int = 10_000 ,
156
159
sample_with : Optional [str ] = None ,
@@ -201,19 +204,22 @@ def sample(
201
204
x = reshape_to_batch_event (x , self .vector_field_estimator .condition_shape )
202
205
is_iid = x .shape [0 ] > 1
203
206
self .potential_fn .set_x (
204
- x , x_is_iid = is_iid , iid_method = iid_method , iid_params = iid_params
207
+ x ,
208
+ x_is_iid = is_iid ,
209
+ iid_method = iid_method or self .potential_fn .iid_method ,
210
+ iid_params = iid_params ,
205
211
)
206
212
207
213
num_samples = torch .Size (sample_shape ).numel ()
208
214
209
215
if sample_with == "ode" :
210
- samples = rejection .accept_reject_sample (
216
+ samples , _ = rejection .accept_reject_sample (
211
217
proposal = self .sample_via_ode ,
212
218
accept_reject_fn = lambda theta : within_support (self .prior , theta ),
213
219
num_samples = num_samples ,
214
220
show_progress_bars = show_progress_bars ,
215
221
max_sampling_batch_size = max_sampling_batch_size ,
216
- )[ 0 ]
222
+ )
217
223
elif sample_with == "sde" :
218
224
proposal_sampling_kwargs = {
219
225
"predictor" : predictor ,
@@ -225,14 +231,14 @@ def sample(
225
231
"max_sampling_batch_size" : max_sampling_batch_size ,
226
232
"show_progress_bars" : show_progress_bars ,
227
233
}
228
- samples = rejection .accept_reject_sample (
234
+ samples , _ = rejection .accept_reject_sample (
229
235
proposal = self ._sample_via_diffusion ,
230
236
accept_reject_fn = lambda theta : within_support (self .prior , theta ),
231
237
num_samples = num_samples ,
232
238
show_progress_bars = show_progress_bars ,
233
239
max_sampling_batch_size = max_sampling_batch_size ,
234
240
proposal_sampling_kwargs = proposal_sampling_kwargs ,
235
- )[ 0 ]
241
+ )
236
242
else :
237
243
raise ValueError (
238
244
f"Expected sample_with to be 'ode' or 'sde', but got { sample_with } ."
@@ -282,13 +288,16 @@ def _sample_via_diffusion(
282
288
"The vector field estimator does not support the 'sde' sampling method."
283
289
)
284
290
285
- num_samples = torch .Size (sample_shape ).numel ()
291
+ total_samples_needed = torch .Size (sample_shape ).numel ()
286
292
287
- max_sampling_batch_size = (
293
+ # Determine effective batch size for sampling
294
+ effective_batch_size = (
288
295
self .max_sampling_batch_size
289
296
if max_sampling_batch_size is None
290
297
else max_sampling_batch_size
291
298
)
299
+ # Ensure we don't use larger batches than total samples needed
300
+ effective_batch_size = min (effective_batch_size , total_samples_needed )
292
301
293
302
# TODO: the time schedule should be provided by the estimator, see issue #1437
294
303
if ts is None :
@@ -297,28 +306,45 @@ def _sample_via_diffusion(
297
306
ts = torch .linspace (t_max , t_min , steps )
298
307
ts = ts .to (self .device )
299
308
309
+ # Initialize the diffusion sampler
300
310
diffuser = Diffuser (
301
311
self .potential_fn ,
302
312
predictor = predictor ,
303
313
corrector = corrector ,
304
314
predictor_params = predictor_params ,
305
315
corrector_params = corrector_params ,
306
316
)
307
- max_sampling_batch_size = min (max_sampling_batch_size , num_samples )
308
- samples = []
309
- num_iter = num_samples // max_sampling_batch_size
310
- num_iter = (
311
- num_iter + 1 if (num_samples % max_sampling_batch_size ) != 0 else num_iter
312
- )
313
- for _ in range (num_iter ):
314
- samples .append (
315
- diffuser .run (
316
- num_samples = max_sampling_batch_size ,
317
- ts = ts ,
318
- show_progress_bars = show_progress_bars ,
319
- )
317
+
318
+ # Calculate how many batches we need
319
+ num_batches = math .ceil (total_samples_needed / effective_batch_size )
320
+
321
+ # Generate samples in batches
322
+ all_samples = []
323
+ samples_generated = 0
324
+
325
+ for _ in range (num_batches ):
326
+ # Calculate how many samples to generate in this batch
327
+ remaining_samples = total_samples_needed - samples_generated
328
+ current_batch_size = min (effective_batch_size , remaining_samples )
329
+
330
+ # Generate samples for this batch
331
+ batch_samples = diffuser .run (
332
+ num_samples = current_batch_size ,
333
+ ts = ts ,
334
+ show_progress_bars = show_progress_bars ,
335
+ )
336
+
337
+ all_samples .append (batch_samples )
338
+ samples_generated += current_batch_size
339
+
340
+ # Concatenate all batches and ensure we return exactly the requested number
341
+ samples = torch .cat (all_samples , dim = 0 )[:total_samples_needed ]
342
+
343
+ if torch .isnan (samples ).all ():
344
+ raise RuntimeError (
345
+ "All samples NaN after diffusion sampling. "
346
+ "This may indicate numerical instability in the vector field."
320
347
)
321
- samples = torch .cat (samples , dim = 0 )[:num_samples ]
322
348
323
349
return samples
324
350
@@ -443,14 +469,14 @@ def sample_batched(
443
469
max_sampling_batch_size = capped
444
470
445
471
if self .sample_with == "ode" :
446
- samples = rejection .accept_reject_sample (
472
+ samples , _ = rejection .accept_reject_sample (
447
473
proposal = self .sample_via_ode ,
448
474
accept_reject_fn = lambda theta : within_support (self .prior , theta ),
449
475
num_samples = num_samples ,
450
476
num_xos = batch_size ,
451
477
show_progress_bars = show_progress_bars ,
452
478
max_sampling_batch_size = max_sampling_batch_size ,
453
- )[ 0 ]
479
+ )
454
480
samples = samples .reshape (
455
481
sample_shape + batch_shape + self .vector_field_estimator .input_shape
456
482
)
@@ -465,15 +491,15 @@ def sample_batched(
465
491
"max_sampling_batch_size" : max_sampling_batch_size ,
466
492
"show_progress_bars" : show_progress_bars ,
467
493
}
468
- samples = rejection .accept_reject_sample (
494
+ samples , _ = rejection .accept_reject_sample (
469
495
proposal = self ._sample_via_diffusion ,
470
496
accept_reject_fn = lambda theta : within_support (self .prior , theta ),
471
497
num_samples = num_samples ,
472
498
num_xos = batch_size ,
473
499
show_progress_bars = show_progress_bars ,
474
500
max_sampling_batch_size = max_sampling_batch_size ,
475
501
proposal_sampling_kwargs = proposal_sampling_kwargs ,
476
- )[ 0 ]
502
+ )
477
503
samples = samples .reshape (
478
504
sample_shape + batch_shape + self .vector_field_estimator .input_shape
479
505
)
0 commit comments