Skip to content

Commit e53ed53

Browse files
committed
fix map error handling and tests.
1 parent 4b3fc61 commit e53ed53

File tree

5 files changed

+14
-11
lines changed

5 files changed

+14
-11
lines changed

sbi/inference/posteriors/base_posterior.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ def __init__(
5252
stacklevel=2,
5353
)
5454

55-
if not isinstance(potential_fn, BasePotential) and not isinstance(
56-
potential_fn, BasePotential
57-
):
55+
# Wrap as `CallablePotentialWrapper` if `potential_fn` is a Callable.
56+
if not isinstance(potential_fn, BasePotential):
5857
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
5958
for key in ["theta", "x_o"]:
6059
assert key in kwargs_of_callable, (

sbi/utils/sbiutils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,13 +945,13 @@ def gradient_ascent(
945945
try:
946946
optimize_inits.requires_grad_(False) # type: ignore
947947
gradient = potential_fn.gradient(optimize_inits)
948-
except NotImplementedError:
948+
except (NotImplementedError, AttributeError):
949949
optimize_inits.requires_grad_(True) # type: ignore
950950
probs = potential_fn(optimize_inits).squeeze()
951951
loss = probs.sum()
952952
loss.backward()
953953
gradient = optimize_inits.grad
954-
assert gradient is Tensor, "Gradient must be a tensor."
954+
assert isinstance(gradient, Tensor), "Gradient must be a tensor."
955955

956956
# Update the parameters with gradient descent.
957957
# See https://discuss.pytorch.org/t/updatation-of-parameters-without-using-optimizer-step/34244/2

tests/linearGaussian_npse_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def test_npse_iid_inference(num_trials):
206206

207207
@pytest.mark.slow
208208
@pytest.mark.xfail(
209-
raises=AssertionError, reason="MAP optimization via score not working accurately."
209+
raises=NotImplementedError,
210+
reason="MAP optimization via score not working accurately.",
210211
)
211212
def test_npse_map():
212213
num_dim = 2

tests/posterior_nn_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232
(
3333
0,
3434
1,
35-
pytest.param(2, marks=pytest.mark.xfail(raises=AssertionError)),
35+
pytest.param(
36+
2,
37+
marks=pytest.mark.xfail(
38+
raises=AssertionError,
39+
reason=".log_prob() supports only batch size 1 for x_o.",
40+
),
41+
),
3642
),
3743
)
3844
def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool):

tests/sbc_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
from sbi.analysis import sbc_rank_plot
1414
from sbi.diagnostics import check_sbc, get_nltp, run_sbc
1515
from sbi.inference import NPSE, SNLE, SNPE
16-
from sbi.simulators import linear_gaussian
17-
from sbi.simulators.linear_gaussian import (
18-
linear_gaussian,
19-
)
16+
from sbi.simulators.linear_gaussian import linear_gaussian
2017
from sbi.utils import BoxUniform, MultipleIndependent
2118
from tests.test_utils import PosteriorPotential, TractablePosterior
2219

0 commit comments

Comments
 (0)