@@ -554,42 +554,56 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]:
554
554
return samples
555
555
556
556
557
+ def convert_to_list_of_numpy (
558
+ arr : Union [List [np .ndarray ], List [torch .Tensor ], np .ndarray , torch .Tensor ],
559
+ ) -> List [np .ndarray ]:
560
+ """Converts a list of torch.Tensor to a list of np.ndarray."""
561
+ if not isinstance (arr , list ):
562
+ arr = ensure_numpy (arr )
563
+ return [arr ]
564
+ return [ensure_numpy (a ) for a in arr ]
565
+
566
+
567
+ def infer_limits (
568
+ samples : List [np .ndarray ], dim : int , points : Optional [List [np .ndarray ]] = None
569
+ ) -> List :
570
+ """Infer limits for the plot."""
571
+ limits = []
572
+ for d in range (dim ):
573
+ min_val = min (np .min (sample [:, d ]) for sample in samples )
574
+ max_val = max (np .max (sample [:, d ]) for sample in samples )
575
+ if points is not None :
576
+ min_val = min (min_val , min (np .min (point [:, d ]) for point in points ))
577
+ max_val = max (max_val , max (np .max (point [:, d ]) for point in points ))
578
+ limits .append ([min_val , max_val ])
579
+ return limits
580
+
581
+
557
582
def prepare_for_plot (
558
583
samples : Union [List [np .ndarray ], List [torch .Tensor ], np .ndarray , torch .Tensor ],
559
- limits : Optional [Union [List , torch .Tensor , np .ndarray ]],
584
+ limits : Optional [Union [List , torch .Tensor , np .ndarray ]] = None ,
585
+ points : Optional [
586
+ Union [List [np .ndarray ], List [torch .Tensor ], np .ndarray , torch .Tensor ]
587
+ ] = None ,
560
588
) -> Tuple [List [np .ndarray ], int , torch .Tensor ]:
561
589
"""
562
590
Ensures correct formatting for samples and limits, and returns dimension
563
591
of the samples.
564
592
"""
565
593
566
- # Prepare samples
567
- if not isinstance (samples , list ):
568
- samples = ensure_numpy (samples )
569
- samples = [samples ]
570
- else :
571
- samples = [ensure_numpy (sample ) for sample in samples ]
594
+ samples = convert_to_list_of_numpy (samples )
595
+ if points is not None :
596
+ points = convert_to_list_of_numpy (points )
572
597
573
- # check if nans and infs
574
598
samples = handle_nan_infs (samples )
575
599
576
- # Dimensionality of the problem.
577
600
dim = samples [0 ].shape [1 ]
578
601
579
- # Prepare limits. Infer them from samples if they had not been passed.
580
- if limits == [] or limits is None :
581
- limits = []
582
- for d in range (dim ):
583
- min = + np .inf
584
- max = - np .inf
585
- for sample in samples :
586
- min_ = np .min (sample [:, d ])
587
- min = min_ if min_ < min else min
588
- max_ = np .max (sample [:, d ])
589
- max = max_ if max_ > max else max
590
- limits .append ([min , max ])
602
+ if limits is None or limits == []:
603
+ limits = infer_limits (samples , dim , points )
591
604
else :
592
605
limits = [limits [0 ] for _ in range (dim )] if len (limits ) == 1 else limits
606
+
593
607
limits = torch .as_tensor (limits )
594
608
return samples , dim , limits
595
609
@@ -737,7 +751,7 @@ def pairplot(
737
751
)
738
752
return fig , axes
739
753
740
- samples , dim , limits = prepare_for_plot (samples , limits )
754
+ samples , dim , limits = prepare_for_plot (samples , limits , points )
741
755
742
756
# prepate figure kwargs
743
757
fig_kwargs_filled = _get_default_fig_kwargs ()
@@ -978,6 +992,7 @@ def _get_default_diag_kwargs(diag: Optional[str], i: int = 0) -> Dict:
978
992
"density" : False ,
979
993
"histtype" : "step" ,
980
994
},
995
+ "bins" : "auto" ,
981
996
}
982
997
elif diag == "scatter" :
983
998
diag_kwargs = {
0 commit comments