Skip to content

Commit 31f8ad6

Browse files
committed
refactor tests: map passing
1 parent 6012fa1 commit 31f8ad6

File tree

1 file changed

+21
-29
lines changed

1 file changed

+21
-29
lines changed

tests/linearGaussian_fmpe_test.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_c2st_fmpe_on_linearGaussian(num_dim: int, prior_str: str):
4646

4747
x_o = zeros(1, num_dim)
4848
num_samples = 1000
49-
num_simulations = 2500
49+
num_simulations = 4000
5050

5151
# likelihood_mean will be likelihood_shift+theta
5252
likelihood_shift = -1.0 * ones(num_dim)
@@ -73,14 +73,10 @@ def test_c2st_fmpe_on_linearGaussian(num_dim: int, prior_str: str):
7373
theta = prior.sample((num_simulations,))
7474
x = linear_gaussian(theta, likelihood_shift, likelihood_cov)
7575

76-
inference = FMPE(prior, show_progress_bars=False)
76+
inference = FMPE(prior, show_progress_bars=True)
7777

78-
posterior_estimator = inference.append_simulations(theta, x).train(
79-
training_batch_size=100
80-
)
81-
posterior = DirectPosterior(
82-
prior=prior, posterior_estimator=posterior_estimator
83-
).set_default_x(x_o)
78+
inference.append_simulations(theta, x).train(training_batch_size=100)
79+
posterior = inference.build_posterior().set_default_x(x_o)
8480
samples = posterior.sample((num_samples,))
8581

8682
# Compute the c2st and assert it is near chance level of 0.5.
@@ -104,6 +100,13 @@ def test_c2st_fmpe_on_linearGaussian(num_dim: int, prior_str: str):
104100
dkl < max_dkl
105101
), f"D-KL={dkl} is more than 2 stds above the average performance."
106102

103+
# test probs
104+
probs = posterior.log_prob(samples).exp()
105+
gt_probs = gt_posterior.log_prob(samples).exp()
106+
assert torch.allclose(
107+
probs, gt_probs, atol=0.2
108+
) # note that this is 0.1 for NPE.
109+
107110
elif prior_str == "uniform":
108111
# Check whether the returned probability outside of the support is zero.
109112
posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim)
@@ -161,12 +164,8 @@ def test_fmpe_with_different_models(model):
161164

162165
inference = FMPE(prior, density_estimator=estimator_build_fun)
163166

164-
posterior_estimator = inference.append_simulations(theta, x).train(
165-
training_batch_size=100
166-
)
167-
posterior = DirectPosterior(
168-
prior=prior, posterior_estimator=posterior_estimator
169-
).set_default_x(x_o)
167+
inference.append_simulations(theta, x).train(training_batch_size=100)
168+
posterior = inference.build_posterior().set_default_x(x_o)
170169
samples = posterior.sample((num_samples,))
171170

172171
# Compute the c2st and assert it is near chance level of 0.5.
@@ -349,17 +348,12 @@ def simulator(theta):
349348

350349

351350
@pytest.mark.slow
352-
@pytest.mark.xfail(
353-
reason="FMPE MAP failing in spite of accuracte c2st.",
354-
strict=True,
355-
raises=AssertionError,
356-
)
357351
def test_fmpe_map():
358352
"""Test whether fmpe can find the MAP of a simple linear Gaussian example."""
359353

360354
num_dim = 3
361355
x_o = zeros(1, num_dim)
362-
num_simulations = 2000
356+
num_simulations = 5000
363357

364358
likelihood_shift = -1.0 * ones(num_dim)
365359
likelihood_cov = 0.3 * eye(num_dim)
@@ -374,19 +368,17 @@ def test_fmpe_map():
374368
theta = prior.sample((num_simulations,))
375369
x = linear_gaussian(theta, likelihood_shift, likelihood_cov)
376370

377-
inference = FMPE(prior, show_progress_bars=False)
371+
inference = FMPE(prior, show_progress_bars=True)
378372

379-
posterior_estimator = inference.append_simulations(theta, x).train(
380-
training_batch_size=100
381-
)
382-
posterior = DirectPosterior(
383-
prior=prior, posterior_estimator=posterior_estimator
384-
).set_default_x(x_o)
373+
inference.append_simulations(theta, x).train(training_batch_size=100)
374+
posterior = inference.build_posterior().set_default_x(x_o)
385375

386-
map_ = posterior.map(num_init_samples=1_000, show_progress_bars=True)
376+
map_ = posterior.map(show_progress_bars=True, num_iter=20)
387377

388378
# Check whether the MAP is close to the ground truth.
389-
assert torch.allclose(map_, gt_posterior.mean, atol=0.1)
379+
assert torch.allclose(
380+
map_, gt_posterior.mean, atol=0.2
381+
), f"{map_} != {gt_posterior.mean}"
390382

391383

392384
def test_multi_round_handling_fmpe():

0 commit comments

Comments
 (0)