Skip to content

Commit 6ea16be

Browse files
authored
Batch sampling slow without warning (#1490)
* direct posterior: warning for slow batched sampling with large batch size * add option 'importance' to sample_with argument * bugfix: error message for batched sampling not implemented * adapt threshold for warning * change name to for consistency with rejection.py * change condition for acceptance rate warning in rejection sampling to total number of samples tried to draw instead of total number drawn * update threshold for high batch size sampling
1 parent b7c4acc commit 6ea16be

File tree

5 files changed

+20
-8
lines changed

5 files changed

+20
-8
lines changed

sbi/inference/posteriors/direct_posterior.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import warnings
45
from typing import Optional, Union
56

67
import torch
@@ -168,6 +169,17 @@ def sample_batched(
168169
num_samples = torch.Size(sample_shape).numel()
169170
condition_shape = self.posterior_estimator.condition_shape
170171
x = reshape_to_batch_event(x, event_shape=condition_shape)
172+
num_xos = x.shape[0]
173+
174+
# throw warning if num_x * num_samples is too large
175+
if num_xos * num_samples > 2**21: # 2 million-ish
176+
warnings.warn(
177+
"Note that for batched sampling, the direct posterior sampling "
178+
"generates {num_xos} * {num_samples} = {num_xos * num_samples} "
179+
"samples. This can be slow and memory-intensive. Consider "
180+
"reducing the number of samples or batch size.",
181+
stacklevel=2,
182+
)
171183

172184
max_sampling_batch_size = (
173185
self.max_sampling_batch_size

sbi/inference/posteriors/importance_posterior.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ def sample_batched(
210210
show_progress_bars: bool = True,
211211
) -> Tensor:
212212
raise NotImplementedError(
213-
"Batched sampling is not implemented for ImportanceSamplingPosterior. \
214-
Alternatively you can use `sample` in a loop \
215-
[posterior.sample(theta, x_o) for x_o in x]."
213+
"Batched sampling is not implemented for ImportanceSamplingPosterior. "
214+
"Alternatively you can use `sample` in a loop "
215+
"[posterior.sample(theta, x_o) for x_o in x]."
216216
)
217217

218218
def _importance_sample(

sbi/inference/posteriors/vi_posterior.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ def sample_batched(
304304
show_progress_bars: bool = True,
305305
) -> Tensor:
306306
raise NotImplementedError(
307-
"Batched sampling is not implemented for VIPosterior. \
308-
Alternatively you can use `sample` in a loop \
309-
[posterior.sample(theta, x_o) for x_o in x]."
307+
"Batched sampling is not implemented for VIPosterior. "
308+
"Alternatively you can use `sample` in a loop "
309+
"[posterior.sample(theta, x_o) for x_o in x]."
310310
)
311311

312312
def log_prob(

sbi/inference/trainers/npe/npe_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def build_posterior(
470470
If `None`, use the latest neural density estimator that was trained.
471471
prior: Prior distribution.
472472
sample_with: Method to use for sampling from the posterior. Must be one of
473-
[`direct` | `mcmc` | `rejection` | `vi`].
473+
[`direct` | `mcmc` | `rejection` | `vi` | `importance`].
474474
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
475475
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
476476
implementation of slice sampling; select `hmc`, `nuts` or `slice` for

sbi/samplers/rejection/rejection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def accept_reject_sample(
321321
max(int(1.5 * num_remaining / max(min_acceptance_rate, 1e-12)), 100),
322322
)
323323
if (
324-
num_sampled_total.min().item() > 1000
324+
num_samples_possible > 1000
325325
and min_acceptance_rate < warn_acceptance
326326
and not leakage_warning_raised
327327
):

0 commit comments

Comments
 (0)