File tree Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -404,7 +404,13 @@ def log_prob(
404
404
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
405
405
support of the prior, -∞ (corresponding to 0 probability) outside.
406
406
"""
407
- self .potential_fn .set_x (self ._x_else_default_x (x ), ** (ode_kwargs or {}))
407
+ x = self ._x_else_default_x (x )
408
+ x = reshape_to_batch_event (x , self .vector_field_estimator .condition_shape )
409
+ is_iid = x .shape [0 ] > 1
410
+ self .potential_fn .set_x (
411
+ x ,
412
+ x_is_iid = is_iid ,
413
+ )
408
414
409
415
theta = ensure_theta_batched (torch .as_tensor (theta ))
410
416
return self .potential_fn (
Original file line number Diff line number Diff line change @@ -115,8 +115,6 @@ def set_x(
115
115
super ().set_x (x_o , x_is_iid )
116
116
self .iid_method = iid_method or self .iid_method
117
117
self .iid_params = iid_params
118
- # NOTE: Once IID potential evaluation is supported. This needs to be adapted.
119
- # See #1450.
120
118
if not x_is_iid and (self ._x_o is not None ):
121
119
self .flow = self .rebuild_flow (** ode_kwargs )
122
120
elif self ._x_o is not None :
@@ -163,6 +161,8 @@ def __call__(
163
161
),
164
162
dim = 0 ,
165
163
)
164
+ # Apply the adjustment for iid observations i.e. we have to subtract
165
+ # (n-1) times the log prior.
166
166
log_probs = iid_posteriors_prob - (n - 1 ) * self .prior .log_prob (
167
167
theta_density_estimator
168
168
).squeeze (- 1 )
You can’t perform that action at this time.
0 commit comments