Skip to content

Commit 9e5bed4

Browse files
authored
fix: posterior potential iid handling (#1276)
1 parent 25dd902 commit 9e5bed4

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

sbi/inference/potentials/posterior_based_potential.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ def set_x(self, x_o: Optional[Tensor], x_is_iid: Optional[bool] = False):
8989
For posterior-based methods, `x_o` is not allowed to be iid, as we assume that
9090
iid `x` is handled by a Permutation Invariant embedding net.
9191
"""
92-
if x_is_iid:
92+
if x_is_iid and x_o is not None and x_o.shape[0] > 1:
9393
raise NotImplementedError(
94-
"For NPE, iid `x` must be handled by a Permutation Invariant embedding \
95-
net. Therefore, the iid dimension of `x` is added to the event\
96-
dimension of `x`. Please set `x_is_iid=False`."
94+
"For NPE, iid `x` must be handled by a permutation invariant embedding "
95+
"net. Therefore, the iid dimension of `x` is added to the event "
96+
"dimension of `x`. Please set `x_is_iid=False`."
9797
)
9898
else:
9999
super().set_x(x_o, x_is_iid=False)

tests/linearGaussian_fmpe_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ def simulator(theta):
287287
# Evaluate the conditional density be drawing samples and smoothing with a Gaussian
288288
# kde.
289289
potential_fn, theta_transform = posterior_estimator_based_potential(
290-
posterior_estimator, prior=prior, x_o=x_o
291-
)
290+
posterior_estimator, prior=prior
291+
).set_x(x_o, x_is_iid=False)
292292
(
293293
conditioned_potential_fn,
294294
restricted_tf,

tests/linearGaussian_snpe_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,8 @@ def simulator(theta):
558558
# Evaluate the conditional density be drawing samples and smoothing with a Gaussian
559559
# kde.
560560
potential_fn, theta_transform = posterior_estimator_based_potential(
561-
posterior_estimator, prior=prior, x_o=x_o
562-
)
561+
posterior_estimator, prior=prior
562+
).set_x(x_o, x_is_iid=False)
563563
(
564564
conditioned_potential_fn,
565565
restricted_tf,

0 commit comments

Comments
 (0)