diff --git a/sbi/inference/trainers/base.py b/sbi/inference/trainers/base.py index 46238ecb6..558ee1c71 100644 --- a/sbi/inference/trainers/base.py +++ b/sbi/inference/trainers/base.py @@ -678,13 +678,13 @@ def _raise_deprecation_warning( # Check if deprecated parameters are used if ( - kwargs.get("mcmc_method") == default_mcmc_method - or kwargs.get("mcmc_method") is None + kwargs.get("mcmc_method") is not None + and kwargs.get("mcmc_method") != default_mcmc_method ): deprecated_params.append("mcmc_method") if ( - kwargs.get("vi_method") == default_vi_method - or kwargs.get("vi_method") is None + kwargs.get("vi_method") is not None + and kwargs.get("vi_method") != default_vi_method ): deprecated_params.append("vi_method") diff --git a/tests/linearGaussian_vector_field_test.py b/tests/linearGaussian_vector_field_test.py index 1184c7677..e9e5cc486 100644 --- a/tests/linearGaussian_vector_field_test.py +++ b/tests/linearGaussian_vector_field_test.py @@ -350,12 +350,7 @@ def test_vector_field_sde_ode_sampling_equivalence(vector_field_trained_model): @pytest.mark.parametrize( "iid_method, num_trial", [ - pytest.param( - "fnpe", - 3, - id="fnpe-3trials", - marks=pytest.mark.xfail(reason="c2st to high, fixed in PR #1501/1544"), - ), + pytest.param("fnpe", 3, id="fnpe-3trials"), pytest.param("gauss", 3, id="gauss-3trials"), pytest.param("auto_gauss", 8, id="auto_gauss-8trials"), pytest.param("auto_gauss", 16, id="auto_gauss-16trials"),