File tree Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -179,10 +179,10 @@ def plt_kde_2d(
179
179
ax .imshow (
180
180
Z ,
181
181
extent = (
182
- limits_col [0 ],
183
- limits_col [1 ],
184
- limits_row [0 ],
185
- limits_row [1 ],
182
+ limits_col [0 ]. item () ,
183
+ limits_col [1 ]. item () ,
184
+ limits_row [0 ]. item () ,
185
+ limits_row [1 ]. item () ,
186
186
),
187
187
** offdiag_kwargs ["mpl_kwargs" ],
188
188
)
@@ -350,7 +350,7 @@ def get_offdiag_funcs(
350
350
def _format_subplot (
351
351
ax : Axes ,
352
352
current : str ,
353
- limits : Union [List , torch .Tensor ],
353
+ limits : Union [List [ List [ float ]] , torch .Tensor ],
354
354
ticks : Optional [Union [List , torch .Tensor ]],
355
355
labels_dim : List [str ],
356
356
fig_kwargs : Dict ,
@@ -384,6 +384,9 @@ def _format_subplot(
384
384
):
385
385
ax .set_facecolor (fig_kwargs ["fig_bg_colors" ][current ])
386
386
# Limits
387
+ if isinstance (limits , Tensor ):
388
+ assert limits .dim () == 2 , "Limits should be a 2D tensor."
389
+ limits = limits .tolist ()
387
390
if current == "diag" :
388
391
eps = fig_kwargs ["x_lim_add_eps" ]
389
392
ax .set_xlim ((limits [col ][0 ] - eps , limits [col ][1 ] + eps ))
You can’t perform that action at this time.
0 commit comments