Skip to content

Commit 3a999de

Browse files
committed
refactor: move tarp plotting to analysis, fix tutorial.
1 parent ba6c1e6 commit 3a999de

File tree

4 files changed

+224
-556
lines changed

4 files changed

+224
-556
lines changed

sbi/analysis/plot.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,6 @@ def _sbc_rank_plot(
15101510
line_alpha: float = 0.8,
15111511
show_uniform_region: bool = True,
15121512
uniform_region_alpha: float = 0.3,
1513-
uniform_region_color: str = "gray",
15141513
xlim_offset_factor: float = 0.1,
15151514
num_cols: int = 4,
15161515
params_in_subplots: bool = False,
@@ -1571,7 +1570,7 @@ def _sbc_rank_plot(
15711570
), "plot type {plot_type} not implemented, use one in {plot_types}."
15721571

15731572
if legend_kwargs is None:
1574-
legend_kwargs = dict(loc=1, handlelength=0.8)
1573+
legend_kwargs = dict(loc="best", handlelength=0.8)
15751574

15761575
num_sbc_runs, num_parameters = ranks_list[0].shape
15771576
num_ranks = len(ranks_list)
@@ -2095,6 +2094,41 @@ def pp_plot_lc2st(
20952094
)
20962095

20972096

2097+
def plot_tarp(ecp: Tensor, alpha: Tensor, title="") -> Tuple[Figure, Axes]:
2098+
"""
2099+
Plots the expected coverage probability (ECP) against the credibility
2100+
level,alpha, for a given alpha grid.
2101+
2102+
Args:
2103+
ecp : numpy.ndarray
2104+
Array of expected coverage probabilities.
2105+
alpha : numpy.ndarray
2106+
Array of credibility levels.
2107+
title : str, optional
2108+
Title for the plot. The default is "".
2109+
2110+
Returns
2111+
fig : matplotlib.figure.Figure
2112+
The figure object.
2113+
ax : matplotlib.axes.Axes
2114+
The axes object.
2115+
2116+
"""
2117+
2118+
fig = plt.figure(figsize=(6, 6))
2119+
ax: Axes = plt.gca()
2120+
2121+
ax.plot(alpha, ecp, color="blue", label="TARP")
2122+
ax.plot(alpha, alpha, color="black", linestyle="--", label="ideal")
2123+
ax.set_xlabel(r"Credibility Level $\alpha$")
2124+
ax.set_ylabel(r"Expected Coverage Probility")
2125+
ax.set_xlim(0.0, 1.0)
2126+
ax.set_ylim(0.0, 1.0)
2127+
ax.set_title(title)
2128+
ax.legend()
2129+
return fig, ax # type: ignore
2130+
2131+
20982132
# TO BE DEPRECATED
20992133
# ----------------
21002134
def pairplot_dep(

sbi/diagnostics/tarp.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
from typing import Callable, Optional, Tuple
1010

11-
import matplotlib.pyplot as plt
1211
import torch
13-
from matplotlib.axes import Axes
14-
from matplotlib.figure import Figure
1512
from scipy.stats import kstest
1613
from torch import Tensor
1714

@@ -216,38 +213,3 @@ def check_tarp(
216213
kstest_pvals: float = kstest(ecp.numpy(), alpha.numpy())[1] # type: ignore
217214

218215
return atc, kstest_pvals
219-
220-
221-
def plot_tarp(ecp: Tensor, alpha: Tensor, title="") -> Tuple[Figure, Axes]:
222-
"""
223-
Plots the expected coverage probability (ECP) against the credibility
224-
level,alpha, for a given alpha grid.
225-
226-
Args:
227-
ecp : numpy.ndarray
228-
Array of expected coverage probabilities.
229-
alpha : numpy.ndarray
230-
Array of credibility levels.
231-
title : str, optional
232-
Title for the plot. The default is "".
233-
234-
Returns
235-
fig : matplotlib.figure.Figure
236-
The figure object.
237-
ax : matplotlib.axes.Axes
238-
The axes object.
239-
240-
"""
241-
242-
fig = plt.figure(figsize=(6, 6))
243-
ax: Axes = plt.gca()
244-
245-
ax.plot(alpha, ecp, color="blue", label="TARP")
246-
ax.plot(alpha, alpha, color="black", linestyle="--", label="ideal")
247-
ax.set_xlabel(r"Credibility Level $\alpha$")
248-
ax.set_ylabel(r"Expected Coverage Probility")
249-
ax.set_xlim(0.0, 1.0)
250-
ax.set_ylim(0.0, 1.0)
251-
ax.set_title(title)
252-
ax.legend()
253-
return fig, ax # type: ignore

0 commit comments

Comments
 (0)