@@ -150,7 +150,7 @@ def __call__(
150
150
assert self .flows is not None , (
151
151
"Flows for each iid x are required for evaluating log_prob."
152
152
)
153
- n = self .x_o .shape [0 ] # number of iid samples
153
+ num_iid = self .x_o .shape [0 ] # number of iid samples
154
154
iid_posteriors_prob = torch .sum (
155
155
torch .stack (
156
156
[
@@ -162,8 +162,8 @@ def __call__(
162
162
dim = 0 ,
163
163
)
164
164
# Apply the adjustment for iid observations i.e. we have to subtract
165
- # (n -1) times the log prior.
166
- log_probs = iid_posteriors_prob - (n - 1 ) * self .prior .log_prob (
165
+ # (num_iid -1) times the log prior.
166
+ log_probs = iid_posteriors_prob - (num_iid - 1 ) * self .prior .log_prob (
167
167
theta_density_estimator
168
168
).squeeze (- 1 )
169
169
else :
@@ -243,8 +243,8 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
243
243
"""
244
244
if self ._x_o is None :
245
245
raise ValueError (
246
- "No observed data x_o is available. Please reinitialize \
247
- the potential or manually set self._x_o."
246
+ "No observed data x_o is available. Please reinitialize"
247
+ " the potential or manually set self._x_o."
248
248
)
249
249
x_density_estimator = reshape_to_batch_event (
250
250
self .x_o , event_shape = self .vector_field_estimator .condition_shape
@@ -253,17 +253,15 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
253
253
flow = self .neural_ode (x_density_estimator , ** kwargs )
254
254
return flow
255
255
256
- def rebuild_flows_for_batch (
257
- self , atol : float = 1e-5 , rtol : float = 1e-6 , exact : bool = True
258
- ) -> List [NormalizingFlow ]:
256
+ def rebuild_flows_for_batch (self , ** kwargs ) -> List [NormalizingFlow ]:
259
257
"""
260
258
Rebuilds the continuous normalizing flows for each iid in x_o. This is used when
261
259
a new default x_o is set, or to evaluate the log probs at higher precision.
262
260
"""
263
261
if self ._x_o is None :
264
262
raise ValueError (
265
- "No observed data x_o is available. Please reinitialize \
266
- the potential or manually set self._x_o."
263
+ "No observed data x_o is available. Please reinitialize "
264
+ " the potential or manually set self._x_o."
267
265
)
268
266
flows = []
269
267
for i in range (self ._x_o .shape [0 ]):
@@ -272,9 +270,7 @@ def rebuild_flows_for_batch(
272
270
iid_x , event_shape = self .vector_field_estimator .condition_shape
273
271
)
274
272
275
- flow = self .neural_ode (
276
- condition = x_density_estimator , atol = atol , rtol = rtol , exact = exact
277
- )
273
+ flow = self .neural_ode (x_density_estimator , ** kwargs )
278
274
flows .append (flow )
279
275
return flows
280
276
0 commit comments