Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,13 @@ def log_prob(
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
support of the prior, -∞ (corresponding to 0 probability) outside.
"""
self.potential_fn.set_x(self._x_else_default_x(x), **(ode_kwargs or {}))
x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
is_iid = x.shape[0] > 1
self.potential_fn.set_x(
x,
x_is_iid=is_iid,
)

theta = ensure_theta_batched(torch.as_tensor(theta))
return self.potential_fn(
Expand Down
81 changes: 55 additions & 26 deletions sbi/inference/potentials/vector_field_potential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Dict, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -115,10 +115,10 @@ def set_x(
super().set_x(x_o, x_is_iid)
self.iid_method = iid_method or self.iid_method
self.iid_params = iid_params
# NOTE: Once IID potential evaluation is supported. This needs to be adapted.
# See #1450.
if not x_is_iid and (self._x_o is not None):
self.flow = self.rebuild_flow(**ode_kwargs)
elif self._x_o is not None:
self.flows = self.rebuild_flows_for_batch(**ode_kwargs)

def __call__(
self,
Expand All @@ -135,26 +135,6 @@ def __call__(
Returns:
The potential function, i.e., the log probability of the posterior.
"""
# TODO: incorporate iid setting. See issue #1450 and PR #1508
if self.x_is_iid:
if (
self.vector_field_estimator.MARGINALS_DEFINED
and self.vector_field_estimator.SCORE_DEFINED
):
raise NotImplementedError(
"Potential function evaluation in the "
"IID setting is not yet supported"
" for vector field based methods. "
"Sampling does however work via `.sample`. "
"If you intended to evaluate the posterior "
"given a batch of (non-iid) "
"x use `log_prob_batched`."
)
else:
raise NotImplementedError(
"IID is not supported for this vector field estimator "
"since the required methods (marginals or score) are not defined."
)

theta = ensure_theta_batched(torch.as_tensor(theta))
theta_density_estimator = reshape_to_sample_batch_event(
Expand All @@ -163,7 +143,31 @@ def __call__(
self.vector_field_estimator.eval()

with torch.set_grad_enabled(track_gradients):
log_probs = self.flow.log_prob(theta_density_estimator).squeeze(-1)
if self.x_is_iid:
assert self.prior is not None, (
"Prior is required for evaluating log_prob with iid observations."
)
assert self.flows is not None, (
"Flows for each iid x are required for evaluating log_prob."
)
n = self.x_o.shape[0] # number of iid samples
iid_posteriors_prob = torch.sum(
torch.stack(
[
flow.log_prob(theta_density_estimator).squeeze(-1)
for flow in self.flows
],
dim=0,
),
dim=0,
)
# Apply the adjustment for iid observations i.e. we have to subtract
# (n-1) times the log prior.
log_probs = iid_posteriors_prob - (n - 1) * self.prior.log_prob(
theta_density_estimator
).squeeze(-1)
else:
log_probs = self.flow.log_prob(theta_density_estimator).squeeze(-1)
# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)

Expand Down Expand Up @@ -208,8 +212,8 @@ def gradient(

if self._x_o is None:
raise ValueError(
"No observed data x_o is available. Please reinitialize \
the potential or manually set self._x_o."
"No observed data x_o is available. Please reinitialize"
"the potential or manually set self._x_o."
)

with torch.set_grad_enabled(track_gradients):
Expand Down Expand Up @@ -249,6 +253,31 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
flow = self.neural_ode(x_density_estimator, **kwargs)
return flow

def rebuild_flows_for_batch(
self, atol: float = 1e-5, rtol: float = 1e-6, exact: bool = True
) -> List[NormalizingFlow]:
"""
Rebuilds the continuous normalizing flows for each iid in x_o. This is used when
a new default x_o is set, or to evaluate the log probs at higher precision.
"""
if self._x_o is None:
raise ValueError(
"No observed data x_o is available. Please reinitialize \
the potential or manually set self._x_o."
)
flows = []
for i in range(self._x_o.shape[0]):
iid_x = self._x_o[i]
x_density_estimator = reshape_to_batch_event(
iid_x, event_shape=self.vector_field_estimator.condition_shape
)

flow = self.neural_ode(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
)
flows.append(flow)
return flows


def vector_field_estimator_based_potential(
vector_field_estimator: ConditionalVectorFieldEstimator,
Expand Down
46 changes: 46 additions & 0 deletions tests/linearGaussian_vector_field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,49 @@ def simulator(theta):

max_err = np.max(error)
assert max_err < 0.0027


@pytest.mark.slow
@pytest.mark.parametrize("vector_field_type", ["ve", "vp", "fmpe"])
@pytest.mark.parametrize("prior_type", ["gaussian"])
@pytest.mark.parametrize("iid_batch_size", [1, 2])
def test_iid_log_prob(vector_field_type, prior_type, iid_batch_size):
'''
Tests the log-probability computation of the score-based posterior.

'''

vector_field_trained_model = train_vector_field_model(vector_field_type, prior_type)

# Prior Gaussian
prior = vector_field_trained_model["prior"]
vf_estimator = vector_field_trained_model["estimator"]
inference = vector_field_trained_model["inference"]
likelihood_shift = vector_field_trained_model["likelihood_shift"]
likelihood_cov = vector_field_trained_model["likelihood_cov"]
prior_mean = vector_field_trained_model["prior_mean"]
prior_cov = vector_field_trained_model["prior_cov"]
num_dim = vector_field_trained_model["num_dim"]
num_posterior_samples = 1000

# Ground truth theta
theta_o = zeros(num_dim)
x_o = linear_gaussian(
theta_o.repeat(iid_batch_size, 1),
likelihood_shift=likelihood_shift,
likelihood_cov=likelihood_cov,
)
true_posterior = true_posterior_linear_gaussian_mvn_prior(
x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov
)

approx_posterior = inference.build_posterior(vf_estimator, prior=prior)
posterior_samples = true_posterior.sample((num_posterior_samples,))
true_prob = true_posterior.log_prob(posterior_samples)
approx_prob = approx_posterior.log_prob(posterior_samples, x=x_o)

diff = torch.abs(true_prob - approx_prob)
assert diff.mean() < 0.4, (
f"Probs diff: {diff.mean()} too big "
f"for number of samples {num_posterior_samples}"
)