Skip to content

Commit cba45de

Browse files
Allow 1D prior with batch dimension
1 parent 4a7ed2b commit cba45de

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

sbi/utils/user_input_checks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sbi.utils.user_input_checks_utils import (
1818
CustomPriorWrapper,
1919
MultipleIndependent,
20+
OneDimPriorWrapper,
2021
PytorchReturnTypeWrapper,
2122
)
2223

@@ -220,6 +221,11 @@ def process_pytorch_prior(prior: Distribution) -> Tuple[Distribution, int, bool]
220221
# This will fail for float64 priors.
221222
check_prior_return_type(prior)
222223

224+
# Potentially required wrapper if the prior returns an additional sample dimension
225+
# for `.log_prob()`.
226+
if prior.log_prob(prior.sample((10,))).shape == (10, 1):
227+
prior = OneDimPriorWrapper(prior, validate_args=False)
228+
223229
theta_numel = prior.sample().numel()
224230

225231
return prior, theta_numel, False

sbi/utils/user_input_checks_utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,69 @@ def build_support(
373373
support = constraints.interval(lower_bound, upper_bound)
374374

375375
return support
376+
377+
378+
class OneDimPriorWrapper(Distribution):
379+
"""Wrap batched 1D distributions to get rid of the batch dim of `.log_prob()`.
380+
381+
1D pytorch distributions such as `torch.distributions.Exponential`, `.Uniform`, or
382+
`.Normal` do not, by default return __any__ sample or batch dimension. E.g.:
383+
```python
384+
dist = torch.distributions.Exponential(torch.tensor(3.0))
385+
dist.sample((10,)).shape # (10,)
386+
```
387+
388+
`sbi` will raise an error that the sample dimension is missing. A simple solution is
389+
to add a batch dimension to `dist` as follows:
390+
```python
391+
dist = torch.distributions.Exponential(torch.tensor([3.0]))
392+
dist.sample((10,)).shape # (10, 1)
393+
```
394+
395+
Unfortunately, this `dist` will return the batch dimension also for `.log_prob():
396+
```python
397+
dist = torch.distributions.Exponential(torch.tensor([3.0]))
398+
samples = dist.sample((10,))
399+
dist.log_prob(samples).shape # (10, 1)
400+
```
401+
402+
This will lead to unexpected errors in `sbi`. The point of this class is to wrap
403+
those batched 1D distributions to get rid of their batch dimension in `.log_prob()`.
404+
"""
405+
406+
def __init__(
407+
self,
408+
prior: Distribution,
409+
validate_args=None,
410+
) -> None:
411+
super().__init__(
412+
batch_shape=prior.batch_shape,
413+
event_shape=prior.event_shape,
414+
validate_args=(
415+
prior._validate_args if validate_args is None else validate_args
416+
),
417+
)
418+
self.prior = prior
419+
420+
def sample(self, *args, **kwargs) -> Tensor:
421+
return self.prior.sample(*args, **kwargs)
422+
423+
def log_prob(self, *args, **kwargs) -> Tensor:
424+
"""Override the log_prob method to get rid of the additional batch dimension."""
425+
return self.prior.log_prob(*args, **kwargs)[..., 0]
426+
427+
@property
428+
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
429+
return self.prior.arg_constraints
430+
431+
@property
432+
def support(self):
433+
return self.prior.support
434+
435+
@property
436+
def mean(self) -> Tensor:
437+
return self.prior.mean
438+
439+
@property
440+
def variance(self) -> Tensor:
441+
return self.prior.variance

tests/user_input_checks_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from pyknos.mdn.mdn import MultivariateGaussianMDN
1212
from torch import Tensor, eye, nn, ones, zeros
13-
from torch.distributions import Beta, Distribution, Gamma, MultivariateNormal, Uniform
13+
from torch.distributions import Beta, Distribution, Gamma, MultivariateNormal, Uniform, Exponential
1414

1515
from sbi.inference import NPE_A, NPE_C, simulate_for_sbi
1616
from sbi.inference.posteriors.direct_posterior import DirectPosterior
@@ -28,6 +28,7 @@
2828
CustomPriorWrapper,
2929
MultipleIndependent,
3030
PytorchReturnTypeWrapper,
31+
OneDimPriorWrapper,
3132
)
3233

3334

@@ -93,6 +94,11 @@ def matrix_simulator(theta):
9394
BoxUniform(zeros(3, dtype=torch.float64), ones(3, dtype=torch.float64)),
9495
dict(),
9596
),
97+
(
98+
OneDimPriorWrapper,
99+
Exponential(torch.tensor([3.0])),
100+
dict(),
101+
),
96102
),
97103
)
98104
def test_prior_wrappers(wrapper, prior, kwargs):
@@ -118,6 +124,9 @@ def test_prior_wrappers(wrapper, prior, kwargs):
118124
# Test transform
119125
mcmc_transform(prior)
120126

127+
# For 1D priors, the `log_prob()` should not have a batch dim.
128+
assert len(prior.log_prob(prior.sample((10,))).shape) == 1
129+
121130

122131
def test_reinterpreted_batch_dim_prior():
123132
"""Test whether the right warning and error are raised for reinterpreted priors."""
@@ -268,7 +277,6 @@ def test_prepare_sbi_problem(simulator: Callable, prior):
268277
prior: prior as defined by the user (pytorch, scipy, custom)
269278
x_shape: shape of data as defined by the user.
270279
"""
271-
272280
prior, _, prior_returns_numpy = process_prior(prior)
273281
simulator = process_simulator(simulator, prior, prior_returns_numpy)
274282
check_sbi_inputs(simulator, prior)

0 commit comments

Comments
 (0)