Skip to content

Commit ce31030

Browse files
authored
fix: patch torch bug in tarp, run torch.histogram on cpu (#1596)
* patch for torch bug in tarp, run torch.histogram on with cpu-only tensor * pin gpu tests to one param combination * more clear variable naming * ruff formatting
1 parent 2e1509e commit ce31030

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

sbi/diagnostics/tarp.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def _run_tarp(
138138
139139
"""
140140
num_posterior_samples, num_tarp_samples, _ = posterior_samples.shape
141+
input_device = posterior_samples.device
141142

142143
assert references.shape == thetas.shape, (
143144
"references must have the same shape as thetas"
@@ -162,11 +163,20 @@ def _run_tarp(
162163
coverage_values = (
163164
torch.sum(sample_dists < theta_dists, dim=0) / num_posterior_samples
164165
)
165-
hist, alpha_grid = torch.histogram(coverage_values, density=True, bins=num_bins)
166+
167+
# enforce execution on the CPU due to
168+
# https://github.com/pytorch/pytorch/issues/69519
169+
hist, alpha_grid = torch.histogram(
170+
coverage_values.cpu(), density=True, bins=num_bins
171+
)
172+
173+
# return all tensors to input_device to keep contract valid
174+
hist, alpha_grid = hist.to(input_device), alpha_grid.to(input_device)
175+
166176
# calculate empirical CDF via cumsum and normalize
167177
ecp = torch.cumsum(hist, dim=0) / hist.sum()
168178
# add 0 to the beginning of the ecp curve to match the alpha grid
169-
ecp = torch.cat([Tensor([0]), ecp])
179+
ecp = torch.cat([torch.zeros((1,), device=input_device), ecp])
170180

171181
return ecp, alpha_grid
172182

tests/tarp_test.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from scipy.stats import uniform
3-
from torch import Tensor, allclose, exp, eye, ones
3+
from torch import allclose, device, exp, eye, ones, zeros
44
from torch.distributions import Normal, Uniform
55
from torch.nn import L1Loss
66

@@ -141,12 +141,46 @@ def test_run_tarp_correct(distance, z_score_theta, accurate_samples):
141141
num_bins=30,
142142
)
143143

144-
assert allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1)
144+
assert allclose((ecp - alpha).abs().max(), zeros((1,)), atol=1e-1)
145145
assert (
146146
ecp - alpha
147147
).abs().sum() < 1.0 # integral of residuals should vanish, fig.2 in paper
148148

149149

150+
@pytest.mark.gpu
151+
def test_run_tarp_correct_on_cuda_device(accurate_samples):
152+
z_score_theta = True
153+
distance = l2
154+
dev = device("cuda")
155+
theta, samples = accurate_samples
156+
theta, samples = theta.to(dev), samples.to(dev)
157+
158+
with pytest.raises(NotImplementedError):
159+
# let's make sure the execution problem is still there
160+
# if torch fixes https://github.com/pytorch/pytorch/issues/69519
161+
# this context manager should ensure, the case fails
162+
# then we can fix the tarp code
163+
from torch import histogram
164+
165+
histogram(zeros((3,)).cuda(), bins=4)
166+
167+
references = get_tarp_references(theta).to(dev)
168+
169+
ecp, alpha = _run_tarp(
170+
samples,
171+
theta,
172+
references,
173+
distance=distance,
174+
z_score_theta=z_score_theta,
175+
num_bins=30,
176+
)
177+
178+
assert allclose((ecp - alpha).abs().max(), zeros((1,), device=dev), atol=1e-1)
179+
assert (
180+
ecp - alpha
181+
).abs().sum() < 1.05 # integral of residuals should vanish, fig.2 in paper
182+
183+
150184
@pytest.mark.parametrize("distance", (l1, l2))
151185
def test_run_tarp_detect_overdispersed(distance, overdispersed_samples):
152186
theta, samples = overdispersed_samples
@@ -158,7 +192,7 @@ def test_run_tarp_detect_overdispersed(distance, overdispersed_samples):
158192

159193
# TARP detects that this is NOT a correct representation of the posterior
160194
# hence we test for not allclose
161-
assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1)
195+
assert not allclose((ecp - alpha).abs().max(), zeros((1,)), atol=1e-1)
162196
assert (ecp - alpha).abs().sum() > 3.0 # integral is nonzero, fig.2 in paper
163197

164198

@@ -173,7 +207,7 @@ def test_run_tarp_detect_underdispersed(distance, underdispersed_samples):
173207

174208
# TARP detects that this is NOT a correct representation of the posterior
175209
# hence we test for not allclose
176-
assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1)
210+
assert not allclose((ecp - alpha).abs().max(), zeros((1,)), atol=1e-1)
177211
assert (ecp - alpha).abs().sum() > 3.0 # integral is nonzero, fig.2 in paper
178212

179213

@@ -188,7 +222,7 @@ def test_run_tarp_detect_bias(distance, biased_samples):
188222

189223
# TARP detects that this is NOT a correct representation of the posterior
190224
# hence we test for not allclose
191-
assert not allclose((ecp - alpha).abs().max(), Tensor([0.0]), atol=1e-1)
225+
assert not allclose((ecp - alpha).abs().max(), zeros((1,)), atol=1e-1)
192226
assert (ecp - alpha).abs().sum() > 3.0 # integral is nonzero, fig.2 in paper
193227

194228

0 commit comments

Comments
 (0)