Skip to content

Commit d82c103

Browse files
committed
test: update tests to work with updated build_posterior logic
1 parent be39800 commit d82c103

10 files changed

+77
-34
lines changed

tests/embedding_net_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from sbi import utils
1515
from sbi.inference import NLE, NPE, NRE, simulate_for_sbi
16+
from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters
1617
from sbi.neural_nets import classifier_nn, likelihood_nn, posterior_nn
1718
from sbi.neural_nets.embedding_nets import (
1819
CNNEmbedding,
@@ -81,8 +82,9 @@ def test_embedding_net_api(
8182

8283
_ = inference.append_simulations(theta, x).train(max_num_epochs=2)
8384
posterior = inference.build_posterior(
84-
mcmc_method="slice_np_vectorized",
85-
mcmc_parameters=mcmc_params_fast,
85+
posterior_parameters=MCMCPosteriorParameters(
86+
method="slice_np_vectorized", **mcmc_params_fast
87+
)
8688
).set_default_x(x_o)
8789

8890
s = posterior.sample((1,))

tests/inference_on_device_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
)
3434
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
3535
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
36+
from sbi.inference.posteriors.posterior_parameters import (
37+
MCMCPosteriorParameters,
38+
)
3639
from sbi.inference.potentials.base_potential import BasePotential
3740
from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential
3841
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
@@ -190,9 +193,9 @@ def simulator(theta):
190193
# mcmc cases
191194
if sampling_method in ["slice_np", "slice_np_vectorized", "nuts_pymc"]:
192195
posterior = inferer.build_posterior(
193-
sample_with="mcmc",
194-
mcmc_method=sampling_method,
195-
mcmc_parameters=mcmc_params_fast,
196+
posterior_parameters=MCMCPosteriorParameters(
197+
method=sampling_method, **mcmc_params_fast
198+
)
196199
)
197200
elif sampling_method in ["rejection", "direct"]:
198201
# all other cases: rejection, direct

tests/linearGaussian_snle_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
VIPosterior,
2020
likelihood_estimator_based_potential,
2121
)
22+
from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters
2223
from sbi.neural_nets import likelihood_nn
2324
from sbi.simulators.linear_gaussian import (
2425
diagonal_linear_gaussian,
@@ -64,8 +65,9 @@ def test_api_nle_multiple_trials_and_rounds_map(
6465
for num_trials in [1, 3]:
6566
x_o = zeros((num_trials, num_dim))
6667
posterior = inference.build_posterior(
67-
mcmc_method="slice_np_vectorized",
68-
mcmc_parameters=mcmc_params_fast,
68+
posterior_parameters=MCMCPosteriorParameters(
69+
method="slice_np_vectorized", **mcmc_params_fast
70+
)
6971
).set_default_x(x_o)
7072
posterior.sample(sample_shape=(num_samples,))
7173
proposals.append(posterior)

tests/linearGaussian_snre_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
VIPosterior,
2424
ratio_estimator_based_potential,
2525
)
26+
from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters
2627
from sbi.neural_nets.ratio_estimators import RatioEstimator
2728
from sbi.simulators.linear_gaussian import (
2829
diagonal_linear_gaussian,
@@ -65,8 +66,9 @@ def test_api_nre_multiple_trials_and_rounds_map(
6566
for num_trials in [1, 3]:
6667
x_o = zeros((num_trials, num_dim))
6768
posterior = inference.build_posterior(
68-
mcmc_method="slice_np_vectorized",
69-
mcmc_parameters=mcmc_params_fast,
69+
posterior_parameters=MCMCPosteriorParameters(
70+
method="slice_np_vectorized", **mcmc_params_fast
71+
)
7072
).set_default_x(x_o)
7173
posterior.sample(sample_shape=(num_samples,))
7274
proposals.append(posterior)

tests/linearGaussian_vector_field_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
simulate_for_sbi,
2222
vector_field_estimator_based_potential,
2323
)
24+
from sbi.inference.posteriors.posterior_parameters import VectorFieldPosteriorParameters
2425
from sbi.neural_nets.factory import flowmatching_nn
2526
from sbi.simulators import linear_gaussian
2627
from sbi.simulators.linear_gaussian import (
@@ -103,9 +104,9 @@ def test_c2st_vector_field_on_linearGaussian(
103104
posterior = inference.build_posterior(
104105
score_estimator,
105106
sample_with=method,
106-
vectorfield_sampling_parameters={
107-
"neural_ode_backend": "zuko",
108-
},
107+
posterior_parameters=VectorFieldPosteriorParameters(
108+
neural_ode_backend="zuko"
109+
),
109110
)
110111
posterior.set_default_x(x_o)
111112
samples = posterior.sample((num_samples,))

tests/mnle_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.distributions import Beta, Binomial, Distribution, Gamma
1111

1212
from sbi.inference import MNLE, MCMCPosterior
13+
from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters
1314
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
1415
from sbi.inference.posteriors.vi_posterior import VIPosterior
1516
from sbi.inference.potentials.base_potential import BasePotential
@@ -182,7 +183,9 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
182183
).sample((num_samples,), show_progress_bars=False)
183184

184185
posterior = trainer.build_posterior(
185-
prior=prior, sample_with=sampler, mcmc_parameters=mcmc_params_accurate
186+
prior=prior,
187+
sample_with=sampler,
188+
posterior_parameters=MCMCPosteriorParameters(**mcmc_params_accurate),
186189
)
187190
posterior.set_default_x(x_o)
188191
if sampler == "vi":

tests/posterior_nn_test.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
NRE_C,
2222
DirectPosterior,
2323
)
24-
from sbi.inference.posteriors.posterior_parameters import RejectionPosteriorParameters
24+
from sbi.inference.posteriors.posterior_parameters import (
25+
MCMCPosteriorParameters,
26+
RejectionPosteriorParameters,
27+
)
2528
from sbi.simulators.linear_gaussian import (
2629
diagonal_linear_gaussian,
2730
linear_gaussian,
@@ -195,9 +198,10 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
195198
x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim)
196199

197200
posterior = inference.build_posterior(
198-
sample_with="mcmc",
199-
mcmc_method="slice_np_vectorized",
200-
mcmc_parameters=mcmc_params_fast,
201+
posterior_parameters=MCMCPosteriorParameters(
202+
method="slice_np_vectorized",
203+
**mcmc_params_fast,
204+
)
201205
)
202206

203207
samples = posterior.sample_batched(
@@ -219,9 +223,10 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
219223
inference = snlre_method(prior=prior)
220224
_ = inference.append_simulations(theta, x).train()
221225
posterior = inference.build_posterior(
222-
sample_with="mcmc",
223-
mcmc_method="slice_np_vectorized",
224-
mcmc_parameters=mcmc_params_fast,
226+
posterior_parameters=MCMCPosteriorParameters(
227+
method="slice_np_vectorized",
228+
**mcmc_params_fast,
229+
)
225230
)
226231

227232
x_o = torch.stack([0.5 * ones(num_dim), -0.5 * ones(num_dim)], dim=0)

tests/posterior_parameters_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,12 @@ def test_invalid_literal_field_values():
385385
"""
386386

387387
MCMCPosteriorParameters(method="invalid")
388+
389+
390+
def test_if_warning_raised_for_deprecated_build_posterior_parameters(get_inference):
391+
"""
392+
Check if the build_posterior method raises a warning for deprecated parameters
393+
"""
394+
395+
with pytest.warns(DeprecationWarning):
396+
get_inference.build_posterior(mcmc_parameters={})

tests/save_and_load_test.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,27 @@
88

99
from sbi import utils as utils
1010
from sbi.inference import NLE, NPE, NRE
11+
from sbi.inference.posteriors.posterior_parameters import (
12+
DirectPosteriorParameters,
13+
MCMCPosteriorParameters,
14+
RejectionPosteriorParameters,
15+
VIPosteriorParameters,
16+
)
1117
from sbi.inference.posteriors.vi_posterior import VIPosterior
1218

1319

1420
@pytest.mark.parametrize(
15-
"inference_method, sampling_method",
21+
"inference_method, posterior_parameter",
1622
(
17-
(NPE, "direct"),
18-
pytest.param(NLE, "mcmc", marks=pytest.mark.mcmc),
19-
pytest.param(NRE, "mcmc", marks=pytest.mark.mcmc),
20-
pytest.param(NRE, "vi", marks=pytest.mark.mcmc),
21-
(NRE, "rejection"),
23+
(NPE, DirectPosteriorParameters),
24+
pytest.param(NLE, MCMCPosteriorParameters, marks=pytest.mark.mcmc),
25+
pytest.param(NRE, MCMCPosteriorParameters, marks=pytest.mark.mcmc),
26+
pytest.param(NRE, VIPosteriorParameters, marks=pytest.mark.mcmc),
27+
(NRE, RejectionPosteriorParameters),
2228
),
2329
)
2430
def test_picklability(
25-
inference_method, sampling_method: str, tmp_path, mcmc_params_fast
31+
inference_method, posterior_parameter, tmp_path, mcmc_params_fast
2632
):
2733
num_dim = 2
2834
prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
@@ -33,10 +39,14 @@ def test_picklability(
3339

3440
inference = inference_method(prior=prior)
3541
_ = inference.append_simulations(theta, x).train(max_num_epochs=1)
36-
posterior = inference.build_posterior(
37-
sample_with=sampling_method, mcmc_parameters=mcmc_params_fast
38-
).set_default_x(x_o)
39-
42+
if posterior_parameter is MCMCPosteriorParameters:
43+
posterior = inference.build_posterior(
44+
posterior_parameters=posterior_parameter(**mcmc_params_fast)
45+
).set_default_x(x_o)
46+
else:
47+
posterior = inference.build_posterior(
48+
posterior_parameters=posterior_parameter()
49+
).set_default_x(x_o)
4050
# After sample and log_prob, the posterior should still be picklable
4151
if isinstance(posterior, VIPosterior):
4252
posterior.train(max_num_iters=10)

tests/sbc_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from sbi.diagnostics import check_sbc, get_nltp, run_sbc
1515
from sbi.inference import NLE, NPE, NPSE
1616
from sbi.inference.posteriors.base_posterior import NeuralPosterior
17+
from sbi.inference.posteriors.posterior_parameters import (
18+
MCMCPosteriorParameters,
19+
VIPosteriorParameters,
20+
)
1721
from sbi.simulators.linear_gaussian import linear_gaussian
1822
from sbi.utils import BoxUniform, MultipleIndependent
1923
from tests.test_utils import PosteriorPotential, TractablePosterior
@@ -106,9 +110,11 @@ def simulator(theta):
106110
posterior_kwargs = {}
107111
if method == NLE:
108112
posterior_kwargs = {
109-
"sample_with": "mcmc" if sampler == "mcmc" else "vi",
110-
"mcmc_method": "slice_np_vectorized",
111-
"mcmc_parameters": mcmc_params_fast,
113+
"posterior_parameters": MCMCPosteriorParameters(
114+
method="slice_np_vectorized", **mcmc_params_fast
115+
)
116+
if sampler == "mcmc"
117+
else VIPosteriorParameters()
112118
}
113119

114120
posterior = train_inference_method(

0 commit comments

Comments
 (0)