6
6
import pytest
7
7
import torch
8
8
from torch import eye , ones , zeros
9
- from torch .distributions import MultivariateNormal
9
+ from torch .distributions import Independent , MultivariateNormal , Uniform
10
10
11
11
from sbi .inference import (
12
12
NLE_A ,
@@ -98,13 +98,20 @@ def test_importance_posterior_sample_log_prob(snplre_method: type):
98
98
99
99
@pytest .mark .parametrize ("snpe_method" , [NPE_A , NPE_C ])
100
100
@pytest .mark .parametrize ("x_o_batch_dim" , (0 , 1 , 2 ))
101
+ @pytest .mark .parametrize ("prior" , ("mvn" , "uniform" ))
101
102
def test_batched_sample_log_prob_with_different_x (
102
- snpe_method : type , x_o_batch_dim : bool
103
+ snpe_method : type ,
104
+ x_o_batch_dim : bool ,
105
+ prior : str ,
103
106
):
104
107
num_dim = 2
105
108
num_simulations = 1000
106
109
107
- prior = MultivariateNormal (loc = zeros (num_dim ), covariance_matrix = eye (num_dim ))
110
+ # We also want to test on bounded support! Which will invoke leakage correction.
111
+ if prior == "mvn" :
112
+ prior = MultivariateNormal (loc = zeros (num_dim ), covariance_matrix = eye (num_dim ))
113
+ elif prior == "uniform" :
114
+ prior = Independent (Uniform (- 1.0 * ones (num_dim ), 1.0 * ones (num_dim )), 1 )
108
115
simulator = diagonal_linear_gaussian
109
116
110
117
inference = snpe_method (prior = prior )
@@ -116,6 +123,7 @@ def test_batched_sample_log_prob_with_different_x(
116
123
117
124
posterior = DirectPosterior (posterior_estimator = posterior_estimator , prior = prior )
118
125
126
+ torch .manual_seed (0 )
119
127
samples = posterior .sample_batched ((10 ,), x_o )
120
128
batched_log_probs = posterior .log_prob_batched (samples , x_o )
121
129
@@ -126,6 +134,20 @@ def test_batched_sample_log_prob_with_different_x(
126
134
), "Sample shape wrong"
127
135
assert batched_log_probs .shape == (10 , max (x_o_batch_dim , 1 )), "logprob shape wrong"
128
136
137
+ # Test consistency with non-batched log_prob
138
+ # NOTE: Leakage factor is a MC estimate, so we need to relax the tolerance here.
139
+ if x_o_batch_dim == 0 :
140
+ log_probs = posterior .log_prob (samples , x = x_o )
141
+ assert torch .allclose (
142
+ log_probs , batched_log_probs [:, 0 ], atol = 1e-1 , rtol = 1e-1
143
+ ), "Batched log probs different from non-batched log probs"
144
+ else :
145
+ for idx in range (x_o_batch_dim ):
146
+ log_probs = posterior .log_prob (samples [:, idx ], x = x_o [idx ])
147
+ assert torch .allclose (
148
+ log_probs , batched_log_probs [:, idx ], atol = 1e-1 , rtol = 1e-1
149
+ ), "Batched log probs different from non-batched log probs"
150
+
129
151
130
152
@pytest .mark .mcmc
131
153
@pytest .mark .parametrize ("snlre_method" , [NLE_A , NRE_A , NRE_B , NRE_C , NPE_C ])
0 commit comments