Skip to content

Commit 448cef2

Browse files
authored
fix #1364: conditional posterior shape and device bugs. (#1373)
* fix shape and device bugs. * fix: do not test batched and iid x * fix coverage and testing bugs
1 parent 80740b2 commit 448cef2

File tree

4 files changed

+94
-13
lines changed

4 files changed

+94
-13
lines changed

sbi/inference/potentials/likelihood_based_potential.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,14 @@ def condition_on_theta(
143143
def conditioned_potential(
144144
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
145145
) -> Tensor:
146-
assert len(dims_global_theta) == theta.shape[1], (
146+
assert len(dims_global_theta) == theta.shape[-1], (
147147
"dims_global_theta must match the number of parameters to sample."
148148
)
149+
if theta.dim() > 2:
150+
assert theta.shape[0] == 1, (
151+
"condition_on_theta does not support sample shape for theta."
152+
)
153+
theta = theta.squeeze(0)
149154
global_theta = theta[:, dims_global_theta]
150155
x_o = x_o if x_o is not None else self.x_o
151156
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
@@ -155,7 +160,7 @@ def conditioned_potential(
155160
)
156161

157162
return _log_likelihood_over_iid_trials_and_local_theta(
158-
x=x_o,
163+
x=x_o.to(self.device),
159164
global_theta=global_theta,
160165
local_theta=local_theta,
161166
estimator=self.likelihood_estimator,
@@ -266,6 +271,10 @@ def _log_likelihood_over_iid_trials_and_local_theta(
266271
assert local_theta.shape[0] == num_trials, (
267272
"Condition batch size must match the number of iid trials in x."
268273
)
274+
if num_xs > 1:
275+
raise NotImplementedError(
276+
"Batched sampling for multiple `x` is not supported for iid conditions."
277+
)
269278

270279
# move the iid batch dimension onto the batch dimension of theta and repeat it there
271280
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)
@@ -289,7 +298,8 @@ def _log_likelihood_over_iid_trials_and_local_theta(
289298
num_xs, num_trials, num_thetas
290299
).sum(1)
291300

292-
return log_likelihood_trial_sum
301+
# remove xs batch dimension
302+
return log_likelihood_trial_sum.squeeze(0)
293303

294304

295305
def mixed_likelihood_estimator_based_potential(

sbi/utils/sbiutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ def gradient_ascent(
947947
)
948948
best_theta_iter = optimize_inits[ # type: ignore
949949
torch.argmax(log_probs_of_optimized)
950-
].view(1, -1)
950+
].unsqueeze(0) # add batch dim
951951
best_log_prob_iter = potential_fn(
952952
theta_transform.inv(best_theta_iter)
953953
)

tests/inference_on_device_test.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ratio_estimator_based_potential,
2525
)
2626
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
27+
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
2728
from sbi.inference.potentials.base_potential import BasePotential
2829
from sbi.neural_nets.embedding_nets import FCEmbedding
2930
from sbi.neural_nets.factory import (
@@ -33,7 +34,11 @@
3334
posterior_nn,
3435
)
3536
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
36-
from sbi.utils.torchutils import BoxUniform, gpu_available, process_device
37+
from sbi.utils.torchutils import (
38+
BoxUniform,
39+
gpu_available,
40+
process_device,
41+
)
3742
from sbi.utils.user_input_checks import (
3843
validate_theta_and_x,
3944
)
@@ -465,3 +470,49 @@ def test_multiround_mdn_training_on_device(method: Union[NPE_A, NPE_C], device:
465470
proposal = trainer.build_posterior().set_default_x(torch.zeros(num_dim))
466471
theta = proposal.sample((num_simulations,))
467472
x = simulator(theta)
473+
474+
475+
@pytest.mark.gpu
476+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
477+
def test_conditioned_posterior_on_gpu(device: str, mcmc_params_fast: dict):
478+
device = process_device(device)
479+
num_dims = 3
480+
481+
proposal = BoxUniform(
482+
low=-torch.ones(num_dims, device=device),
483+
high=torch.ones(num_dims, device=device),
484+
)
485+
486+
inference = NPE_C(device=device, show_progress_bars=False)
487+
488+
num_simulations = 100
489+
theta = proposal.sample((num_simulations,))
490+
x = torch.randn_like(theta)
491+
x_o = torch.zeros(1, num_dims).to(device)
492+
inference = inference.append_simulations(theta, x)
493+
494+
estimator = inference.train(max_num_epochs=2)
495+
496+
# condition on one dim of theta
497+
condition_o = torch.ones(1, 1).to(device)
498+
prior = BoxUniform(
499+
low=-torch.ones(num_dims - 1, device=device),
500+
high=torch.ones(num_dims - 1, device=device),
501+
)
502+
prior_transform = utils.mcmc_transform(prior, device=device)
503+
504+
potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o)
505+
conditioned_potential_fn = potential_fn.condition_on_theta(
506+
condition_o, dims_global_theta=[0, 1]
507+
)
508+
509+
conditional_posterior = MCMCPosterior(
510+
potential_fn=conditioned_potential_fn,
511+
theta_transform=prior_transform,
512+
proposal=prior,
513+
device=device,
514+
**mcmc_params_fast,
515+
).set_default_x(x_o)
516+
samples = conditional_posterior.sample((1,), x=x_o)
517+
conditional_posterior.potential_fn(samples)
518+
conditional_posterior.map()

tests/mnle_test.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.
4040
return torch.cat((rts, choices), dim=1)
4141

4242

43-
def wrapped_simulator(
43+
def mixed_simulator_with_conditions(
4444
theta_and_condition: Tensor, last_idx_parameters: int = 2
4545
) -> Tensor:
46+
"""Simulator for mixed data with experimental conditions."""
4647
# simulate with experiment conditions
4748
theta = theta_and_condition[:, :last_idx_parameters]
4849
condition = theta_and_condition[:, last_idx_parameters:]
@@ -278,7 +279,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
278279
)
279280

280281
theta = proposal.sample((num_simulations,))
281-
x = wrapped_simulator(theta)
282+
x = mixed_simulator_with_conditions(theta)
282283
assert x.shape == (num_simulations, 2)
283284

284285
num_trials = 10
@@ -289,7 +290,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
289290
condition_o = theta_and_condition[:, 2:]
290291
theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)
291292

292-
x_o = wrapped_simulator(theta_and_conditions_o)
293+
x_o = mixed_simulator_with_conditions(theta_and_conditions_o)
293294

294295
mcmc_kwargs = dict(
295296
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
@@ -313,6 +314,9 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
313314
],
314315
validate_args=False,
315316
)
317+
# test theta with sample shape.
318+
conditioned_potential_fn(prior.sample((10,)).unsqueeze(0))
319+
316320
prior_transform = mcmc_transform(prior)
317321
true_posterior_samples = MCMCPosterior(
318322
BinomialGammaPotential(
@@ -339,14 +343,28 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
339343

340344
@pytest.mark.parametrize("num_thetas", [1, 10])
341345
@pytest.mark.parametrize("num_trials", [1, 5])
342-
@pytest.mark.parametrize("num_xs", [1, 3])
346+
@pytest.mark.parametrize(
347+
"num_xs",
348+
[
349+
1,
350+
pytest.param(
351+
2,
352+
marks=pytest.mark.xfail(
353+
reason="Batched x not supported for iid trials.",
354+
raises=NotImplementedError,
355+
),
356+
),
357+
],
358+
)
343359
@pytest.mark.parametrize(
344360
"num_conditions",
345361
[
346362
1,
347363
pytest.param(
348364
2,
349-
marks=pytest.mark.xfail(reason="Batched theta_condition is not supported"),
365+
marks=pytest.mark.xfail(
366+
reason="Batched theta_condition is not supported",
367+
),
350368
),
351369
],
352370
)
@@ -376,7 +394,7 @@ def test_log_likelihood_over_local_iid_theta(
376394

377395
num_simulations = 100
378396
theta = proposal.sample((num_simulations,))
379-
x = wrapped_simulator(theta)
397+
x = mixed_simulator_with_conditions(theta)
380398
estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)
381399

382400
# condition on multiple conditions
@@ -407,8 +425,10 @@ def test_log_likelihood_over_local_iid_theta(
407425
)
408426
x_i = x_o[i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1)
409427
ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition))
410-
ll_single = torch.stack(ll_single).sum(0) # sum over trials
428+
ll_single = (
429+
torch.stack(ll_single).sum(0).squeeze(0)
430+
) # sum over trials, squeeze x batch.
411431

412-
assert ll_batched.shape == torch.Size([num_xs, num_thetas])
432+
assert ll_batched.shape == torch.Size([num_thetas])
413433
assert ll_batched.shape == ll_single.shape
414434
assert torch.allclose(ll_batched, ll_single, atol=1e-5)

0 commit comments

Comments
 (0)