Skip to content

Commit be39800

Browse files
committed
refactor: add deprecation warnings for deprecated build_posterior arguments
1 parent a406410 commit be39800

File tree

8 files changed

+151
-4
lines changed

8 files changed

+151
-4
lines changed

sbi/inference/trainers/fmpe/fmpe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from sbi.neural_nets import flowmatching_nn
1818
from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator
19+
from sbi.utils.sbiutils import warn_if_deprecated
1920

2021

2122
class FMPE(VectorFieldTrainer):
@@ -98,6 +99,14 @@ def build_posterior(
9899
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods.
99100
"""
100101

102+
warn_if_deprecated(
103+
self.build_posterior,
104+
locals(),
105+
{
106+
"vectorfield_sampling_parameters",
107+
},
108+
)
109+
101110
return super().build_posterior(
102111
estimator=vector_field_estimator,
103112
prior=prior,

sbi/inference/trainers/nle/mnle.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
1616
from sbi.neural_nets.estimators import MixedDensityEstimator
1717
from sbi.sbi_types import TensorboardSummaryWriter
18-
from sbi.utils.sbiutils import del_entries
18+
from sbi.utils.sbiutils import del_entries, warn_if_deprecated
1919

2020

2121
class MNLE(LikelihoodEstimatorTrainer):
@@ -162,6 +162,20 @@ def build_posterior(
162162
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods
163163
(the returned log-probability is unnormalized).
164164
"""
165+
166+
warn_if_deprecated(
167+
self.build_posterior,
168+
locals(),
169+
{
170+
"mcmc_parameters",
171+
"vi_parameters",
172+
"rejection_sampling_parameters",
173+
"importance_sampling_parameters",
174+
"mcmc_method",
175+
"vi_method",
176+
},
177+
)
178+
165179
if density_estimator is not None:
166180
assert isinstance(
167181
density_estimator, MixedDensityEstimator

sbi/inference/trainers/nle/nle_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from sbi.sbi_types import TorchTransform
3333
from sbi.utils import check_estimator_arg, x_shape_from_simulation
34+
from sbi.utils.sbiutils import warn_if_deprecated
3435
from sbi.utils.torchutils import assert_all_finite
3536

3637

@@ -359,6 +360,19 @@ def build_posterior(
359360
(the returned log-probability is unnormalized).
360361
"""
361362

363+
warn_if_deprecated(
364+
self.build_posterior,
365+
locals(),
366+
{
367+
"mcmc_parameters",
368+
"vi_parameters",
369+
"rejection_sampling_parameters",
370+
"importance_sampling_parameters",
371+
"mcmc_method",
372+
"vi_method",
373+
},
374+
)
375+
362376
return super().build_posterior(
363377
density_estimator,
364378
prior,

sbi/inference/trainers/npe/mnpe.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sbi.inference.trainers.npe.npe_c import NPE_C
1717
from sbi.neural_nets.estimators import MixedDensityEstimator
1818
from sbi.sbi_types import TensorboardSummaryWriter
19-
from sbi.utils.sbiutils import del_entries
19+
from sbi.utils.sbiutils import del_entries, warn_if_deprecated
2020

2121

2222
class MNPE(NPE_C):
@@ -158,6 +158,21 @@ def build_posterior(
158158
Returns:
159159
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods.
160160
"""
161+
162+
warn_if_deprecated(
163+
self.build_posterior,
164+
locals(),
165+
{
166+
"direct_sampling_parameters",
167+
"mcmc_parameters",
168+
"vi_parameters",
169+
"rejection_sampling_parameters",
170+
"importance_sampling_parameters",
171+
"mcmc_method",
172+
"vi_method",
173+
},
174+
)
175+
161176
if density_estimator is not None:
162177
assert isinstance(
163178
density_estimator, MixedDensityEstimator

sbi/inference/trainers/npe/npe_base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
validate_theta_and_x,
4747
warn_if_zscoring_changes_data,
4848
)
49-
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior
49+
from sbi.utils.sbiutils import (
50+
ImproperEmpirical,
51+
mask_sims_from_prior,
52+
warn_if_deprecated,
53+
)
5054
from sbi.utils.torchutils import assert_all_finite
5155

5256

@@ -529,6 +533,20 @@ def build_posterior(
529533
(the returned log-probability is unnormalized).
530534
"""
531535

536+
warn_if_deprecated(
537+
self.build_posterior,
538+
locals(),
539+
{
540+
"direct_sampling_parameters",
541+
"mcmc_parameters",
542+
"vi_parameters",
543+
"rejection_sampling_parameters",
544+
"importance_sampling_parameters",
545+
"mcmc_method",
546+
"vi_method",
547+
},
548+
)
549+
532550
self._check_prior_for_rejection_sampling(
533551
prior, sample_with, posterior_parameters
534552
)

sbi/inference/trainers/npse/npse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator
1616
from sbi.neural_nets.factory import posterior_score_nn
17+
from sbi.utils.sbiutils import warn_if_deprecated
1718

1819

1920
class NPSE(VectorFieldTrainer):
@@ -111,6 +112,15 @@ def build_posterior(
111112
Returns:
112113
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods.
113114
"""
115+
116+
warn_if_deprecated(
117+
self.build_posterior,
118+
locals(),
119+
{
120+
"vectorfield_sampling_parameters",
121+
},
122+
)
123+
114124
return super().build_posterior(
115125
estimator=vector_field_estimator,
116126
prior=prior,

sbi/inference/trainers/nre/nre_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
check_estimator_arg,
3232
clamp_and_warn,
3333
)
34+
from sbi.utils.sbiutils import warn_if_deprecated
3435
from sbi.utils.torchutils import repeat_rows
3536

3637

@@ -405,6 +406,20 @@ def build_posterior(
405406
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods
406407
(the returned log-probability is unnormalized).
407408
"""
409+
410+
warn_if_deprecated(
411+
self.build_posterior,
412+
locals(),
413+
{
414+
"mcmc_parameters",
415+
"vi_parameters",
416+
"rejection_sampling_parameters",
417+
"importance_sampling_parameters",
418+
"mcmc_method",
419+
"vi_method",
420+
},
421+
)
422+
408423
return super().build_posterior(
409424
density_estimator,
410425
prior,

sbi/utils/sbiutils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import inspect
45
import logging
56
import random
67
import warnings
78
from math import pi
8-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
9+
from typing import (
10+
Any,
11+
Callable,
12+
Dict,
13+
List,
14+
Optional,
15+
Sequence,
16+
Set,
17+
Tuple,
18+
Type,
19+
Union,
20+
)
921

1022
import numpy as np
1123
import pyknos.nflows.transforms as nflows_tf
@@ -1006,3 +1018,43 @@ def seed_all_backends(seed: Optional[Union[int, Tensor]] = None) -> None:
10061018
torch.cuda.manual_seed(seed)
10071019
torch.backends.cudnn.deterministic = True # type: ignore
10081020
torch.backends.cudnn.benchmark = False # type: ignore
1021+
1022+
1023+
def warn_if_deprecated(
1024+
method: Callable, locals_dict: Dict[str, Any], deprecated_keys: Set
1025+
) -> None:
1026+
"""
1027+
Issues a warning if any deprecated parameters are used with non-default values.
1028+
1029+
This function compares the values of deprecated parameters (from `locals_dict`)
1030+
against their default values in the given `method` signature. If a deprecated
1031+
parameter is explicitly set to a non-default value, a `DeprecationWarning` is
1032+
raised.
1033+
1034+
Args:
1035+
method: The function whose parameters are checked.
1036+
locals_dict: The arguments of the function.
1037+
deprecated_keys: The names of the parameters that are deprecated.
1038+
1039+
"""
1040+
1041+
# Get the signature of the function
1042+
method_signature = inspect.signature(method)
1043+
1044+
used = []
1045+
for key in deprecated_keys:
1046+
if key in locals_dict and key in method_signature.parameters:
1047+
default_value = method_signature.parameters[key].default
1048+
1049+
# Compare value to default
1050+
if locals_dict[key] != default_value:
1051+
used.append(key)
1052+
1053+
if used:
1054+
warnings.warn(
1055+
f"The following arguments are deprecated and"
1056+
" will be removed in a future version: "
1057+
f"{', '.join(used)}. Please use `posterior_parameters` instead.",
1058+
DeprecationWarning,
1059+
stacklevel=2,
1060+
)

0 commit comments

Comments
 (0)