Skip to content

Commit 2c409fb

Browse files
committed
fix typing in plotting
1 parent bc54032 commit 2c409fb

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

sbi/analysis/plot.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ def plt_kde_2d(
179179
ax.imshow(
180180
Z,
181181
extent=(
182-
limits_col[0],
183-
limits_col[1],
184-
limits_row[0],
185-
limits_row[1],
182+
limits_col[0].item(),
183+
limits_col[1].item(),
184+
limits_row[0].item(),
185+
limits_row[1].item(),
186186
),
187187
**offdiag_kwargs["mpl_kwargs"],
188188
)
@@ -350,7 +350,7 @@ def get_offdiag_funcs(
350350
def _format_subplot(
351351
ax: Axes,
352352
current: str,
353-
limits: Union[List, torch.Tensor],
353+
limits: Union[List[List[float]], torch.Tensor],
354354
ticks: Optional[Union[List, torch.Tensor]],
355355
labels_dim: List[str],
356356
fig_kwargs: Dict,
@@ -384,6 +384,9 @@ def _format_subplot(
384384
):
385385
ax.set_facecolor(fig_kwargs["fig_bg_colors"][current])
386386
# Limits
387+
if isinstance(limits, Tensor):
388+
assert limits.dim() == 2, "Limits should be a 2D tensor."
389+
limits = limits.tolist()
387390
if current == "diag":
388391
eps = fig_kwargs["x_lim_add_eps"]
389392
ax.set_xlim((limits[col][0] - eps, limits[col][1] + eps))

0 commit comments

Comments
 (0)