Skip to content

Commit 6b5994f

Browse files
committed
fix xfail test, fix deprecation warnings
1 parent 3681a15 commit 6b5994f

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

sbi/inference/trainers/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,13 +678,13 @@ def _raise_deprecation_warning(
678678

679679
# Check if deprecated parameters are used
680680
if (
681-
kwargs.get("mcmc_method") == default_mcmc_method
682-
or kwargs.get("mcmc_method") is None
681+
kwargs.get("mcmc_method") is not None
682+
and kwargs.get("mcmc_method") != default_mcmc_method
683683
):
684684
deprecated_params.append("mcmc_method")
685685
if (
686-
kwargs.get("vi_method") == default_vi_method
687-
or kwargs.get("vi_method") is None
686+
kwargs.get("vi_method") is not None
687+
and kwargs.get("vi_method") != default_vi_method
688688
):
689689
deprecated_params.append("vi_method")
690690

tests/linearGaussian_vector_field_test.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,7 @@ def test_vector_field_sde_ode_sampling_equivalence(vector_field_trained_model):
350350
@pytest.mark.parametrize(
351351
"iid_method, num_trial",
352352
[
353-
pytest.param(
354-
"fnpe",
355-
3,
356-
id="fnpe-3trials",
357-
marks=pytest.mark.xfail(reason="c2st to high, fixed in PR #1501/1544"),
358-
),
353+
pytest.param("fnpe", 3, id="fnpe-3trials"),
359354
pytest.param("gauss", 3, id="gauss-3trials"),
360355
pytest.param("auto_gauss", 8, id="auto_gauss-8trials"),
361356
pytest.param("auto_gauss", 16, id="auto_gauss-16trials"),

0 commit comments

Comments
 (0)