Skip to content

Commit d558b16

Browse files
Fixing bug on merge
1 parent 43d4e4d commit d558b16

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,13 @@ def log_prob(
404404
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
405405
support of the prior, -∞ (corresponding to 0 probability) outside.
406406
"""
407-
self.potential_fn.set_x(self._x_else_default_x(x), **(ode_kwargs or {}))
407+
x = self._x_else_default_x(x)
408+
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
409+
is_iid = x.shape[0] > 1
410+
self.potential_fn.set_x(
411+
x,
412+
x_is_iid=is_iid,
413+
)
408414

409415
theta = ensure_theta_batched(torch.as_tensor(theta))
410416
return self.potential_fn(

sbi/inference/potentials/vector_field_potential.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ def set_x(
115115
super().set_x(x_o, x_is_iid)
116116
self.iid_method = iid_method or self.iid_method
117117
self.iid_params = iid_params
118-
# NOTE: Once IID potential evaluation is supported. This needs to be adapted.
119-
# See #1450.
120118
if not x_is_iid and (self._x_o is not None):
121119
self.flow = self.rebuild_flow(**ode_kwargs)
122120
elif self._x_o is not None:
@@ -163,6 +161,8 @@ def __call__(
163161
),
164162
dim=0,
165163
)
164+
# Apply the adjustment for iid observations i.e. we have to subtract
165+
# (n-1) times the log prior.
166166
log_probs = iid_posteriors_prob - (n - 1) * self.prior.log_prob(
167167
theta_density_estimator
168168
).squeeze(-1)

0 commit comments

Comments
 (0)