Skip to content

Commit e989d76

Browse files
committed
refactor: improve warning clarity and formatting
1 parent f4d7068 commit e989d76

File tree

8 files changed

+68
-72
lines changed

8 files changed

+68
-72
lines changed

sbi/diagnostics/sbc.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ def run_sbc(
5656

5757
if num_sbc_samples < 100:
5858
warnings.warn(
59-
"""Number of SBC samples should be on the order of 100s to give realiable
60-
results.""",
59+
"Number of SBC samples should be on the order of 100s to give realiable "
60+
"results.",
6161
stacklevel=2,
6262
)
6363
if num_posterior_samples < 100:
6464
warnings.warn(
65-
"""Number of posterior samples for ranking should be on the order
66-
of 100s to give reliable SBC results.""",
65+
"Number of posterior samples for ranking should be on the order "
66+
"of 100s to give reliable SBC results.",
6767
stacklevel=2,
6868
)
6969

@@ -73,8 +73,8 @@ def run_sbc(
7373

7474
if "sbc_batch_size" in kwargs:
7575
warnings.warn(
76-
"""`sbc_batch_size` is deprecated and will be removed in future versions.
77-
Use `num_workers` instead.""",
76+
"`sbc_batch_size` is deprecated and will be removed in future versions."
77+
" Use `num_workers` instead.",
7878
DeprecationWarning,
7979
stacklevel=2,
8080
)
@@ -182,8 +182,8 @@ def get_nltp(thetas: Tensor, xs: Tensor, posterior: NeuralPosterior) -> Tensor:
182182

183183
if unnormalized_log_prob:
184184
warnings.warn(
185-
"""Note that log probs of the true parameters under the posteriors
186-
are not normalized because the posterior used is likelihood-based.""",
185+
"Note that log probs of the true parameters under the posteriors are not "
186+
"normalized because the posterior used is likelihood-based.",
187187
stacklevel=2,
188188
)
189189

@@ -216,9 +216,9 @@ def check_sbc(
216216
"""
217217
if ranks.shape[0] < 100:
218218
warnings.warn(
219-
"""You are computing SBC checks with less than 100 samples. These checks
220-
should be based on a large number of test samples theta_o, x_o. We
221-
recommend using at least 100.""",
219+
"You are computing SBC checks with less than 100 samples. These checks"
220+
" should be based on a large number of test samples theta_o, x_o. We"
221+
" recommend using at least 100.",
222222
stacklevel=2,
223223
)
224224

@@ -315,9 +315,8 @@ def check_uniformity_c2st(
315315
c2st_std = c2st_scores.std(0, correction=0 if num_repetitions == 1 else 1)
316316
if (c2st_std > 0.05).any():
317317
warnings.warn(
318-
f"""C2ST score variability is larger than {0.05}: std={c2st_scores.std(0)},
319-
result may be unreliable. Consider increasing the number of samples.
320-
""",
318+
f"C2ST score variability is larger than {0.05}: std={c2st_scores.std(0)}, "
319+
"result may be unreliable. Consider increasing the number of samples.",
321320
stacklevel=2,
322321
)
323322

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def __init__(
136136

137137
if init_strategy_num_candidates is not None:
138138
warn(
139-
"""Passing `init_strategy_num_candidates` is deprecated as of sbi
140-
v0.19.0. Instead, use e.g.,
141-
`init_strategy_parameters={"num_candidate_samples": 1000}`""",
139+
"Passing `init_strategy_num_candidates` is deprecated as of sbi "
140+
"v0.19.0. Instead, use e.g., `init_strategy_parameters "
141+
f"={'num_candidate_samples': 1000}`",
142142
stacklevel=2,
143143
)
144144
self.init_strategy_parameters["num_candidate_samples"] = (
@@ -194,9 +194,8 @@ def log_prob(
194194
`len($\theta$)`-shaped log-probability.
195195
"""
196196
warn(
197-
"""`.log_prob()` is deprecated for methods that can only evaluate the
198-
log-probability up to a normalizing constant. Use `.potential()`
199-
instead.""",
197+
"`.log_prob()` is deprecated for methods that can only evaluate the "
198+
"log-probability up to a normalizing constant. Use `.potential()` instead.",
200199
stacklevel=2,
201200
)
202201
warn("The log-probability is unnormalized!", stacklevel=2)
@@ -264,9 +263,9 @@ def sample(
264263
)
265264
if init_strategy_num_candidates is not None:
266265
warn(
267-
"""Passing `init_strategy_num_candidates` is deprecated as of sbi
268-
v0.19.0. Instead, use e.g.,
269-
`init_strategy_parameters={"num_candidate_samples": 1000}`""",
266+
"Passing `init_strategy_num_candidates` is deprecated as of sbi"
267+
"v0.19.0. Instead, use e.g.,"
268+
f"`init_strategy_parameters={"num_candidate_samples": 1000}`",
270269
stacklevel=2,
271270
)
272271
self.init_strategy_parameters["num_candidate_samples"] = (
@@ -275,7 +274,7 @@ def sample(
275274
if sample_with is not None:
276275
raise ValueError(
277276
f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
278-
f"`sample_with` is no longer supported. You have to rerun "
277+
"`sample_with` is no longer supported. You have to rerun "
279278
f"`.build_posterior(sample_with={sample_with}).`"
280279
)
281280
if mcmc_method is not None:
@@ -426,9 +425,9 @@ def sample_batched(
426425
# warn if num_chains is larger than num requested samples
427426
if num_chains > torch.Size(sample_shape).numel():
428427
warnings.warn(
429-
f"""Passed num_chains {num_chains} is larger than the number of
430-
requested samples {torch.Size(sample_shape).numel()}, resetting
431-
it to {torch.Size(sample_shape).numel()}.""",
428+
"The passed number of MCMC chains is larger than the number of "
429+
f"requested samples: {num_chains} > {torch.Size(sample_shape).numel()},"
430+
f" resetting it to {torch.Size(sample_shape).numel()}.",
432431
stacklevel=2,
433432
)
434433
num_chains = torch.Size(sample_shape).numel()
@@ -453,12 +452,11 @@ def sample_batched(
453452
num_chains_extended = batch_size * num_chains
454453
if num_chains_extended > 100:
455454
warnings.warn(
456-
f"""Note that for batched sampling, we use {num_chains} for each
457-
x in the batch. With the given settings, this results in a
458-
large number of chains ({num_chains_extended}), This can be
459-
large number of chains ({num_chains_extended}), which can be
460-
slow and memory-intensive. Consider reducing the number of
461-
chains.""",
455+
"Note that for batched sampling, we use num_chains many chains for each"
456+
" x in the batch. With the given settings, this results in a large "
457+
f"number large number of chains ({num_chains_extended}), which can be "
458+
"slow and memory-intensive for vectorized MCMC. Consider reducing the "
459+
"number of chains.",
462460
stacklevel=2,
463461
)
464462
init_strategy_parameters["num_return_samples"] = num_chains_extended
@@ -905,8 +903,8 @@ def _prepare_potential(self, method: str) -> Callable:
905903
else:
906904
if "hmc" in method or "nuts" in method:
907905
warn(
908-
"""The kwargs "hmc" and "nuts" are deprecated. Use "hmc_pyro",
909-
"nuts_pyro", "hmc_pymc", or "nuts_pymc" instead.""",
906+
"The kwargs 'hmc' and 'nuts' are deprecated. Use 'hmc_pyro',"
907+
"'nuts_pyro', 'hmc_pymc', or 'nuts_pymc' instead.",
910908
DeprecationWarning,
911909
stacklevel=2,
912910
)

sbi/inference/posteriors/rejection_posterior.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,8 @@ def log_prob(
8484
`len($\theta$)`-shaped log-probability.
8585
"""
8686
warn(
87-
"""`.log_prob()` is deprecated for methods that can only evaluate the
88-
log-probability up to a normalizing constant. Use `.potential()`
89-
instead.""",
87+
"`.log_prob()` is deprecated for methods that can only evaluate the "
88+
"log-probability up to a normalizing constant. Use `.potential()` instead.",
9089
stacklevel=2,
9190
)
9291
warn("The log-probability is unnormalized!", stacklevel=2)

sbi/neural_nets/mnle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ def build_mnle(
124124
check_data_device(batch_x, batch_y)
125125

126126
warnings.warn(
127-
"""The mixed neural likelihood estimator assumes that x contains
128-
continuous data in the first n-1 columns (e.g., reaction times) and
129-
categorical data in the last column (e.g., corresponding choices). If
130-
this is not the case for the passed `x` do not use this function.""",
127+
"The mixed neural likelihood estimator assumes that x contains "
128+
"continuous data in the first n-1 columns (e.g., reaction times) and "
129+
"categorical data in the last column (e.g., corresponding choices). If "
130+
"this is not the case for the passed `x` do not use this function.",
131131
stacklevel=2,
132132
)
133133
# Separate continuous and discrete data.

sbi/samplers/mcmc/slice_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,8 @@ def __init__(
394394
# TODO: implement parallelization across batches of chains.
395395
if num_workers > 1:
396396
warn(
397-
"""Parallelization of vectorized slice sampling not implement, running
398-
serially.""",
397+
"Parallelization of vectorized slice sampling not implement, running "
398+
"serially.",
399399
stacklevel=2,
400400
)
401401
self._reset()

sbi/utils/sbiutils.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -
4343
# Check we do have different data in the batch
4444
if num_unique == 1:
4545
warnings.warn(
46-
"""Beware that there is only a single unique element in the simulated data.
47-
If this is intended, make sure to set `z_score_x='none'` as z-scoring would
48-
result in NaNs""",
46+
"Beware that there is only a single unique element in the simulated data. "
47+
"If this is intended, make sure to set `z_score_x='none'` as z-scoring "
48+
"would result in NaNs",
4949
UserWarning,
5050
stacklevel=2,
5151
)
@@ -61,13 +61,14 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -
6161

6262
if num_unique_z < num_unique * (1 - duplicate_tolerance):
6363
warnings.warn(
64-
"""Z-scoring these simulation outputs resulted in {num_unique_z} unique
65-
datapoints. Before z-scoring, it had been {num_unique}. This can occur
66-
due to numerical inaccuracies when the data covers a large range of
67-
values. Consider either setting `z_score_x=False` (but beware that this
68-
can be problematic for training the NN) or exclude outliers from your
69-
dataset. Note: if you have already set `z_score_x=False`, this warning
70-
will still be displayed, but you can ignore it.""",
64+
"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
65+
"datapoints. Before z-scoring, it had been {num_unique}. This can "
66+
"occur due to numerical inaccuracies when the data covers a large "
67+
"range of values. Consider either setting `z_score_x=False` (but "
68+
"beware that this can be problematic for training the NN) or exclude "
69+
"outliers from your dataset. Note: if you have already set "
70+
"`z_score_x=False`, this warning will still be displayed, but you can"
71+
" ignore it.",
7172
UserWarning,
7273
stacklevel=2,
7374
)
@@ -406,11 +407,11 @@ def warn_on_batched_x(batch_size):
406407
if batch_size > 1:
407408
warnings.warn(
408409
f"An x with a batch size of {batch_size} was passed. "
409-
+ """Unless you are using `sample_batched` or `log_prob_batched`, this will
410-
be interpreted as a batch of independent and identically distributed data
411-
X={x_1, ..., x_n}, i.e., data generated based on the same underlying
412-
(unknown) parameter. The resulting posterior will be with respect to entire
413-
batch, i.e,. p(theta | X).""",
410+
"Unless you are using `sample_batched` or `log_prob_batched`, this will"
411+
"be interpreted as a batch of independent and identically distributed data"
412+
"X={x_1, ..., x_n}, i.e., data generated based on the same underlying"
413+
"(unknown) parameter. The resulting posterior will be with respect to"
414+
" the entire batch, i.e,. p(theta | X).",
414415
stacklevel=2,
415416
)
416417

@@ -714,9 +715,9 @@ def mcmc_transform(
714715
# does not implement support.
715716
# AttributeError -> Custom distribution that has no support attribute.
716717
warnings.warn(
717-
"""The passed prior has no support property, transform will be
718-
constructed from mean and std. If the passed prior is supposed to be
719-
bounded consider implementing the prior.support property.""",
718+
"The passed prior has no support property, transform will be "
719+
"constructed from mean and std. If the passed prior is supposed to be "
720+
"bounded consider implementing the prior.support property.",
720721
stacklevel=2,
721722
)
722723
has_support = False
@@ -749,9 +750,8 @@ def mcmc_transform(
749750
# does not implement mean, e.g., TransformedDistribution.
750751
# AttributeError -> Custom distribution that has no mean/std attribute.
751752
warnings.warn(
752-
"""The passed prior has no mean or stddev attribute, estimating
753-
them from samples to build affimed standardizing
754-
transform.""",
753+
"The passed prior has no mean or stddev attribute, estimating "
754+
"them from samples to build affimed standardizing transform.",
755755
stacklevel=2,
756756
)
757757
theta = prior.sample(torch.Size((num_prior_samples_for_zscoring,)))

sbi/utils/user_input_checks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def process_prior(
7171
# If prior is a sequence, assume independent components and check as PyTorch prior.
7272
if isinstance(prior, Sequence):
7373
warnings.warn(
74-
f"""Prior was provided as a sequence of {len(prior)} priors. They will be
75-
interpreted as independent of each other and matched in order to the
76-
components of the parameter.""",
74+
f"Prior was provided as a sequence of {len(prior)} priors. They will be "
75+
"interpreted as independent of each other and matched in order to the "
76+
"components of the parameter.",
7777
stacklevel=2,
7878
)
7979
# process individual priors

sbi/utils/user_input_checks_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def _set_mean_and_variance(self):
7777
** 2
7878
)
7979
warnings.warn(
80-
"""Prior is lacking variance attribute, estimating prior variance from
81-
samples...""",
80+
"Prior is lacking variance attribute, estimating prior variance from "
81+
"samples.",
8282
UserWarning,
8383
stacklevel=2,
8484
)
@@ -333,8 +333,8 @@ def build_support(
333333
if lower_bound is None and upper_bound is None:
334334
support = constraints.real
335335
warnings.warn(
336-
"""No prior bounds were passed, consider passing lower_bound
337-
and / or upper_bound if your prior has bounded support.""",
336+
"No prior bounds were passed, consider passing lower_bound "
337+
"and / or upper_bound if your prior has bounded support.",
338338
stacklevel=2,
339339
)
340340
# Only lower bound is specified.

0 commit comments

Comments
 (0)