|
1 | 1 | # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
|
2 | 2 | # under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
|
3 | 3 |
|
| 4 | +import inspect |
4 | 5 | from abc import ABCMeta, abstractmethod
|
5 |
| -from typing import Optional |
| 6 | +from typing import Callable, Optional |
6 | 7 |
|
7 | 8 | import torch
|
8 | 9 | from torch import Tensor
|
@@ -85,18 +86,39 @@ def return_x_o(self) -> Optional[Tensor]:
|
85 | 86 | class CallablePotentialWrapper(BasePotential):
|
86 | 87 | """If `potential_fn` is a callable it gets wrapped as this."""
|
87 | 88 |
|
88 |
| - allow_iid_x = True # type: ignore |
89 |
| - |
90 | 89 | def __init__(
|
91 | 90 | self,
|
92 |
| - callable_potential, |
| 91 | + potential_fn: Callable, |
93 | 92 | prior: Optional[Distribution],
|
94 | 93 | x_o: Optional[Tensor] = None,
|
95 | 94 | device: str = "cpu",
|
96 | 95 | ):
|
| 96 | + """Wraps a callable potential function. |
| 97 | +
|
| 98 | + Args: |
| 99 | + potential_fn: Callable potential function, must have `theta` and `x_o` as |
| 100 | + arguments. |
| 101 | + prior: Prior distribution. |
| 102 | + x_o: Observed data. |
| 103 | + device: Device on which to evaluate the potential function. |
| 104 | +
|
| 105 | + """ |
97 | 106 | super().__init__(prior, x_o, device)
|
98 |
| - self.callable_potential = callable_potential |
| 107 | + |
| 108 | + kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) |
| 109 | + required_keys = ["theta", "x_o"] |
| 110 | + for key in required_keys: |
| 111 | + assert key in kwargs_of_callable, ( |
| 112 | + "If you pass a `Callable` as `potential_fn` then it must have " |
| 113 | + "`theta` and `x_o` as inputs, even if some of these keyword " |
| 114 | + "arguments are unused." |
| 115 | + ) |
| 116 | + self.potential_fn = potential_fn |
99 | 117 |
|
100 | 118 | def __call__(self, theta, track_gradients: bool = True):
|
| 119 | + """Call the callable potential function on given theta. |
| 120 | +
|
| 121 | + Note, x_o is re-used from the initialization of the potential function. |
| 122 | + """ |
101 | 123 | with torch.set_grad_enabled(track_gradients):
|
102 |
| - return self.callable_potential(theta=theta, x_o=self.x_o) |
| 124 | + return self.potential_fn(theta=theta, x_o=self.x_o) |
0 commit comments