Skip to content

Commit 8a95c8f

Browse files
committed
add batch dim for x, add test.
1 parent 074efa0 commit 8a95c8f

File tree

2 files changed

+147
-46
lines changed

2 files changed

+147
-46
lines changed

sbi/inference/potentials/likelihood_based_potential.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,28 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
116116
)
117117
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore
118118

119-
def condition_on(self, condition: Tensor, dims_to_sample: List[int]) -> Callable:
120-
"""Returns a potential conditioned on a subset of theta dimensions.
119+
def condition_on_theta(
120+
self, theta_condition: Tensor, dims_to_sample: List[int]
121+
) -> Callable:
122+
"""Returns a potential function conditioned on a subset of theta dimensions.
121123
122124
The condition is a part of theta, but is assumed to correspond to a batch of iid
123-
x_o.
125+
x_o. For example, it can be a batch of experimental conditions that corresponds
126+
to a batch of i.i.d. trials in x_o.
124127
125128
Args:
126-
condition: The condition to fix.
127-
dims_to_sample: The indices of the parameters to sample.
129+
theta_condition: The condition values to be conditioned.
130+
dims_to_sample: The indices of the columns in theta that will be sampled,
131+
i.e., that *not* conditioned. For example, if original theta has shape
132+
`(batch_dim, 3)`, and `dims_to_sample=[0, 1]`, then the potential will
133+
set `theta[:, 3] = theta_condition` at inference time.
128134
129135
Returns:
130-
A potential function conditioned on the condition.
136+
A potential function conditioned on the theta_condition.
131137
"""
132138

139+
assert self.x_is_iid, "Conditioning is only supported for iid data."
140+
133141
def conditioned_potential(
134142
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
135143
) -> Tensor:
@@ -138,10 +146,10 @@ def conditioned_potential(
138146
), "dims_to_sample must match the number of parameters to sample."
139147
theta_without_condition = theta[:, dims_to_sample]
140148

141-
return _log_likelihood_with_iid_condition(
149+
return _log_likelihood_over_iid_conditions(
142150
x=x_o if x_o is not None else self.x_o,
143151
theta_without_condition=theta_without_condition,
144-
condition=condition,
152+
condition=theta_condition,
145153
estimator=self.likelihood_estimator,
146154
track_gradients=track_gradients,
147155
)
@@ -205,63 +213,75 @@ def _log_likelihoods_over_trials(
205213
return log_likelihood_trial_sum
206214

207215

208-
def _log_likelihood_with_iid_condition(
216+
def _log_likelihood_over_iid_conditions(
209217
x: Tensor,
210218
theta_without_condition: Tensor,
211219
condition: Tensor,
212220
estimator: ConditionalDensityEstimator,
213221
track_gradients: bool = False,
214222
) -> Tensor:
215-
"""Return log likelihoods summed over iid trials of `x` with a matching batch of
216-
conditions.
223+
"""Returns $\\log(p(x_o|\theta, condition)$, where x_o is a batch of iid data, and
224+
condition is a matching batch of conditions.
217225
218226
This function is different from `_log_likelihoods_over_trials` in that it moves the
219-
iid batch dimension of `x` onto the batch dimension of `theta`. This is useful when
227+
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
220228
the likelihood estimator is conditioned on a batch of conditions that are iid with
221-
the batch of `x`. It avoid the evaluation of the likelihood for every combination of
222-
`x` and `condition`. Instead, it manually constructs a batch covering all
223-
combination of iid trial and theta batch and reshapes to sum over the iid
229+
the batch of `x`. It avoids the evaluation of the likelihood for every combination
230+
of `x` and `condition`. Instead, it manually constructs a batch covering all
231+
combination of iid trials and theta batch and reshapes to sum over the iid
224232
likelihoods.
225233
226234
Args:
227-
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
228-
theta_without_condition: Batch of parameters `(batch_dim, *event_shape)`
229-
condition: Batch of conditions of shape `(iid_dim, *condition_shape)`.
235+
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
236+
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
237+
observations.
238+
theta_without_condition: Batch of parameters `(theta_batch_dim,
239+
num_parameters)`.
240+
condition: Batch of conditions of shape `(sample_dim, num_conditions)`, must
241+
match x's `sample_dim`.
230242
estimator: DensityEstimator.
231243
track_gradients: Whether to track gradients.
232244
233245
Returns:
234-
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
235-
batch entries (iid trials) in `x`.
246+
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
247+
theta_batch_dim, summed over all i.i.d. trials. Shape
248+
`(x_batch_dim, theta_batch_dim)`.
236249
"""
250+
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
237251
assert (
238-
condition.shape[0] == x.shape[0]
239-
), "Condition and iid x must have the same batch size."
240-
num_trials = x.shape[0]
241-
num_theta = theta_without_condition.shape[0]
242-
x = reshape_to_sample_batch_event(
243-
x, event_shape=x.shape[1:], leading_is_sample=True
244-
)
252+
condition.dim() == 2
253+
), "condition must have shape (sample_dim, num_conditions)."
254+
assert (
255+
theta_without_condition.dim() == 2
256+
), "theta must have shape (batch_dim, num_parameters)."
257+
num_trials, num_xs = x.shape[:2]
258+
num_thetas = theta_without_condition.shape[0]
259+
assert (
260+
condition.shape[0] == num_trials
261+
), "Condition batch size must match the number of iid trials in x."
245262

246263
# move the iid batch dimension onto the batch dimension of theta and repeat it there
247-
x_expanded = x.reshape(1, num_trials, -1).repeat_interleave(num_theta, dim=1)
248-
# for this to work we construct theta and condition to cover all combinations in the
249-
# trial batch and the theta batch.
250-
theta = torch.cat(
264+
x.transpose_(0, 1)
265+
x_repeated = x.repeat_interleave(num_thetas, dim=1)
266+
267+
# construct theta and condition to cover all trial-theta combinations
268+
theta_with_condition = torch.cat(
251269
[
252270
theta_without_condition.repeat(num_trials, 1), # repeat ABAB
253-
condition.repeat_interleave(num_theta, dim=0), # repeat AABB
271+
condition.repeat_interleave(num_thetas, dim=0), # repeat AABB
254272
],
255273
dim=-1,
256274
)
257275

258276
with torch.set_grad_enabled(track_gradients):
259-
# Calculate likelihood in one batch. Returns (1, num_trials * theta_batch_size)
260-
log_likelihood_trial_batch = estimator.log_prob(x_expanded, condition=theta)
277+
# Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
278+
log_likelihood_trial_batch = estimator.log_prob(
279+
x_repeated, condition=theta_with_condition
280+
)
261281
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
262282
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
263-
num_trials, num_theta
264-
).sum(0)
283+
num_xs, num_trials, num_thetas
284+
).sum(1)
265285

266286
return log_likelihood_trial_sum
267287

tests/mnle_test.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sbi.inference.posteriors.vi_posterior import VIPosterior
1515
from sbi.inference.potentials.base_potential import BasePotential
1616
from sbi.inference.potentials.likelihood_based_potential import (
17+
_log_likelihood_over_iid_conditions,
1718
likelihood_estimator_based_potential,
1819
)
1920
from sbi.neural_nets import likelihood_nn
@@ -39,6 +40,15 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.
3940
return torch.cat((rts, choices), dim=1)
4041

4142

43+
def wrapped_simulator(
44+
theta_and_condition: Tensor, last_idx_parameters: int = 2
45+
) -> Tensor:
46+
# simulate with experiment conditions
47+
theta = theta_and_condition[:, :last_idx_parameters]
48+
condition = theta_and_condition[:, last_idx_parameters:]
49+
return mixed_simulator(theta, condition)
50+
51+
4252
@pytest.mark.mcmc
4353
@pytest.mark.gpu
4454
@pytest.mark.parametrize("device", ("cpu", "gpu"))
@@ -256,14 +266,6 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
256266
num_simulations = 10000
257267
num_samples = 1000
258268

259-
def sim_wrapper(
260-
theta_and_condition: Tensor, last_idx_parameters: int = 2
261-
) -> Tensor:
262-
# simulate with experiment conditions
263-
theta = theta_and_condition[:, :last_idx_parameters]
264-
condition = theta_and_condition[:, last_idx_parameters:]
265-
return mixed_simulator(theta, condition)
266-
267269
proposal = MultipleIndependent(
268270
[
269271
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
@@ -274,7 +276,7 @@ def sim_wrapper(
274276
)
275277

276278
theta = proposal.sample((num_simulations,))
277-
x = sim_wrapper(theta)
279+
x = wrapped_simulator(theta)
278280
assert x.shape == (num_simulations, 2)
279281

280282
num_trials = 10
@@ -285,7 +287,7 @@ def sim_wrapper(
285287
condition_o = theta_and_condition[:, 2:]
286288
theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)
287289

288-
x_o = sim_wrapper(theta_and_conditions_o)
290+
x_o = wrapped_simulator(theta_and_conditions_o)
289291

290292
mcmc_kwargs = dict(
291293
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
@@ -331,3 +333,82 @@ def sim_wrapper(
331333
true_posterior_samples,
332334
alg=f"MNLE trained with {num_simulations} simulations",
333335
)
336+
337+
338+
@pytest.mark.parametrize("num_thetas", [1, 10])
339+
@pytest.mark.parametrize("num_trials", [1, 5])
340+
@pytest.mark.parametrize("num_xs", [1, 3])
341+
@pytest.mark.parametrize(
342+
"num_conditions",
343+
[
344+
1,
345+
pytest.param(
346+
2,
347+
marks=pytest.mark.xfail(
348+
reason="Batched theta_condition is not " "supported"
349+
),
350+
),
351+
],
352+
)
353+
def test_log_likelihood_over_iid_conditions(
354+
num_thetas, num_trials, num_xs, num_conditions
355+
):
356+
"""Test log likelihood over iid conditions using MNLE.
357+
358+
Args:
359+
num_thetas: batch of theta to condition on.
360+
num_trials: number of i.i.d. trials in x
361+
num_xs: batch of x, e.g., different subjects in a study.
362+
num_conditions: number of batches of conditions, e.g., different conditions
363+
for each x (not implemented yet).
364+
"""
365+
366+
# train mnle on mixed data
367+
trainer = MNLE(
368+
density_estimator=likelihood_nn(model="mnle", z_score_x=None),
369+
)
370+
proposal = MultipleIndependent(
371+
[
372+
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
373+
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
374+
BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),
375+
],
376+
validate_args=False,
377+
)
378+
379+
num_simulations = 100
380+
theta = proposal.sample((num_simulations,))
381+
x = wrapped_simulator(theta)
382+
estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)
383+
384+
# condition on multiple conditions
385+
theta_o = proposal.sample((num_xs,))[:, :2]
386+
387+
x_o = torch.zeros(num_trials, num_xs, 2)
388+
condition_o = proposal.sample((
389+
num_conditions,
390+
num_trials,
391+
))[:, 2:].reshape(num_trials, 1)
392+
for i in range(num_xs):
393+
# simulate with same iid theta but different conditions
394+
x_o[:, i, :] = mixed_simulator(theta_o[i].repeat(num_trials, 1), condition_o)
395+
396+
# batched conditioning
397+
theta = proposal.sample((num_thetas,))[:, :2]
398+
# x_o has shape (batch, iid, *event)
399+
# condition_o has shape (batch, iid, num_conditions)
400+
ll_batched = _log_likelihood_over_iid_conditions(x_o, theta, condition_o, estimator)
401+
402+
# looped conditioning
403+
ll_single = []
404+
for i in range(num_trials):
405+
theta_and_condition = torch.cat(
406+
(theta, condition_o[i].repeat(num_thetas, 1)), dim=1
407+
)
408+
x_i = x_o[:, i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1)
409+
ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition))
410+
ll_single = torch.stack(ll_single).sum(0) # sum over trials
411+
412+
assert ll_batched.shape == torch.Size([num_xs, num_thetas])
413+
assert ll_batched.shape == ll_single.shape
414+
assert torch.allclose(ll_batched, ll_single, atol=1e-5)

0 commit comments

Comments
 (0)