Skip to content

Commit 8296538

Browse files
Slight improvements
1 parent 26cf59d commit 8296538

File tree

3 files changed

+12
-19
lines changed

3 files changed

+12
-19
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,7 @@ def log_prob(
407407
x = self._x_else_default_x(x)
408408
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
409409
is_iid = x.shape[0] > 1
410-
self.potential_fn.set_x(
411-
x,
412-
x_is_iid=is_iid,
413-
)
410+
self.potential_fn.set_x(x, x_is_iid=is_iid, **(ode_kwargs or {}))
414411

415412
theta = ensure_theta_batched(torch.as_tensor(theta))
416413
return self.potential_fn(

sbi/inference/potentials/vector_field_potential.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __call__(
150150
assert self.flows is not None, (
151151
"Flows for each iid x are required for evaluating log_prob."
152152
)
153-
n = self.x_o.shape[0] # number of iid samples
153+
num_iid = self.x_o.shape[0] # number of iid samples
154154
iid_posteriors_prob = torch.sum(
155155
torch.stack(
156156
[
@@ -162,8 +162,8 @@ def __call__(
162162
dim=0,
163163
)
164164
# Apply the adjustment for iid observations i.e. we have to subtract
165-
# (n-1) times the log prior.
166-
log_probs = iid_posteriors_prob - (n - 1) * self.prior.log_prob(
165+
# (num_iid-1) times the log prior.
166+
log_probs = iid_posteriors_prob - (num_iid - 1) * self.prior.log_prob(
167167
theta_density_estimator
168168
).squeeze(-1)
169169
else:
@@ -243,8 +243,8 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
243243
"""
244244
if self._x_o is None:
245245
raise ValueError(
246-
"No observed data x_o is available. Please reinitialize \
247-
the potential or manually set self._x_o."
246+
"No observed data x_o is available. Please reinitialize"
247+
"the potential or manually set self._x_o."
248248
)
249249
x_density_estimator = reshape_to_batch_event(
250250
self.x_o, event_shape=self.vector_field_estimator.condition_shape
@@ -253,17 +253,15 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
253253
flow = self.neural_ode(x_density_estimator, **kwargs)
254254
return flow
255255

256-
def rebuild_flows_for_batch(
257-
self, atol: float = 1e-5, rtol: float = 1e-6, exact: bool = True
258-
) -> List[NormalizingFlow]:
256+
def rebuild_flows_for_batch(self, **kwargs) -> List[NormalizingFlow]:
259257
"""
260258
Rebuilds the continuous normalizing flows for each iid in x_o. This is used when
261259
a new default x_o is set, or to evaluate the log probs at higher precision.
262260
"""
263261
if self._x_o is None:
264262
raise ValueError(
265-
"No observed data x_o is available. Please reinitialize \
266-
the potential or manually set self._x_o."
263+
"No observed data x_o is available. Please reinitialize "
264+
"the potential or manually set self._x_o."
267265
)
268266
flows = []
269267
for i in range(self._x_o.shape[0]):
@@ -272,9 +270,7 @@ def rebuild_flows_for_batch(
272270
iid_x, event_shape=self.vector_field_estimator.condition_shape
273271
)
274272

275-
flow = self.neural_ode(
276-
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
277-
)
273+
flow = self.neural_ode(x_density_estimator, **kwargs)
278274
flows.append(flow)
279275
return flows
280276

tests/linearGaussian_vector_field_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def simulator(theta):
593593
@pytest.mark.slow
594594
@pytest.mark.parametrize("vector_field_type", ["ve", "vp", "fmpe"])
595595
@pytest.mark.parametrize("prior_type", ["gaussian"])
596-
@pytest.mark.parametrize("iid_batch_size", [1, 2])
596+
@pytest.mark.parametrize("iid_batch_size", [1, 2, 5])
597597
def test_iid_log_prob(vector_field_type, prior_type, iid_batch_size):
598598
'''
599599
Tests the log-probability computation of the score-based posterior.
@@ -630,7 +630,7 @@ def test_iid_log_prob(vector_field_type, prior_type, iid_batch_size):
630630
approx_prob = approx_posterior.log_prob(posterior_samples, x=x_o)
631631

632632
diff = torch.abs(true_prob - approx_prob)
633-
assert diff.mean() < 0.4, (
633+
assert diff.mean() < 0.3 * iid_batch_size, (
634634
f"Probs diff: {diff.mean()} too big "
635635
f"for number of samples {num_posterior_samples}"
636636
)

0 commit comments

Comments
 (0)