@@ -554,42 +554,66 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]:
554
554
return samples
555
555
556
556
557
+ def convert_to_list_of_numpy (
558
+ arr : Union [List [np .ndarray ], List [torch .Tensor ], np .ndarray , torch .Tensor ],
559
+ ) -> List [np .ndarray ]:
560
+ """Converts a list of torch.Tensor to a list of np.ndarray."""
561
+ if not isinstance (arr , list ):
562
+ arr = ensure_numpy (arr )
563
+ return [arr ]
564
+ return [ensure_numpy (a ) for a in arr ]
565
+
566
+
567
+ def infer_limits (
568
+ samples : List [np .ndarray ],
569
+ dim : int ,
570
+ points : Optional [List [np .ndarray ]] = None ,
571
+ eps : float = 0.1 ,
572
+ ) -> List :
573
+ """Infer limits for the plot.
574
+
575
+ Args:
576
+ samples: List of samples.
577
+ dim: Dimension of the samples.
578
+ points: List of points.
579
+ eps: Relative margin for the limits.
580
+ """
581
+ limits = []
582
+ for d in range (dim ):
583
+ min_val = min (np .min (sample [:, d ]) for sample in samples )
584
+ max_val = max (np .max (sample [:, d ]) for sample in samples )
585
+ if points is not None :
586
+ min_val = min (min_val , min (np .min (point [:, d ]) for point in points ))
587
+ max_val = max (max_val , max (np .max (point [:, d ]) for point in points ))
588
+ limits .append ([min_val * (1 + eps ), max_val * (1 + eps )])
589
+ return limits
590
+
591
+
557
592
def prepare_for_plot (
558
593
samples : Union [List [np .ndarray ], List [torch .Tensor ], np .ndarray , torch .Tensor ],
559
- limits : Optional [Union [List , torch .Tensor , np .ndarray ]],
594
+ limits : Optional [Union [List , torch .Tensor , np .ndarray ]] = None ,
595
+ points : Optional [
596
+ Union [List [np .ndarray ], List [torch .Tensor ], np .ndarray , torch .Tensor ]
597
+ ] = None ,
560
598
) -> Tuple [List [np .ndarray ], int , torch .Tensor ]:
561
599
"""
562
600
Ensures correct formatting for samples and limits, and returns dimension
563
601
of the samples.
564
602
"""
565
603
566
- # Prepare samples
567
- if not isinstance (samples , list ):
568
- samples = ensure_numpy (samples )
569
- samples = [samples ]
570
- else :
571
- samples = [ensure_numpy (sample ) for sample in samples ]
604
+ samples = convert_to_list_of_numpy (samples )
605
+ if points is not None :
606
+ points = convert_to_list_of_numpy (points )
572
607
573
- # check if nans and infs
574
608
samples = handle_nan_infs (samples )
575
609
576
- # Dimensionality of the problem.
577
610
dim = samples [0 ].shape [1 ]
578
611
579
- # Prepare limits. Infer them from samples if they had not been passed.
580
- if limits == [] or limits is None :
581
- limits = []
582
- for d in range (dim ):
583
- min = + np .inf
584
- max = - np .inf
585
- for sample in samples :
586
- min_ = np .min (sample [:, d ])
587
- min = min_ if min_ < min else min
588
- max_ = np .max (sample [:, d ])
589
- max = max_ if max_ > max else max
590
- limits .append ([min , max ])
612
+ if limits is None or limits == []:
613
+ limits = infer_limits (samples , dim , points )
591
614
else :
592
615
limits = [limits [0 ] for _ in range (dim )] if len (limits ) == 1 else limits
616
+
593
617
limits = torch .as_tensor (limits )
594
618
return samples , dim , limits
595
619
@@ -737,7 +761,7 @@ def pairplot(
737
761
)
738
762
return fig , axes
739
763
740
- samples , dim , limits = prepare_for_plot (samples , limits )
764
+ samples , dim , limits = prepare_for_plot (samples , limits , points )
741
765
742
766
# prepate figure kwargs
743
767
fig_kwargs_filled = _get_default_fig_kwargs ()
0 commit comments