Skip to content

Commit 1ee577e

Browse files
manuelgloecklergmoss13janfb
authored
feat: score-based iid sampling (#1381)
* npse MAP * set default enable_Transform to True * ruff formatting version change * sampling via diffusion twice * batched sampling for score-based posteriors * iid api integration * new ruff * adding corrector back * adding untesed GAUSS * All other methods * reformat * messy version of simple gauss, API to sample method * jac method (still needs feasible Lambda projection to work) * Only jac left * Formating and so on * Adding correct tests * Add empirical support - but this doesnt work that well * A bunch of auto marginalize and denois methods * general prior with GMM approx * Bunch of reffactorings and customizability * Update API docstirngs * New tests, passes all now * Formating linting, type error form other PR * Remove assert that IID data is not supported * Ruff linting * Fixing comments * Rearangements * Missing types * Improving tests doc * Refactored tests with fixtures! * refactored the tests * Format tests * more documentaiton * Fix FNPE intitialization - not that nice * Last issues resolved * improve docstrings * Raise error on potential. Make BaseClass abstract and remove redundant arguments. * cosmetics * update changelog --------- Co-authored-by: Guy Moss <guy.moss13@gmail.com> Co-authored-by: Jan <jan.boelts@mailbox.org>
1 parent 7d43073 commit 1ee577e

File tree

9 files changed

+1705
-60
lines changed

9 files changed

+1705
-60
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
### ✨ Highlights
66

7-
- Add `CategoricalMADE` by @jnsbck in https://github.com/sbi-dev/sbi/pull/1269 **(Major New Feature)**
7+
- feat: add `CategoricalMADE` by @jnsbck in https://github.com/sbi-dev/sbi/pull/1269 **(Major New Feature)**
88
- tests: `mini-sbibm` by @manuelgloeckler in https://github.com/sbi-dev/sbi/pull/1335 **(Major New Feature)**
9+
- feat: Score-based iid sampling by @manuelgloeckler in https://github.com/sbi-dev/sbi/pull/1381 **(Major New Feature)**
910
- Drop python3.9 support, fix ci by @janfb in https://github.com/sbi-dev/sbi/pull/1412 **(Python Version Support Change)**
1011
- additional features for NPSE by @gmoss13 in https://github.com/sbi-dev/sbi/pull/1370 **(Enhancement)**
1112

@@ -21,6 +22,7 @@
2122
- fix mnle tests by @janfb in https://github.com/sbi-dev/sbi/pull/1415
2223
- fix: protocol and refactor for custom potential by @janfb in https://github.com/sbi-dev/sbi/pull/1409
2324
- fix docs workflow by @janfb in https://github.com/sbi-dev/sbi/pull/1419
25+
- fix: gpu-handling for CategoricalMADE by @janfb in https://github.com/sbi-dev/sbi/pull/1448
2426

2527
### 🛠️ Maintenance & Improvements
2628

sbi/inference/posteriors/score_posterior.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def sample(
105105
corrector_params: Optional[Dict] = None,
106106
steps: int = 500,
107107
ts: Optional[Tensor] = None,
108+
iid_method: str = "auto_gauss",
109+
iid_params: Optional[Dict] = None,
108110
max_sampling_batch_size: int = 10_000,
109111
sample_with: Optional[str] = None,
110112
show_progress_bars: bool = True,
@@ -123,6 +125,19 @@ def sample(
123125
steps: Number of steps to take for the Euler-Maruyama method.
124126
ts: Time points at which to evaluate the diffusion process. If None, a
125127
linear grid between t_max and t_min is used.
128+
iid_method: Which method to use for computing the score in the iid setting.
129+
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". The
130+
fnpe method is simple and generally applicable. However, it can become
131+
inaccurate already for quite a few iid samples (as it based on heuristic
132+
approximations), and should be used at best only with a `corrector`. The
133+
"gauss" methods are more accurate, by aiming for an efficient
134+
approximation of the correct marginal score in the iid case. This
135+
however requires estimating some hyperparamters, which is done in a
136+
systematic way in the "auto_gauss" (initial overhead) and "jac_gauss"
137+
(iterative jacobian computations are expensive). We default to
138+
"auto_gauss" for these reasons.
139+
iid_params: Additional parameters passed to the iid method. See the specific
140+
`IIDScoreFunction` child class for details.
126141
max_sampling_batch_size: Maximum batch size for sampling.
127142
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
128143
`.sample()`.
@@ -138,7 +153,10 @@ def sample(
138153

139154
x = self._x_else_default_x(x)
140155
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
141-
self.potential_fn.set_x(x, x_is_iid=True)
156+
is_iid = x.shape[0] > 1
157+
self.potential_fn.set_x(
158+
x, x_is_iid=is_iid, iid_method=iid_method, iid_params=iid_params
159+
)
142160

143161
num_samples = torch.Size(sample_shape).numel()
144162

sbi/inference/potentials/score_based_potential.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Optional, Tuple
4+
from typing import Any, Dict, Optional, Tuple
55

66
import torch
77
from torch import Tensor
@@ -10,6 +10,7 @@
1010
from zuko.transforms import FreeFormJacobianTransform
1111

1212
from sbi.inference.potentials.base_potential import BasePotential
13+
from sbi.inference.potentials.score_fn_iid import get_iid_method
1314
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
1415
from sbi.neural_nets.estimators.shape_handling import (
1516
reshape_to_batch_event,
@@ -56,7 +57,8 @@ def __init__(
5657
score_estimator: ConditionalScoreEstimator,
5758
prior: Optional[Distribution],
5859
x_o: Optional[Tensor] = None,
59-
iid_method: str = "iid_bridge",
60+
iid_method: str = "auto_gauss",
61+
iid_params: Optional[Dict[str, Any]] = None,
6062
device: str = "cpu",
6163
):
6264
r"""Returns the score function for score-based methods.
@@ -65,19 +67,24 @@ def __init__(
6567
score_estimator: The neural network modelling the score.
6668
prior: The prior distribution.
6769
x_o: The observed data at which to evaluate the posterior.
68-
iid_method: Which method to use for computing the score. Currently, only
69-
`iid_bridge` as proposed in Geffner et al. is implemented.
70+
iid_method: Which method to use for computing the score in the iid setting.
71+
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
72+
iid_params: Parameters for the iid method, for arguments see
73+
`IIDScoreFunction`.
7074
device: The device on which to evaluate the potential.
7175
"""
7276
self.score_estimator = score_estimator
7377
self.score_estimator.eval()
7478
self.iid_method = iid_method
79+
self.iid_params = iid_params
7580
super().__init__(prior, x_o, device=device)
7681

7782
def set_x(
7883
self,
7984
x_o: Optional[Tensor],
8085
x_is_iid: Optional[bool] = False,
86+
iid_method: str = "auto_gauss",
87+
iid_params: Optional[Dict[str, Any]] = None,
8188
atol: float = 1e-5,
8289
rtol: float = 1e-6,
8390
exact: bool = True,
@@ -90,12 +97,20 @@ def set_x(
9097
Args:
9198
x_o: The observed data.
9299
x_is_iid: Whether the observed data is IID (if batch_dim>1).
100+
iid_method: Which method to use for computing the score in the iid setting.
101+
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
102+
iid_params: Parameters for the iid method, for arguments see
103+
`IIDScoreFunction`.
93104
atol: Absolute tolerance for the ODE solver.
94105
rtol: Relative tolerance for the ODE solver.
95106
exact: Whether to use the exact ODE solver.
96107
"""
97108
super().set_x(x_o, x_is_iid)
98-
if self._x_o is not None:
109+
self.iid_method = iid_method
110+
self.iid_params = iid_params
111+
# NOTE: Once IID potential evaluation is supported. This needs to be adapted.
112+
# See #1450.
113+
if not x_is_iid and (self._x_o is not None):
99114
self.flow = self.rebuild_flow(atol=atol, rtol=rtol, exact=exact)
100115

101116
def __call__(
@@ -112,6 +127,15 @@ def __call__(
112127
Returns:
113128
The potential function, i.e., the log probability of the posterior.
114129
"""
130+
131+
if self.x_is_iid:
132+
raise NotImplementedError(
133+
"Potential function evaluation in the IID setting is not yet supported"
134+
" for score-based methods. Sampling does however work via `.sample`. "
135+
"If you intended to evaluate the posterior given a batch of (non-iid) "
136+
"x use `log_prob_batched`."
137+
)
138+
115139
theta = ensure_theta_batched(torch.as_tensor(theta))
116140
theta_density_estimator = reshape_to_sample_batch_event(
117141
theta, theta.shape[1:], leading_is_sample=True
@@ -160,10 +184,15 @@ def gradient(
160184
input=theta, condition=self.x_o, time=time
161185
)
162186
else:
163-
raise NotImplementedError(
164-
"Score accumulation for IID data is not yet implemented."
187+
assert self.prior is not None, "Prior is required for iid methods."
188+
189+
iid_method = get_iid_method(self.iid_method)
190+
score_fn_iid = iid_method(
191+
self.score_estimator, self.prior, **(self.iid_params or {})
165192
)
166193

194+
score = score_fn_iid(theta, self.x_o, time)
195+
167196
return score
168197

169198
def get_continuous_normalizing_flow(
@@ -205,9 +234,6 @@ def rebuild_flow(
205234
x_density_estimator = reshape_to_batch_event(
206235
self.x_o, event_shape=self.score_estimator.condition_shape
207236
)
208-
assert x_density_estimator.shape[0] == 1 or not self.x_is_iid, (
209-
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
210-
)
211237

212238
flow = self.get_continuous_normalizing_flow(
213239
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact

0 commit comments

Comments
 (0)