8
8
9
9
from sbi import utils as utils
10
10
from sbi .inference import NLE , NPE , NRE
11
+ from sbi .inference .posteriors .posterior_parameters import (
12
+ DirectPosteriorParameters ,
13
+ MCMCPosteriorParameters ,
14
+ RejectionPosteriorParameters ,
15
+ VIPosteriorParameters ,
16
+ )
11
17
from sbi .inference .posteriors .vi_posterior import VIPosterior
12
18
13
19
14
20
@pytest .mark .parametrize (
15
- "inference_method, sampling_method " ,
21
+ "inference_method, posterior_parameter " ,
16
22
(
17
- (NPE , "direct" ),
18
- pytest .param (NLE , "mcmc" , marks = pytest .mark .mcmc ),
19
- pytest .param (NRE , "mcmc" , marks = pytest .mark .mcmc ),
20
- pytest .param (NRE , "vi" , marks = pytest .mark .mcmc ),
21
- (NRE , "rejection" ),
23
+ (NPE , DirectPosteriorParameters ),
24
+ pytest .param (NLE , MCMCPosteriorParameters , marks = pytest .mark .mcmc ),
25
+ pytest .param (NRE , MCMCPosteriorParameters , marks = pytest .mark .mcmc ),
26
+ pytest .param (NRE , VIPosteriorParameters , marks = pytest .mark .mcmc ),
27
+ (NRE , RejectionPosteriorParameters ),
22
28
),
23
29
)
24
30
def test_picklability (
25
- inference_method , sampling_method : str , tmp_path , mcmc_params_fast
31
+ inference_method , posterior_parameter , tmp_path , mcmc_params_fast
26
32
):
27
33
num_dim = 2
28
34
prior = utils .BoxUniform (low = - 2 * torch .ones (num_dim ), high = 2 * torch .ones (num_dim ))
@@ -33,10 +39,14 @@ def test_picklability(
33
39
34
40
inference = inference_method (prior = prior )
35
41
_ = inference .append_simulations (theta , x ).train (max_num_epochs = 1 )
36
- posterior = inference .build_posterior (
37
- sample_with = sampling_method , mcmc_parameters = mcmc_params_fast
38
- ).set_default_x (x_o )
39
-
42
+ if posterior_parameter is MCMCPosteriorParameters :
43
+ posterior = inference .build_posterior (
44
+ posterior_parameters = posterior_parameter (** mcmc_params_fast )
45
+ ).set_default_x (x_o )
46
+ else :
47
+ posterior = inference .build_posterior (
48
+ posterior_parameters = posterior_parameter ()
49
+ ).set_default_x (x_o )
40
50
# After sample and log_prob, the posterior should still be picklable
41
51
if isinstance (posterior , VIPosterior ):
42
52
posterior .train (max_num_iters = 10 )
0 commit comments