Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import inspect
from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Union
from warnings import warn
Expand Down Expand Up @@ -54,14 +53,6 @@ def __init__(

# Wrap as `CallablePotentialWrapper` if `potential_fn` is a Callable.
if not isinstance(potential_fn, BasePotential):
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)

# If the `potential_fn` is a Callable then we wrap it as a
# `CallablePotentialWrapper` which inherits from `BasePotential`.
potential_device = "cpu" if device is None else device
Expand Down
34 changes: 28 additions & 6 deletions sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import inspect
from abc import ABCMeta, abstractmethod
from typing import Optional
from typing import Callable, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -85,18 +86,39 @@ def return_x_o(self) -> Optional[Tensor]:
class CallablePotentialWrapper(BasePotential):
"""If `potential_fn` is a callable it gets wrapped as this."""

allow_iid_x = True # type: ignore

def __init__(
self,
callable_potential,
potential_fn: Callable,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
"""Wraps a callable potential function.

Args:
potential_fn: Callable potential function, must have `theta` and `x_o` as
arguments.
prior: Prior distribution.
x_o: Observed data.
device: Device on which to evaluate the potential function.

"""
super().__init__(prior, x_o, device)
self.callable_potential = callable_potential

kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
required_keys = ["theta", "x_o"]
for key in required_keys:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)
self.potential_fn = potential_fn

def __call__(self, theta, track_gradients: bool = True):
"""Call the callable potential function on given theta.

Note, x_o is re-used from the initialization of the potential function.
"""
with torch.set_grad_enabled(track_gradients):
return self.callable_potential(theta=theta, x_o=self.x_o)
return self.potential_fn(theta=theta, x_o=self.x_o)
16 changes: 11 additions & 5 deletions sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def condition_mog(
return logits, means, precfs_xx, sumlogdiag


class ConditionedPotential:
class ConditionedPotential(BasePotential):
def __init__(
self,
potential_fn: BasePotential,
Expand All @@ -282,20 +282,26 @@ def __init__(
Return conditional posterior log-probability or $-\infty$ if outside prior.

Args:
theta: Free parameters $\theta_i$, batch dimension 1.
potential_fn: Potential function to condition on.
condition: Fixed parameters $\theta_j$, batch size 1.
dims_to_sample: Which dimensions to sample from. The dimensions not
specified in `dims_to_sample` will be fixed to values given in
`condition`.

Returns:
Conditional posterior log-probability $\log(p(\theta_i|\theta_j, x))$,
masked outside of prior.
"""
condition = torch.atleast_2d(condition)
if condition.shape[0] != 1:
raise ValueError("Condition with batch size > 1 not supported.")

self.potential_fn = potential_fn
self.condition = condition
self.dims_to_sample = dims_to_sample
self.device = self.potential_fn.device

def __call__(
self, theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
r"""
Returns the conditional potential $\log(p(\theta_i|\theta_j, x))$.

Expand Down
49 changes: 18 additions & 31 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
# True posterior samples
transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
PotentialFunctionProvider(prior, atleast_2d(x_o)),
BinomialGammaPotential(prior, atleast_2d(x_o)),
theta_transform=transform,
proposal=prior,
**mcmc_kwargs,
Expand All @@ -189,14 +189,9 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
)


class PotentialFunctionProvider(BasePotential):
"""Returns potential function for reference posterior of a mixed likelihood."""

allow_iid_x = True # type: ignore

class BinomialGammaPotential(BasePotential):
def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"):
super().__init__(prior, x_o, device)

self.concentration_scaling = concentration_scaling

def __call__(self, theta, track_gradients: bool = True):
Expand All @@ -207,33 +202,25 @@ def __call__(self, theta, track_gradients: bool = True):

return iid_ll + self.prior.log_prob(theta)

def iid_likelihood(self, theta: torch.Tensor) -> torch.Tensor:
"""Returns the likelihood summed over a batch of i.i.d. data."""

lp_choices = torch.stack(
[
Binomial(probs=th.reshape(1, -1)).log_prob(self.x_o[:, 1:])
for th in theta[:, 1:]
],
dim=1,
def iid_likelihood(self, theta):
batch_size = theta.shape[0]
num_trials = self.x_o.shape[0]
theta = theta.reshape(batch_size, 1, -1)
beta, rho = theta[:, :, :1], theta[:, :, 1:]
# vectorized
logprob_choices = Binomial(probs=rho).log_prob(
self.x_o[:, 1:].reshape(1, num_trials, -1)
)

lp_rts = torch.stack(
[
InverseGamma(
concentration=self.concentration_scaling * torch.ones_like(beta_i),
rate=beta_i,
).log_prob(self.x_o[:, :1])
for beta_i in theta[:, :1]
],
dim=1,
)
logprob_rts = InverseGamma(
concentration=self.concentration_scaling * torch.ones_like(beta),
rate=beta,
).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1))

joint_likelihood = (lp_choices + lp_rts).reshape(
self.x_o.shape[0], theta.shape[0]
)
joint_likelihood = (logprob_choices + logprob_rts).squeeze()

return joint_likelihood.sum(0)
assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]])
return joint_likelihood.sum(1)


@pytest.mark.slow
Expand Down Expand Up @@ -295,7 +282,7 @@ def sim_wrapper(theta):
)
prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
PotentialFunctionProvider(
BinomialGammaPotential(
prior,
atleast_2d(x_o),
concentration_scaling=float(theta_o[0, 2])
Expand Down
27 changes: 26 additions & 1 deletion tests/potential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
import torch
from torch import eye, ones, zeros
from torch import Tensor, eye, ones, zeros
from torch.distributions import MultivariateNormal

from sbi.inference import (
Expand All @@ -14,6 +14,9 @@
RejectionPosterior,
VIPosterior,
)
from sbi.inference.potentials.base_potential import CallablePotentialWrapper
from sbi.utils import BoxUniform
from sbi.utils.conditional_density_utils import ConditionedPotential


@pytest.mark.parametrize(
Expand Down Expand Up @@ -64,3 +67,25 @@ def potential(theta, x_o):
sample_std = torch.std(approx_samples, dim=0)
assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2)
assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)


@pytest.mark.parametrize(
"condition",
[
torch.rand(1, 2),
pytest.param(
torch.rand(2, 2),
marks=pytest.mark.xfail(
raises=ValueError,
match="Condition with batch size > 1 not supported",
),
),
],
)
def test_conditioned_potential(condition: Tensor):
potential_fn = CallablePotentialWrapper(
potential_fn=lambda theta, x_o: theta,
prior=BoxUniform(low=zeros(2), high=ones(2)),
)

ConditionedPotential(potential_fn, condition=condition, dims_to_sample=[0])
Loading