1
1
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2
2
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3
3
4
- from typing import Any , Callable , Dict , Optional , Union
4
+ from typing import Any , Callable , Dict , Literal , Optional , Union
5
5
6
6
from torch .distributions import Distribution
7
7
@@ -94,9 +94,19 @@ def build_posterior(
94
94
self ,
95
95
density_estimator : Optional [MixedDensityEstimator ] = None ,
96
96
prior : Optional [Distribution ] = None ,
97
- sample_with : str = "direct" ,
98
- mcmc_method : str = "slice_np_vectorized" ,
99
- vi_method : str = "rKL" ,
97
+ sample_with : Literal [
98
+ "mcmc" , "rejection" , "vi" , "importance" , "direct"
99
+ ] = "direct" ,
100
+ mcmc_method : Literal [
101
+ "slice_np" ,
102
+ "slice_np_vectorized" ,
103
+ "hmc_pyro" ,
104
+ "nuts_pyro" ,
105
+ "slice_pymc" ,
106
+ "hmc_pymc" ,
107
+ "nuts_pymc" ,
108
+ ] = "slice_np_vectorized" ,
109
+ vi_method : Literal ["rKL" , "fKL" , "IW" , "alpha" ] = "rKL" ,
100
110
direct_sampling_parameters : Optional [Dict [str , Any ]] = None ,
101
111
mcmc_parameters : Optional [Dict [str , Any ]] = None ,
102
112
vi_parameters : Optional [Dict [str , Any ]] = None ,
@@ -117,10 +127,14 @@ def build_posterior(
117
127
prior: Prior distribution.
118
128
sample_with: Method to use for sampling from the posterior. Must be one of
119
129
[`direct` | `mcmc` | `rejection` | `vi` | `importance`].
120
- mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
121
- `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
122
- implementation of slice sampling; select `hmc`, `nuts` or `slice` for
123
- Pyro-based sampling.
130
+ mcmc_method: Method used for MCMC sampling, one of `slice_np`,
131
+ `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
132
+ `hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
133
+ numpy implementation of slice sampling. `slice_np_vectorized` is
134
+ identical to `slice_np`, but if `num_chains>1`, the chains are
135
+ vectorized for `slice_np_vectorized` whereas they are run sequentially
136
+ for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
137
+ likewise the samplers ending on `_pymc` are using PyMC.
124
138
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`].
125
139
direct_sampling_parameters: Additional kwargs passed to `DirectPosterior`.
126
140
mcmc_parameters: Additional kwargs passed to `MCMCPosterior`.
0 commit comments