Skip to content

Commit a6a220d

Browse files
fix: leakage correction for log prob batched (#1355)
* Fixing leakage correction inconsistency. * Improving test to cover batched log_prob on bounded support priors * Fixing test * Formating fix * add type
1 parent 9152e93 commit a6a220d

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

sbi/samplers/rejection/rejection.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,4 @@ def accept_reject_sample(
362362
samples.shape[0] == num_samples
363363
), "Number of accepted samples must match required samples."
364364

365-
# NOTE: Restriction prior does currently require a float as return for the
366-
# acceptance rate, which is why we for now also return the minimum acceptance rate.
367-
return samples, as_tensor(min_acceptance_rate)
365+
return samples, as_tensor(acceptance_rate)

sbi/utils/restriction_estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,11 @@ def sample(
692692
max_sampling_batch_size=max_sampling_batch_size,
693693
alternative_method="sample_with='sir'",
694694
)
695+
# NOTE: This currently requires a float acceptance rate. A previous version
696+
# of accept_reject_sample returned a float. In favour to batched sampling
697+
# it now returns a tensor.
698+
acceptance_rate = acceptance_rate.min().item()
699+
695700
if save_acceptance_rate:
696701
self.acceptance_rate = torch.as_tensor(acceptance_rate)
697702
if print_rejected_frac:

tests/posterior_nn_test.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import torch
88
from torch import eye, ones, zeros
9-
from torch.distributions import MultivariateNormal
9+
from torch.distributions import Independent, MultivariateNormal, Uniform
1010

1111
from sbi.inference import (
1212
NLE_A,
@@ -98,13 +98,20 @@ def test_importance_posterior_sample_log_prob(snplre_method: type):
9898

9999
@pytest.mark.parametrize("snpe_method", [NPE_A, NPE_C])
100100
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
101+
@pytest.mark.parametrize("prior", ("mvn", "uniform"))
101102
def test_batched_sample_log_prob_with_different_x(
102-
snpe_method: type, x_o_batch_dim: bool
103+
snpe_method: type,
104+
x_o_batch_dim: bool,
105+
prior: str,
103106
):
104107
num_dim = 2
105108
num_simulations = 1000
106109

107-
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
110+
# We also want to test on bounded support! Which will invoke leakage correction.
111+
if prior == "mvn":
112+
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
113+
elif prior == "uniform":
114+
prior = Independent(Uniform(-1.0 * ones(num_dim), 1.0 * ones(num_dim)), 1)
108115
simulator = diagonal_linear_gaussian
109116

110117
inference = snpe_method(prior=prior)
@@ -116,6 +123,7 @@ def test_batched_sample_log_prob_with_different_x(
116123

117124
posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior)
118125

126+
torch.manual_seed(0)
119127
samples = posterior.sample_batched((10,), x_o)
120128
batched_log_probs = posterior.log_prob_batched(samples, x_o)
121129

@@ -126,6 +134,20 @@ def test_batched_sample_log_prob_with_different_x(
126134
), "Sample shape wrong"
127135
assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)), "logprob shape wrong"
128136

137+
# Test consistency with non-batched log_prob
138+
# NOTE: Leakage factor is a MC estimate, so we need to relax the tolerance here.
139+
if x_o_batch_dim == 0:
140+
log_probs = posterior.log_prob(samples, x=x_o)
141+
assert torch.allclose(
142+
log_probs, batched_log_probs[:, 0], atol=1e-1, rtol=1e-1
143+
), "Batched log probs different from non-batched log probs"
144+
else:
145+
for idx in range(x_o_batch_dim):
146+
log_probs = posterior.log_prob(samples[:, idx], x=x_o[idx])
147+
assert torch.allclose(
148+
log_probs, batched_log_probs[:, idx], atol=1e-1, rtol=1e-1
149+
), "Batched log probs different from non-batched log probs"
150+
129151

130152
@pytest.mark.mcmc
131153
@pytest.mark.parametrize("snlre_method", [NLE_A, NRE_A, NRE_B, NRE_C, NPE_C])

0 commit comments

Comments
 (0)