Skip to content

Commit e9355bf

Browse files
fix: precision matrix in jac score fn iid (#1636)
* Fix * Do not fully skip all iid_test but only what is currently adressed in other PRs and what is expected to fail * Additional skip for FMPE * xfail
1 parent ebb8c6f commit e9355bf

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

sbi/inference/potentials/score_fn_iid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ def marginal_denoising_posterior_precision_est_fn(
727727
std = self.vector_field_estimator.std_fn(time)
728728
cov0 = std**2 * jac + torch.eye(d)[None, None, :, :]
729729

730-
denoising_posterior_precision = m**2 / std**2 + torch.inverse(cov0)
730+
denoising_posterior_precision = m**2 / std**2 * torch.inverse(cov0)
731731

732732
return denoising_posterior_precision
733733

tests/linearGaussian_vector_field_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,18 +357,21 @@ def test_vector_field_sde_ode_sampling_equivalence(vector_field_trained_model):
357357
# TODO: Currently, c2st is too high for FMPE (e.g., > 3 number of observations),
358358
# so some tests are skipped so far. This seems to be an issue with the
359359
# neural network architecture and can be addressed in PR #1501
360-
@pytest.mark.skip(
361-
reason="c2st too high for some cases, has to be fixed in PR #1501 or #1544"
362-
)
363360
@pytest.mark.slow
364361
@pytest.mark.parametrize(
365362
"iid_method, num_trial",
366363
[
367-
pytest.param("fnpe", 3, id="fnpe-2trials"),
364+
pytest.param(
365+
"fnpe",
366+
3,
367+
id="fnpe-3trials",
368+
marks=pytest.mark.xfail(reason="c2st to high, fixed in PR #1501/1544"),
369+
),
368370
pytest.param("gauss", 3, id="gauss-3trials"),
369371
pytest.param("auto_gauss", 8, id="auto_gauss-8trials"),
370372
pytest.param("auto_gauss", 16, id="auto_gauss-16trials"),
371373
pytest.param("jac_gauss", 8, id="jac_gauss-8trials"),
374+
pytest.param("jac_gauss", 16, id="jac_gauss-16trials"),
372375
],
373376
)
374377
def test_vector_field_iid_inference(
@@ -377,6 +380,10 @@ def test_vector_field_iid_inference(
377380
"""
378381
Test whether NPSE and FMPE infers well a simple example with available ground truth.
379382
"""
383+
if vector_field_type == "fmpe":
384+
# TODO: Remove on merge
385+
pytest.xfail(reason="c2st to high, fixed in PR #1501/1544")
386+
380387
num_samples = 1000
381388

382389
# Extract data from fixture

0 commit comments

Comments
 (0)