Skip to content

Commit 1bb0298

Browse files
committed
test: add sbc and tarp plotting tests
1 parent 3a999de commit 1bb0298

File tree

4 files changed

+76
-22
lines changed

4 files changed

+76
-22
lines changed

sbi/analysis/plot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2094,7 +2094,7 @@ def pp_plot_lc2st(
20942094
)
20952095

20962096

2097-
def plot_tarp(ecp: Tensor, alpha: Tensor, title="") -> Tuple[Figure, Axes]:
2097+
def plot_tarp(ecp: Tensor, alpha: Tensor, title: Optional[str]) -> Tuple[Figure, Axes]:
20982098
"""
20992099
Plots the expected coverage probability (ECP) against the credibility
21002100
level,alpha, for a given alpha grid.
@@ -2117,6 +2117,8 @@ def plot_tarp(ecp: Tensor, alpha: Tensor, title="") -> Tuple[Figure, Axes]:
21172117

21182118
fig = plt.figure(figsize=(6, 6))
21192119
ax: Axes = plt.gca()
2120+
if title is None:
2121+
title = ""
21202122

21212123
ax.plot(alpha, ecp, color="blue", label="TARP")
21222124
ax.plot(alpha, alpha, color="black", linestyle="--", label="ideal")

tests/sbc_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33

44
from __future__ import annotations
55

6+
from typing import Union
7+
68
import pytest
79
import torch
810
from torch import eye, ones, zeros
911
from torch.distributions import MultivariateNormal, Uniform
1012

13+
from sbi.analysis import sbc_rank_plot
1114
from sbi.diagnostics import check_sbc, get_nltp, run_sbc
1215
from sbi.inference import SNLE, SNPE, simulate_for_sbi
1316
from sbi.simulators import linear_gaussian
@@ -208,3 +211,32 @@ def test_sbc_checks():
208211
assert (checks["ks_pvals"] > 0.05).all()
209212
assert (checks["c2st_ranks"] < 0.55).all()
210213
assert (checks["c2st_dap"] < 0.55).all()
214+
215+
216+
# add test for sbc plotting
217+
@pytest.mark.parametrize("num_bins", (None, 30))
218+
@pytest.mark.parametrize("plot_type", ("cdf", "hist"))
219+
@pytest.mark.parametrize("legend_kwargs", (None, {"loc": "upper left"}))
220+
@pytest.mark.parametrize("num_rank_sets", (1, 2))
221+
def test_sbc_plotting(
222+
num_bins: int, plot_type: str, legend_kwargs: Union[None, dict], num_rank_sets: int
223+
):
224+
"""Test the uniformity checks for SBC."""
225+
226+
num_dim = 2
227+
num_posterior_samples = 1000
228+
229+
# Ranks should be distributed uniformly in [0, num_posterior_samples]
230+
ranks = [
231+
torch.distributions.Uniform(
232+
zeros(num_dim), num_posterior_samples * ones(num_dim)
233+
).sample((num_posterior_samples,))
234+
] * num_rank_sets
235+
236+
sbc_rank_plot(
237+
ranks,
238+
num_posterior_samples,
239+
num_bins=num_bins,
240+
plot_type=plot_type,
241+
legend_kwargs=legend_kwargs,
242+
)

tests/tarp_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.distributions import Normal, Uniform
55
from torch.nn import L1Loss
66

7+
from sbi.analysis.plot import plot_tarp
78
from sbi.diagnostics.tarp import _run_tarp, check_tarp, get_tarp_references, run_tarp
89
from sbi.inference import SNPE
910
from sbi.simulators import linear_gaussian
@@ -286,3 +287,14 @@ def simulator(theta):
286287
atc, kspvals = check_tarp(ecp, alpha)
287288
assert -0.5 < atc < 0.5
288289
assert kspvals > 0.05
290+
291+
292+
# Test tarp plotting
293+
@pytest.mark.parametrize("title", ["Correct", None])
294+
def test_tarp_plotting(title: str, accurate_samples):
295+
theta, samples = accurate_samples
296+
references = get_tarp_references(theta)
297+
298+
ecp, alpha = _run_tarp(samples, theta, references)
299+
300+
plot_tarp(ecp, alpha, title=title)

0 commit comments

Comments
 (0)