Skip to content

Commit 2e1509e

Browse files
authored
refactor: Literals for build posterior method arguments (#1606)
* Fix spelling * refactor: Replace string annotations with Literals for posterior classes init methods * refactor(trainer): replace build_posterior string annotations with Literals * Add missing Protocol import
1 parent 046cdb0 commit 2e1509e

File tree

13 files changed

+135
-58
lines changed

13 files changed

+135
-58
lines changed

sbi/diagnostics/lc2st.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def __init__(
536536
References:
537537
[1] : https://arxiv.org/abs/2306.03580, https://github.com/JuliaLinhart/lc2st
538538
"""
539-
# Aplly the inverse transform to the thetas and the posterior samples
539+
# Apply the inverse transform to the thetas and the posterior samples
540540
self.flow_inverse_transform = flow_inverse_transform
541541
inverse_thetas = flow_inverse_transform(thetas, xs).detach()
542542
inverse_posterior_samples = flow_inverse_transform(

sbi/inference/posteriors/importance_posterior.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, Literal, Optional, Tuple, Union
55

66
import torch
77
from torch import Tensor
@@ -30,7 +30,7 @@ def __init__(
3030
potential_fn: Union[Callable, BasePotential],
3131
proposal: Any,
3232
theta_transform: Optional[TorchTransform] = None,
33-
method: str = "sir",
33+
method: Literal["sir", "importance"] = "sir",
3434
oversampling_factor: int = 32,
3535
max_sampling_batch_size: int = 10_000,
3636
device: Optional[Union[str, torch.device]] = None,

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from copy import deepcopy
66
from functools import partial
77
from math import ceil
8-
from typing import Any, Callable, Dict, Optional, Union
8+
from typing import Any, Callable, Dict, Literal, Optional, Union
99
from warnings import warn
1010

1111
import arviz as az
@@ -50,15 +50,23 @@ def __init__(
5050
potential_fn: Union[Callable, BasePotential],
5151
proposal: Any,
5252
theta_transform: Optional[TorchTransform] = None,
53-
method: str = "slice_np_vectorized",
53+
method: Literal[
54+
"slice_np",
55+
"slice_np_vectorized",
56+
"hmc_pyro",
57+
"nuts_pyro",
58+
"slice_pymc",
59+
"hmc_pymc",
60+
"nuts_pymc",
61+
] = "slice_np_vectorized",
5462
thin: int = -1,
5563
warmup_steps: int = 200,
5664
num_chains: int = 20,
57-
init_strategy: str = "resample",
65+
init_strategy: Literal["proposal", "sir", "resample"] = "resample",
5866
init_strategy_parameters: Optional[Dict[str, Any]] = None,
5967
init_strategy_num_candidates: Optional[int] = None,
6068
num_workers: int = 1,
61-
mp_context: str = "spawn",
69+
mp_context: Literal["fork", "spawn"] = "spawn",
6270
device: Optional[Union[str, torch.device]] = None,
6371
x_shape: Optional[torch.Size] = None,
6472
):

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
max_sampling_batch_size: int = 10_000,
5353
device: Optional[Union[str, torch.device]] = None,
5454
enable_transform: bool = True,
55-
sample_with: str = "sde",
55+
sample_with: Literal["ode", "sde"] = "sde",
5656
**kwargs,
5757
):
5858
"""

sbi/inference/posteriors/vi_posterior.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import copy
55
from copy import deepcopy
6-
from typing import Callable, Dict, Iterable, Optional, Union
6+
from typing import Callable, Dict, Iterable, Literal, Optional, Union
77

88
import numpy as np
99
import torch
@@ -60,9 +60,14 @@ def __init__(
6060
self,
6161
potential_fn: Union[BasePotential, CustomPotential],
6262
prior: Optional[TorchDistribution] = None, # type: ignore
63-
q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable] = "maf",
63+
q: Union[
64+
Literal["nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"],
65+
PyroTransformedDistribution,
66+
"VIPosterior",
67+
Callable,
68+
] = "maf",
6469
theta_transform: Optional[TorchTransform] = None,
65-
vi_method: str = "rKL",
70+
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
6671
device: Union[str, torch.device] = "cpu",
6772
x_shape: Optional[torch.Size] = None,
6873
parameters: Optional[Iterable] = None,

sbi/inference/trainers/fmpe/fmpe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# under the Apache License v2.0, see <https://www.apache.org/licenses/LICENSE-2.0>.
33

44

5-
from typing import Optional, Union
5+
from typing import Literal, Optional, Union
66

77
from torch.distributions import Distribution
88
from torch.utils.tensorboard.writer import SummaryWriter
@@ -67,7 +67,7 @@ def build_posterior(
6767
self,
6868
vector_field_estimator: Optional[ConditionalVectorFieldEstimator] = None,
6969
prior: Optional[Distribution] = None,
70-
sample_with: str = "ode",
70+
sample_with: Literal["ode", "sde"] = "ode",
7171
**kwargs,
7272
) -> VectorFieldPosterior:
7373
r"""Build posterior from the flow matching estimator.

sbi/inference/trainers/nle/mnle.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

44
from copy import deepcopy
5-
from typing import Any, Callable, Dict, Optional, Union
5+
from typing import Any, Callable, Dict, Literal, Optional, Union
66

77
from torch.distributions import Distribution
88

@@ -94,9 +94,17 @@ def build_posterior(
9494
self,
9595
density_estimator: Optional[TorchModule] = None,
9696
prior: Optional[Distribution] = None,
97-
sample_with: str = "mcmc",
98-
mcmc_method: str = "slice_np_vectorized",
99-
vi_method: str = "rKL",
97+
sample_with: Literal["mcmc", "rejection", "vi"] = "mcmc",
98+
mcmc_method: Literal[
99+
"slice_np",
100+
"slice_np_vectorized",
101+
"hmc_pyro",
102+
"nuts_pyro",
103+
"slice_pymc",
104+
"hmc_pymc",
105+
"nuts_pymc",
106+
] = "slice_np_vectorized",
107+
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
100108
mcmc_parameters: Optional[Dict[str, Any]] = None,
101109
vi_parameters: Optional[Dict[str, Any]] = None,
102110
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
@@ -114,10 +122,14 @@ def build_posterior(
114122
prior: Prior distribution.
115123
sample_with: Method to use for sampling from the posterior. Must be one of
116124
[`mcmc` | `rejection` | `vi`].
117-
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
118-
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
119-
implementation of slice sampling; select `hmc`, `nuts` or `slice` for
120-
Pyro-based sampling.
125+
mcmc_method: Method used for MCMC sampling, one of `slice_np`,
126+
`slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
127+
`hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
128+
numpy implementation of slice sampling. `slice_np_vectorized` is
129+
identical to `slice_np`, but if `num_chains>1`, the chains are
130+
vectorized for `slice_np_vectorized` whereas they are run sequentially
131+
for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
132+
likewise the samplers ending on `_pymc` are using PyMC.
121133
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note
122134
some of the methods admit a `mode seeking` property (e.g. rKL) whereas
123135
some admit a `mass covering` one (e.g fKL).

sbi/inference/trainers/nle/nle_base.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from abc import ABC
66
from copy import deepcopy
7-
from typing import Any, Callable, Dict, Optional, Union
7+
from typing import Any, Callable, Dict, Literal, Optional, Union
88

99
import torch
1010
from torch import Tensor
@@ -286,9 +286,17 @@ def build_posterior(
286286
self,
287287
density_estimator: Optional[ConditionalDensityEstimator] = None,
288288
prior: Optional[Distribution] = None,
289-
sample_with: str = "mcmc",
290-
mcmc_method: str = "slice_np_vectorized",
291-
vi_method: str = "rKL",
289+
sample_with: Literal["mcmc", "rejection", "vi", "importance"] = "mcmc",
290+
mcmc_method: Literal[
291+
"slice_np",
292+
"slice_np_vectorized",
293+
"hmc_pyro",
294+
"nuts_pyro",
295+
"slice_pymc",
296+
"hmc_pymc",
297+
"nuts_pymc",
298+
] = "slice_np_vectorized",
299+
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
292300
mcmc_parameters: Optional[Dict[str, Any]] = None,
293301
vi_parameters: Optional[Dict[str, Any]] = None,
294302
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
@@ -308,11 +316,15 @@ def build_posterior(
308316
If `None`, use the latest neural density estimator that was trained.
309317
prior: Prior distribution.
310318
sample_with: Method to use for sampling from the posterior. Must be one of
311-
[`mcmc` | `rejection` | `vi`].
312-
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
313-
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
314-
implementation of slice sampling; select `hmc`, `nuts` or `slice` for
315-
Pyro-based sampling.
319+
[`mcmc` | `rejection` | `vi` | `importance`].
320+
mcmc_method: Method used for MCMC sampling, one of `slice_np`,
321+
`slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
322+
`hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
323+
numpy implementation of slice sampling. `slice_np_vectorized` is
324+
identical to `slice_np`, but if `num_chains>1`, the chains are
325+
vectorized for `slice_np_vectorized` whereas they are run sequentially
326+
for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
327+
likewise the samplers ending on `_pymc` are using PyMC.
316328
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note
317329
some of the methods admit a `mode seeking` property (e.g. rKL) whereas
318330
some admit a `mass covering` one (e.g fKL).

sbi/inference/trainers/npe/mnpe.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Dict, Optional, Union
4+
from typing import Any, Callable, Dict, Literal, Optional, Union
55

66
from torch.distributions import Distribution
77

@@ -94,9 +94,19 @@ def build_posterior(
9494
self,
9595
density_estimator: Optional[MixedDensityEstimator] = None,
9696
prior: Optional[Distribution] = None,
97-
sample_with: str = "direct",
98-
mcmc_method: str = "slice_np_vectorized",
99-
vi_method: str = "rKL",
97+
sample_with: Literal[
98+
"mcmc", "rejection", "vi", "importance", "direct"
99+
] = "direct",
100+
mcmc_method: Literal[
101+
"slice_np",
102+
"slice_np_vectorized",
103+
"hmc_pyro",
104+
"nuts_pyro",
105+
"slice_pymc",
106+
"hmc_pymc",
107+
"nuts_pymc",
108+
] = "slice_np_vectorized",
109+
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
100110
direct_sampling_parameters: Optional[Dict[str, Any]] = None,
101111
mcmc_parameters: Optional[Dict[str, Any]] = None,
102112
vi_parameters: Optional[Dict[str, Any]] = None,
@@ -117,10 +127,14 @@ def build_posterior(
117127
prior: Prior distribution.
118128
sample_with: Method to use for sampling from the posterior. Must be one of
119129
[`direct` | `mcmc` | `rejection` | `vi` | `importance`].
120-
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
121-
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
122-
implementation of slice sampling; select `hmc`, `nuts` or `slice` for
123-
Pyro-based sampling.
130+
mcmc_method: Method used for MCMC sampling, one of `slice_np`,
131+
`slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
132+
`hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
133+
numpy implementation of slice sampling. `slice_np_vectorized` is
134+
identical to `slice_np`, but if `num_chains>1`, the chains are
135+
vectorized for `slice_np_vectorized` whereas they are run sequentially
136+
for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
137+
likewise the samplers ending on `_pymc` are using PyMC.
124138
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`].
125139
direct_sampling_parameters: Additional kwargs passed to `DirectPosterior`.
126140
mcmc_parameters: Additional kwargs passed to `MCMCPosterior`.

sbi/inference/trainers/npe/npe_base.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from abc import ABC, abstractmethod
66
from copy import deepcopy
7-
from typing import Any, Callable, Dict, Optional, Union
7+
from typing import Any, Callable, Dict, Literal, Optional, Union
88
from warnings import warn
99

1010
import torch
@@ -440,9 +440,19 @@ def build_posterior(
440440
self,
441441
density_estimator: Optional[ConditionalDensityEstimator] = None,
442442
prior: Optional[Distribution] = None,
443-
sample_with: str = "direct",
444-
mcmc_method: str = "slice_np_vectorized",
445-
vi_method: str = "rKL",
443+
sample_with: Literal[
444+
"mcmc", "rejection", "vi", "importance", "direct"
445+
] = "direct",
446+
mcmc_method: Literal[
447+
"slice_np",
448+
"slice_np_vectorized",
449+
"hmc_pyro",
450+
"nuts_pyro",
451+
"slice_pymc",
452+
"hmc_pymc",
453+
"nuts_pymc",
454+
] = "slice_np_vectorized",
455+
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
446456
direct_sampling_parameters: Optional[Dict[str, Any]] = None,
447457
mcmc_parameters: Optional[Dict[str, Any]] = None,
448458
vi_parameters: Optional[Dict[str, Any]] = None,
@@ -471,10 +481,14 @@ def build_posterior(
471481
prior: Prior distribution.
472482
sample_with: Method to use for sampling from the posterior. Must be one of
473483
[`direct` | `mcmc` | `rejection` | `vi` | `importance`].
474-
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
475-
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
476-
implementation of slice sampling; select `hmc`, `nuts` or `slice` for
477-
Pyro-based sampling.
484+
mcmc_method: Method used for MCMC sampling, one of `slice_np`,
485+
`slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
486+
`hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
487+
numpy implementation of slice sampling. `slice_np_vectorized` is
488+
identical to `slice_np`, but if `num_chains>1`, the chains are
489+
vectorized for `slice_np_vectorized` whereas they are run sequentially
490+
for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
491+
likewise the samplers ending on `_pymc` are using PyMC.
478492
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note
479493
some of the methods admit a `mode seeking` property (e.g. rKL) whereas
480494
some admit a `mass covering` one (e.g fKL).

0 commit comments

Comments
 (0)