Skip to content

Commit 6dd4a22

Browse files
committed
further improvements
1 parent 9db9e9e commit 6dd4a22

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

sbi/diagnostics/sbc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

44
import warnings
5-
from typing import Callable, Dict, List, Union
5+
from typing import Callable, Dict, List, Optional, Union
66

77
import torch
88
from scipy.stats import kstest, uniform
@@ -26,6 +26,7 @@ def run_sbc(
2626
num_workers: int = 1,
2727
show_progress_bar: bool = True,
2828
use_batched_sampling: bool = True,
29+
batch_size: Optional[int] = None,
2930
**kwargs,
3031
):
3132
"""Run simulation-based calibration (SBC) (parallelized across sbc runs).
@@ -49,6 +50,8 @@ def run_sbc(
4950
`num_sbc_samples` inferences.
5051
show_progress_bar: whether to display a progress over sbc runs.
5152
use_batched_sampling: whether to use batched sampling for posterior samples.
53+
batch_size: batch size for batched sampling. Useful for batched sampling with
54+
large batches of xs for avoiding memory overflow.
5255
5356
Returns:
5457
ranks: ranks of the ground truth parameters under the inferred
@@ -89,6 +92,7 @@ def run_sbc(
8992
num_workers,
9093
show_progress_bar,
9194
use_batched_sampling=use_batched_sampling,
95+
batch_size=batch_size,
9296
)
9397

9498
# take a random draw from each posterior to get data averaged posterior samples.

sbi/diagnostics/tarp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def run_tarp(
2929
num_bins: Optional[int] = 30,
3030
z_score_theta: bool = True,
3131
use_batched_sampling: bool = True,
32+
batch_size: Optional[int] = None,
3233
) -> Tuple[Tensor, Tensor]:
3334
"""
3435
Estimates coverage of samples given true values thetas with the TARP method.
@@ -56,6 +57,8 @@ def run_tarp(
5657
If ``None``, then ``num_sims // 10`` bins are used.
5758
z_score_theta : whether to normalize parameters before coverage test.
5859
use_batched_sampling: whether to use batched sampling for posterior samples.
60+
batch_size: batch size for batched sampling. Useful for batched sampling with
61+
large batches of xs for avoiding memory overflow.
5962
6063
Returns:
6164
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
@@ -70,6 +73,7 @@ def run_tarp(
7073
num_workers,
7174
show_progress_bar=show_progress_bar,
7275
use_batched_sampling=use_batched_sampling,
76+
batch_size=batch_size,
7377
)
7478
assert posterior_samples.shape == (
7579
num_posterior_samples,

sbi/utils/diagnostics_utils.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from typing import Optional
23

34
import torch
45
from joblib import Parallel, delayed
@@ -18,6 +19,7 @@ def get_posterior_samples_on_batch(
1819
num_workers: int = 1,
1920
show_progress_bar: bool = False,
2021
use_batched_sampling: bool = True,
22+
batch_size: Optional[int] = None,
2123
) -> Tensor:
2224
"""Get posterior samples for a batch of xs.
2325
@@ -28,22 +30,37 @@ def get_posterior_samples_on_batch(
2830
num_workers: number of workers to use for parallelization.
2931
show_progress_bars: whether to show progress bars.
3032
use_batched_sampling: whether to use batched sampling if possible.
31-
33+
batch_size: batch size for batched sampling. Useful for batched sampling with
34+
large batches of xs for avoiding memory overflow.
3235
Returns:
3336
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
3437
"""
35-
batch_size = len(xs)
38+
num_xs = len(xs)
39+
if batch_size is None:
40+
batch_size = num_xs
3641

37-
# Try using batched sampling when implemented.
38-
try:
39-
# has shape (num_samples, batch_size, dim_parameters)
40-
if use_batched_sampling:
41-
posterior_samples = posterior.sample_batched(
42-
sample_shape, x=xs, show_progress_bars=show_progress_bar
42+
if use_batched_sampling:
43+
try:
44+
# distribute the batch of xs into smaller batches
45+
batched_xs = xs.split(batch_size)
46+
posterior_samples = torch.cat(
47+
[ # has shape (num_samples, num_xs, dim_parameters)
48+
posterior.sample_batched(
49+
sample_shape, x=xs_batch, show_progress_bars=show_progress_bar
50+
)
51+
for xs_batch in batched_xs
52+
],
53+
dim=1,
4354
)
44-
else:
45-
raise NotImplementedError
46-
except (NotImplementedError, AssertionError):
55+
except (NotImplementedError, AssertionError):
56+
warnings.warn(
57+
"Batched sampling not implemented for this posterior. "
58+
"Falling back to non-batched sampling.",
59+
stacklevel=2,
60+
)
61+
use_batched_sampling = False
62+
63+
if not use_batched_sampling:
4764
# We need a function with extra training step for new x for VIPosterior.
4865
def sample_fun(
4966
posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0
@@ -57,13 +74,13 @@ def sample_fun(
5774
if isinstance(posterior, (VIPosterior, MCMCPosterior)):
5875
warnings.warn(
5976
"Using non-batched sampling. Depending on the number of different xs "
60-
f"( {batch_size}) and the number of parallel workers {num_workers}, "
61-
"this might be slow.",
77+
f"( {num_xs}) and the number of parallel workers {num_workers}, "
78+
"this might take a lot of time.",
6279
stacklevel=2,
6380
)
6481

6582
# Run in parallel with progress bar.
66-
seeds = torch.randint(0, 2**32, (batch_size,))
83+
seeds = torch.randint(0, 2**32, (num_xs,))
6784
outputs = list(
6885
tqdm(
6986
Parallel(return_as="generator", n_jobs=num_workers)(
@@ -72,7 +89,7 @@ def sample_fun(
7289
),
7390
disable=not show_progress_bar,
7491
total=len(xs),
75-
desc=f"Sampling {batch_size} times {sample_shape} posterior samples.",
92+
desc=f"Sampling {num_xs} times {sample_shape} posterior samples.",
7693
)
7794
) # (batch_size, num_samples, dim_parameters)
7895
# Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
@@ -81,8 +98,8 @@ def sample_fun(
8198
).permute(1, 0, 2)
8299

83100
assert posterior_samples.shape[:2] == sample_shape + (
84-
batch_size,
85-
), f"""Expected batched posterior samples of shape {
86-
sample_shape + (batch_size,)
87-
} got {posterior_samples.shape[:2]}."""
101+
num_xs,
102+
), f"""Expected batched posterior samples of shape {sample_shape + (num_xs,)} got {
103+
posterior_samples.shape[:2]
104+
}."""
88105
return posterior_samples

0 commit comments

Comments
 (0)