Skip to content

Commit 196c106

Browse files
authored
fix: cap max_sampling_batch_size to prevent excessive memory (#1624)
* add warnings * review fixes, add test.
1 parent eb5ca16 commit 196c106

File tree

4 files changed

+35
-9
lines changed

4 files changed

+35
-9
lines changed

sbi/inference/posteriors/direct_posterior.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def sample_batched(
219219
# throw warning if num_x * num_samples is too large
220220
if num_xos * num_samples > 2**21: # 2 million-ish
221221
warnings.warn(
222-
"Note that for batched sampling, the direct posterior sampling "
223-
"generates {num_xos} * {num_samples} = {num_xos * num_samples} "
222+
f"Note that for batched sampling, the direct posterior sampling "
223+
f"generates {num_xos} * {num_samples} = {num_xos * num_samples} "
224224
"samples. This can be slow and memory-intensive. Consider "
225225
"reducing the number of samples or batch size.",
226226
stacklevel=2,
@@ -232,6 +232,16 @@ def sample_batched(
232232
else max_sampling_batch_size
233233
)
234234

235+
# Adjust max_sampling_batch_size to avoid excessive memory usage
236+
if max_sampling_batch_size * num_xos > 100_000:
237+
capped = max(1, 100_000 // num_xos)
238+
warnings.warn(
239+
f"Capping max_sampling_batch_size from {max_sampling_batch_size} "
240+
f"to {capped} to avoid excessive memory usage.",
241+
stacklevel=2,
242+
)
243+
max_sampling_batch_size = capped
244+
235245
samples = rejection.accept_reject_sample(
236246
proposal=self.posterior_estimator.sample,
237247
accept_reject_fn=lambda theta: within_support(self.prior, theta),

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import warnings
45
from typing import Dict, Literal, Optional, Union
56

67
import torch
@@ -431,6 +432,16 @@ def sample_batched(
431432
else max_sampling_batch_size
432433
)
433434

435+
# Adjust max_sampling_batch_size to avoid excessive memory usage
436+
if max_sampling_batch_size * batch_size > 100_000:
437+
capped = max(1, 100_000 // batch_size)
438+
warnings.warn(
439+
f"Capping max_sampling_batch_size from {max_sampling_batch_size} "
440+
f"to {capped} to avoid excessive memory usage.",
441+
stacklevel=2,
442+
)
443+
max_sampling_batch_size = capped
444+
434445
if self.sample_with == "ode":
435446
samples = rejection.accept_reject_sample(
436447
proposal=self.sample_via_ode,

sbi/samplers/rejection/rejection.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def accept_reject_sample(
215215
216216
Args:
217217
proposal: A callable that takes `sample_shape` as arguments (and kwargs as
218-
needed). Returns samples from the proposal distribution with shape
219-
(*sample_shape, event_dim).
218+
needed). Returns samples from the proposal distribution with shape
219+
(*sample_shape, event_dim).
220220
accept_reject_fn: Function that evaluates which samples are accepted or
221221
rejected. Must take a batch of parameters and return a boolean tensor which
222222
indicates which parameters get accepted.
@@ -272,7 +272,6 @@ def accept_reject_sample(
272272
accepted = [[] for _ in range(num_xos)]
273273
acceptance_rate = torch.full((num_xos,), float("Nan"))
274274
leakage_warning_raised = False
275-
# Ruff suggestion
276275

277276
# To cover cases with few samples without leakage:
278277
sampling_batch_size = min(num_samples, max_sampling_batch_size)

tests/posterior_nn_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,19 @@ def test_importance_posterior_sample_log_prob(snplre_method: type):
102102

103103
@pytest.mark.parametrize("snpe_method", [NPE_A, NPE_C])
104104
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
105-
@pytest.mark.parametrize("prior", ("mvn", "uniform"))
105+
@pytest.mark.parametrize("prior_type", ("mvn", "uniform"))
106106
def test_batched_sample_log_prob_with_different_x(
107107
snpe_method: type,
108108
x_o_batch_dim: bool,
109-
prior: str,
109+
prior_type: str,
110110
):
111111
num_dim = 2
112112
num_simulations = 1000
113113

114114
# We also want to test on bounded support! Which will invoke leakage correction.
115-
if prior == "mvn":
115+
if prior_type == "mvn":
116116
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
117-
elif prior == "uniform":
117+
elif prior_type == "uniform":
118118
prior = Independent(Uniform(-1.0 * ones(num_dim), 1.0 * ones(num_dim)), 1)
119119
simulator = diagonal_linear_gaussian
120120

@@ -131,6 +131,12 @@ def test_batched_sample_log_prob_with_different_x(
131131
samples = posterior.sample_batched((10,), x_o)
132132
batched_log_probs = posterior.log_prob_batched(samples, x_o)
133133

134+
# Test large max_sampling_batch_size to test capping warning.
135+
with pytest.warns(UserWarning, match="Capping max_sampling_batch_size"):
136+
posterior.sample_batched(
137+
(10,), ones(3, num_dim), max_sampling_batch_size=40_000
138+
)
139+
134140
assert (
135141
samples.shape == (10, x_o_batch_dim, num_dim)
136142
if x_o_batch_dim > 0

0 commit comments

Comments
 (0)