1
1
import pytest
2
2
from scipy .stats import uniform
3
- from torch import Tensor , allclose , exp , eye , ones
3
+ from torch import allclose , device , exp , eye , ones , zeros
4
4
from torch .distributions import Normal , Uniform
5
5
from torch .nn import L1Loss
6
6
@@ -141,12 +141,46 @@ def test_run_tarp_correct(distance, z_score_theta, accurate_samples):
141
141
num_bins = 30 ,
142
142
)
143
143
144
- assert allclose ((ecp - alpha ).abs ().max (), Tensor ([ 0.0 ] ), atol = 1e-1 )
144
+ assert allclose ((ecp - alpha ).abs ().max (), zeros (( 1 ,) ), atol = 1e-1 )
145
145
assert (
146
146
ecp - alpha
147
147
).abs ().sum () < 1.0 # integral of residuals should vanish, fig.2 in paper
148
148
149
149
150
+ @pytest .mark .gpu
151
+ def test_run_tarp_correct_on_cuda_device (accurate_samples ):
152
+ z_score_theta = True
153
+ distance = l2
154
+ dev = device ("cuda" )
155
+ theta , samples = accurate_samples
156
+ theta , samples = theta .to (dev ), samples .to (dev )
157
+
158
+ with pytest .raises (NotImplementedError ):
159
+ # let's make sure the execution problem is still there
160
+ # if torch fixes https://github.com/pytorch/pytorch/issues/69519
161
+ # this context manager should ensure, the case fails
162
+ # then we can fix the tarp code
163
+ from torch import histogram
164
+
165
+ histogram (zeros ((3 ,)).cuda (), bins = 4 )
166
+
167
+ references = get_tarp_references (theta ).to (dev )
168
+
169
+ ecp , alpha = _run_tarp (
170
+ samples ,
171
+ theta ,
172
+ references ,
173
+ distance = distance ,
174
+ z_score_theta = z_score_theta ,
175
+ num_bins = 30 ,
176
+ )
177
+
178
+ assert allclose ((ecp - alpha ).abs ().max (), zeros ((1 ,), device = dev ), atol = 1e-1 )
179
+ assert (
180
+ ecp - alpha
181
+ ).abs ().sum () < 1.05 # integral of residuals should vanish, fig.2 in paper
182
+
183
+
150
184
@pytest .mark .parametrize ("distance" , (l1 , l2 ))
151
185
def test_run_tarp_detect_overdispersed (distance , overdispersed_samples ):
152
186
theta , samples = overdispersed_samples
@@ -158,7 +192,7 @@ def test_run_tarp_detect_overdispersed(distance, overdispersed_samples):
158
192
159
193
# TARP detects that this is NOT a correct representation of the posterior
160
194
# hence we test for not allclose
161
- assert not allclose ((ecp - alpha ).abs ().max (), Tensor ([ 0.0 ] ), atol = 1e-1 )
195
+ assert not allclose ((ecp - alpha ).abs ().max (), zeros (( 1 ,) ), atol = 1e-1 )
162
196
assert (ecp - alpha ).abs ().sum () > 3.0 # integral is nonzero, fig.2 in paper
163
197
164
198
@@ -173,7 +207,7 @@ def test_run_tarp_detect_underdispersed(distance, underdispersed_samples):
173
207
174
208
# TARP detects that this is NOT a correct representation of the posterior
175
209
# hence we test for not allclose
176
- assert not allclose ((ecp - alpha ).abs ().max (), Tensor ([ 0.0 ] ), atol = 1e-1 )
210
+ assert not allclose ((ecp - alpha ).abs ().max (), zeros (( 1 ,) ), atol = 1e-1 )
177
211
assert (ecp - alpha ).abs ().sum () > 3.0 # integral is nonzero, fig.2 in paper
178
212
179
213
@@ -188,7 +222,7 @@ def test_run_tarp_detect_bias(distance, biased_samples):
188
222
189
223
# TARP detects that this is NOT a correct representation of the posterior
190
224
# hence we test for not allclose
191
- assert not allclose ((ecp - alpha ).abs ().max (), Tensor ([ 0.0 ] ), atol = 1e-1 )
225
+ assert not allclose ((ecp - alpha ).abs ().max (), zeros (( 1 ,) ), atol = 1e-1 )
192
226
assert (ecp - alpha ).abs ().sum () > 3.0 # integral is nonzero, fig.2 in paper
193
227
194
228
0 commit comments