|
1 | 1 | # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
|
2 | 2 | # under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
|
3 | 3 |
|
| 4 | +import inspect |
4 | 5 | import logging
|
5 | 6 | import random
|
6 | 7 | import warnings
|
7 | 8 | from math import pi
|
8 |
| -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union |
| 9 | +from typing import ( |
| 10 | + Any, |
| 11 | + Callable, |
| 12 | + Dict, |
| 13 | + List, |
| 14 | + Optional, |
| 15 | + Sequence, |
| 16 | + Set, |
| 17 | + Tuple, |
| 18 | + Type, |
| 19 | + Union, |
| 20 | +) |
9 | 21 |
|
10 | 22 | import numpy as np
|
11 | 23 | import pyknos.nflows.transforms as nflows_tf
|
@@ -1006,3 +1018,43 @@ def seed_all_backends(seed: Optional[Union[int, Tensor]] = None) -> None:
|
1006 | 1018 | torch.cuda.manual_seed(seed)
|
1007 | 1019 | torch.backends.cudnn.deterministic = True # type: ignore
|
1008 | 1020 | torch.backends.cudnn.benchmark = False # type: ignore
|
| 1021 | + |
| 1022 | + |
| 1023 | +def warn_if_deprecated( |
| 1024 | + method: Callable, locals_dict: Dict[str, Any], deprecated_keys: Set |
| 1025 | +) -> None: |
| 1026 | + """ |
| 1027 | + Issues a warning if any deprecated parameters are used with non-default values. |
| 1028 | +
|
| 1029 | + This function compares the values of deprecated parameters (from `locals_dict`) |
| 1030 | + against their default values in the given `method` signature. If a deprecated |
| 1031 | + parameter is explicitly set to a non-default value, a `DeprecationWarning` is |
| 1032 | + raised. |
| 1033 | +
|
| 1034 | + Args: |
| 1035 | + method: The function whose parameters are checked. |
| 1036 | + locals_dict: The arguments of the function. |
| 1037 | + deprecated_keys: The names of the parameters that are deprecated. |
| 1038 | +
|
| 1039 | + """ |
| 1040 | + |
| 1041 | + # Get the signature of the function |
| 1042 | + method_signature = inspect.signature(method) |
| 1043 | + |
| 1044 | + used = [] |
| 1045 | + for key in deprecated_keys: |
| 1046 | + if key in locals_dict and key in method_signature.parameters: |
| 1047 | + default_value = method_signature.parameters[key].default |
| 1048 | + |
| 1049 | + # Compare value to default |
| 1050 | + if locals_dict[key] != default_value: |
| 1051 | + used.append(key) |
| 1052 | + |
| 1053 | + if used: |
| 1054 | + warnings.warn( |
| 1055 | + f"The following arguments are deprecated and" |
| 1056 | + " will be removed in a future version: " |
| 1057 | + f"{', '.join(used)}. Please use `posterior_parameters` instead.", |
| 1058 | + DeprecationWarning, |
| 1059 | + stacklevel=2, |
| 1060 | + ) |
0 commit comments