Skip to content

Commit e7940dc

Browse files
authored
feat: NLE with multiple iid conditions (#1331)
* add method for iid-batched conditioning. - deprecate MNLE-based potential (can be nle-based) - adapt tests for conditioned mnle. * update notebook, bugfixes * add batch dim for x, add test. * fix shape handling, adapt tutorial.
1 parent 390a518 commit e7940dc

File tree

7 files changed

+575
-190
lines changed

7 files changed

+575
-190
lines changed

sbi/inference/potentials/likelihood_based_potential.py

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Callable, Optional, Tuple
4+
import warnings
5+
from typing import Callable, List, Optional, Tuple
56

67
import torch
78
from torch import Tensor
@@ -115,6 +116,54 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
115116
)
116117
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore
117118

119+
def condition_on_theta(
120+
self, local_theta: Tensor, dims_global_theta: List[int]
121+
) -> Callable:
122+
r"""Returns a potential function conditioned on a subset of theta dimensions.
123+
124+
The goal of this function is to divide the original `theta` into a
125+
`global_theta` we do inference over, and a `local_theta` we condition on (in
126+
addition to conditioning on `x_o`). Thus, the returned potential function will
127+
calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
128+
and `local_theta_i` are fixed and `global_theta` varies at inference time.
129+
130+
Args:
131+
local_theta: The condition values to be conditioned.
132+
dims_global_theta: The indices of the columns in `theta` that will be
133+
sampled, i.e., that *not* conditioned. For example, if original theta
134+
has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
135+
potential will set `theta[:, 3] = local_theta` at inference time.
136+
137+
Returns:
138+
A potential function conditioned on the `local_theta`.
139+
"""
140+
141+
assert self.x_is_iid, "Conditioning is only supported for iid data."
142+
143+
def conditioned_potential(
144+
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
145+
) -> Tensor:
146+
assert (
147+
len(dims_global_theta) == theta.shape[1]
148+
), "dims_global_theta must match the number of parameters to sample."
149+
global_theta = theta[:, dims_global_theta]
150+
x_o = x_o if x_o is not None else self.x_o
151+
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
152+
if x_o.dim() < 3:
153+
x_o = reshape_to_sample_batch_event(
154+
x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid
155+
)
156+
157+
return _log_likelihood_over_iid_trials_and_local_theta(
158+
x=x_o,
159+
global_theta=global_theta,
160+
local_theta=local_theta,
161+
estimator=self.likelihood_estimator,
162+
track_gradients=track_gradients,
163+
)
164+
165+
return conditioned_potential
166+
118167

119168
def _log_likelihoods_over_trials(
120169
x: Tensor,
@@ -172,6 +221,77 @@ def _log_likelihoods_over_trials(
172221
return log_likelihood_trial_sum
173222

174223

224+
def _log_likelihood_over_iid_trials_and_local_theta(
225+
x: Tensor,
226+
global_theta: Tensor,
227+
local_theta: Tensor,
228+
estimator: ConditionalDensityEstimator,
229+
track_gradients: bool = False,
230+
) -> Tensor:
231+
"""Returns $\\prod_{i=1}^N \\log(p(x_i|\theta, local_theta_i)$.
232+
233+
`x` is a batch of iid data, and `local_theta` is a matching batch of condition
234+
values that were part of `theta` but are treated as local iid variables at inference
235+
time.
236+
237+
This function is different from `_log_likelihoods_over_trials` in that it moves the
238+
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
239+
the likelihood estimator is conditioned on a batch of conditions that are iid with
240+
the batch of `x`. It avoids the evaluation of the likelihood for every combination
241+
of `x` and `local_theta`.
242+
243+
Args:
244+
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
245+
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
246+
observations.
247+
global_theta: Batch of parameters `(theta_batch_dim,
248+
num_parameters)`.
249+
local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
250+
match x's `sample_dim`.
251+
estimator: DensityEstimator.
252+
track_gradients: Whether to track gradients.
253+
254+
Returns:
255+
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
256+
theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
257+
theta_batch_dim)`.
258+
"""
259+
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
260+
assert (
261+
local_theta.dim() == 2
262+
), "condition must have shape (sample_dim, num_conditions)."
263+
assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)."
264+
num_trials, num_xs = x.shape[:2]
265+
num_thetas = global_theta.shape[0]
266+
assert (
267+
local_theta.shape[0] == num_trials
268+
), "Condition batch size must match the number of iid trials in x."
269+
270+
# move the iid batch dimension onto the batch dimension of theta and repeat it there
271+
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)
272+
273+
# construct theta and condition to cover all trial-theta combinations
274+
theta_with_condition = torch.cat(
275+
[
276+
global_theta.repeat(num_trials, 1), # repeat ABAB
277+
local_theta.repeat_interleave(num_thetas, dim=0), # repeat AABB
278+
],
279+
dim=-1,
280+
)
281+
282+
with torch.set_grad_enabled(track_gradients):
283+
# Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
284+
log_likelihood_trial_batch = estimator.log_prob(
285+
x_repeated, condition=theta_with_condition
286+
)
287+
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
288+
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
289+
num_xs, num_trials, num_thetas
290+
).sum(1)
291+
292+
return log_likelihood_trial_sum
293+
294+
175295
def mixed_likelihood_estimator_based_potential(
176296
likelihood_estimator: MixedDensityEstimator,
177297
prior: Distribution,
@@ -192,6 +312,13 @@ def mixed_likelihood_estimator_based_potential(
192312
to unconstrained space.
193313
"""
194314

315+
warnings.warn(
316+
"This function is deprecated and will be removed in a future release. Use "
317+
"`likelihood_estimator_based_potential` instead.",
318+
DeprecationWarning,
319+
stacklevel=2,
320+
)
321+
195322
device = str(next(likelihood_estimator.discrete_net.parameters()).device)
196323

197324
potential_fn = MixedLikelihoodBasedPotential(
@@ -212,6 +339,13 @@ def __init__(
212339
):
213340
super().__init__(likelihood_estimator, prior, x_o, device)
214341

342+
warnings.warn(
343+
"This function is deprecated and will be removed in a future release. Use "
344+
"`LikelihoodBasedPotential` instead.",
345+
DeprecationWarning,
346+
stacklevel=2,
347+
)
348+
215349
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
216350
prior_log_prob = self.prior.log_prob(theta) # type: ignore
217351

@@ -231,7 +365,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
231365
with torch.set_grad_enabled(track_gradients):
232366
# Call the specific log prob method of the mixed likelihood estimator as
233367
# this optimizes the evaluation of the discrete data part.
234-
# TODO log_prob_iid
235368
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
236369
input=x,
237370
condition=theta.to(self.device),

sbi/inference/trainers/nle/mnle.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.distributions import Distribution
88

99
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
10-
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
10+
from sbi.inference.potentials import likelihood_estimator_based_potential
1111
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator
1212
from sbi.neural_nets.estimators import MixedDensityEstimator
1313
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
@@ -155,9 +155,7 @@ def build_posterior(
155155
(
156156
potential_fn,
157157
theta_transform,
158-
) = mixed_likelihood_estimator_based_potential(
159-
likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
160-
)
158+
) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None)
161159

162160
if sample_with == "mcmc":
163161
self._posterior = MCMCPosterior(

sbi/utils/conditional_density_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def __init__(
293293
masked outside of prior.
294294
"""
295295
condition = torch.atleast_2d(condition)
296-
if condition.shape[0] != 1:
296+
if condition.shape[0] > 1:
297297
raise ValueError("Condition with batch size > 1 not supported.")
298298

299299
self.potential_fn = potential_fn

sbi/utils/sbiutils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -
6060

6161
if num_unique_z < num_unique * (1 - duplicate_tolerance):
6262
warnings.warn(
63-
"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
64-
"datapoints. Before z-scoring, it had been {num_unique}. This can "
63+
f"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
64+
f"datapoints. Before z-scoring, it had been {num_unique}. This can "
6565
"occur due to numerical inaccuracies when the data covers a large "
6666
"range of values. Consider either setting `z_score_x=False` (but "
6767
"beware that this can be problematic for training the NN) or exclude "

0 commit comments

Comments
 (0)