Skip to content

Commit 3183f98

Browse files
committed
refactor: improve warning clarity and formatting
1 parent f4d7068 commit 3183f98

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,9 @@ def sample_batched(
426426
# warn if num_chains is larger than num requested samples
427427
if num_chains > torch.Size(sample_shape).numel():
428428
warnings.warn(
429-
f"""Passed num_chains {num_chains} is larger than the number of
430-
requested samples {torch.Size(sample_shape).numel()}, resetting
431-
it to {torch.Size(sample_shape).numel()}.""",
429+
"The passed number of MCMC chains is larger than the number of "
430+
f"requested samples: {num_chains} > {torch.Size(sample_shape).numel()},"
431+
f" resetting it to {torch.Size(sample_shape).numel()}.",
432432
stacklevel=2,
433433
)
434434
num_chains = torch.Size(sample_shape).numel()
@@ -453,12 +453,11 @@ def sample_batched(
453453
num_chains_extended = batch_size * num_chains
454454
if num_chains_extended > 100:
455455
warnings.warn(
456-
f"""Note that for batched sampling, we use {num_chains} for each
457-
x in the batch. With the given settings, this results in a
458-
large number of chains ({num_chains_extended}), This can be
459-
large number of chains ({num_chains_extended}), which can be
460-
slow and memory-intensive. Consider reducing the number of
461-
chains.""",
456+
"Note that for batched sampling, we use num_chains many chains for each"
457+
" x in the batch. With the given settings, this results in a large "
458+
f"number large number of chains ({num_chains_extended}), which can be "
459+
"slow and memory-intensive for vectorized MCMC. Consider reducing the "
460+
"number of chains.",
462461
stacklevel=2,
463462
)
464463
init_strategy_parameters["num_return_samples"] = num_chains_extended

sbi/utils/sbiutils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,11 @@ def warn_on_batched_x(batch_size):
406406
if batch_size > 1:
407407
warnings.warn(
408408
f"An x with a batch size of {batch_size} was passed. "
409-
+ """Unless you are using `sample_batched` or `log_prob_batched`, this will
410-
be interpreted as a batch of independent and identically distributed data
411-
X={x_1, ..., x_n}, i.e., data generated based on the same underlying
412-
(unknown) parameter. The resulting posterior will be with respect to entire
413-
batch, i.e,. p(theta | X).""",
409+
"Unless you are using `sample_batched` or `log_prob_batched`, this will"
410+
"be interpreted as a batch of independent and identically distributed data"
411+
"X={x_1, ..., x_n}, i.e., data generated based on the same underlying"
412+
"(unknown) parameter. The resulting posterior will be with respect to"
413+
" the entire batch, i.e,. p(theta | X).",
414414
stacklevel=2,
415415
)
416416

0 commit comments

Comments
 (0)