Skip to content

Commit d343855

Browse files
committed
refactor: callable potential fun; add test
1 parent 7c58d50 commit d343855

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

sbi/inference/posteriors/base_posterior.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
import inspect
54
from abc import abstractmethod
65
from typing import Any, Callable, Dict, Optional, Union
76
from warnings import warn
@@ -54,14 +53,6 @@ def __init__(
5453

5554
# Wrap as `CallablePotentialWrapper` if `potential_fn` is a Callable.
5655
if not isinstance(potential_fn, BasePotential):
57-
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
58-
for key in ["theta", "x_o"]:
59-
assert key in kwargs_of_callable, (
60-
"If you pass a `Callable` as `potential_fn` then it must have "
61-
"`theta` and `x_o` as inputs, even if some of these keyword "
62-
"arguments are unused."
63-
)
64-
6556
# If the `potential_fn` is a Callable then we wrap it as a
6657
# `CallablePotentialWrapper` which inherits from `BasePotential`.
6758
potential_device = "cpu" if device is None else device

sbi/inference/potentials/base_potential.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import inspect
45
from abc import ABCMeta, abstractmethod
5-
from typing import Optional
6+
from typing import Callable, Optional
67

78
import torch
89
from torch import Tensor
@@ -85,18 +86,39 @@ def return_x_o(self) -> Optional[Tensor]:
8586
class CallablePotentialWrapper(BasePotential):
8687
"""If `potential_fn` is a callable it gets wrapped as this."""
8788

88-
allow_iid_x = True # type: ignore
89-
9089
def __init__(
9190
self,
92-
callable_potential,
91+
potential_fn: Callable,
9392
prior: Optional[Distribution],
9493
x_o: Optional[Tensor] = None,
9594
device: str = "cpu",
9695
):
96+
"""Wraps a callable potential function.
97+
98+
Args:
99+
potential_fn: Callable potential function, must have `theta` and `x_o` as
100+
arguments.
101+
prior: Prior distribution.
102+
x_o: Observed data.
103+
device: Device on which to evaluate the potential function.
104+
105+
"""
97106
super().__init__(prior, x_o, device)
98-
self.callable_potential = callable_potential
107+
108+
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
109+
required_keys = ["theta", "x_o"]
110+
for key in required_keys:
111+
assert key in kwargs_of_callable, (
112+
"If you pass a `Callable` as `potential_fn` then it must have "
113+
"`theta` and `x_o` as inputs, even if some of these keyword "
114+
"arguments are unused."
115+
)
116+
self.potential_fn = potential_fn
99117

100118
def __call__(self, theta, track_gradients: bool = True):
119+
"""Call the callable potential function on given theta.
120+
121+
Note, x_o is re-used from the initialization of the potential function.
122+
"""
101123
with torch.set_grad_enabled(track_gradients):
102-
return self.callable_potential(theta=theta, x_o=self.x_o)
124+
return self.potential_fn(theta=theta, x_o=self.x_o)

tests/potential_test.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77
import torch
8-
from torch import eye, ones, zeros
8+
from torch import Tensor, eye, ones, zeros
99
from torch.distributions import MultivariateNormal
1010

1111
from sbi.inference import (
@@ -14,6 +14,9 @@
1414
RejectionPosterior,
1515
VIPosterior,
1616
)
17+
from sbi.inference.potentials.base_potential import CallablePotentialWrapper
18+
from sbi.utils import BoxUniform
19+
from sbi.utils.conditional_density_utils import ConditionedPotential
1720

1821

1922
@pytest.mark.parametrize(
@@ -64,3 +67,25 @@ def potential(theta, x_o):
6467
sample_std = torch.std(approx_samples, dim=0)
6568
assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2)
6669
assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)
70+
71+
72+
@pytest.mark.parametrize(
73+
"condition",
74+
[
75+
torch.rand(1, 2),
76+
pytest.param(
77+
torch.rand(2, 2),
78+
marks=pytest.mark.xfail(
79+
raises=ValueError,
80+
match="Condition with batch size > 1 not supported",
81+
),
82+
),
83+
],
84+
)
85+
def test_conditioned_potential(condition: Tensor):
86+
potential_fn = CallablePotentialWrapper(
87+
potential_fn=lambda theta, x_o: theta,
88+
prior=BoxUniform(low=zeros(2), high=ones(2)),
89+
)
90+
91+
ConditionedPotential(potential_fn, condition=condition, dims_to_sample=[0])

0 commit comments

Comments
 (0)