Skip to content

Commit 77b4139

Browse files
committed
fix deprecation warning on default args
1 parent 314f45e commit 77b4139

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

sbi/inference/trainers/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,13 @@ def _raise_deprecation_warning(
665665

666666
deprecated_params = deprecated_params.copy()
667667

668-
is_default_mcmc_method = kwargs.get("mcmc_method") == "slice_np_vectorized"
669-
is_default_vi_method = kwargs.get("vi_method") == "rKL"
668+
is_default_mcmc_method = (
669+
kwargs.get("mcmc_method") == "slice_np_vectorized"
670+
or kwargs.get("mcmc_method") is None
671+
)
672+
is_default_vi_method = (
673+
kwargs.get("vi_method") == "rKL" or kwargs.get("vi_method") is None
674+
)
670675

671676
if not is_default_mcmc_method:
672677
deprecated_params.append("mcmc_method")

0 commit comments

Comments
 (0)