Skip to content

Commit fd29c41

Browse files
committed
fix #1260: include points in plotting limits
1 parent bd740d6 commit fd29c41

File tree

1 file changed

+37
-22
lines changed

1 file changed

+37
-22
lines changed

sbi/analysis/plot.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -554,42 +554,56 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]:
554554
return samples
555555

556556

557+
def convert_to_list_of_numpy(
558+
arr: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
559+
) -> List[np.ndarray]:
560+
"""Converts a list of torch.Tensor to a list of np.ndarray."""
561+
if not isinstance(arr, list):
562+
arr = ensure_numpy(arr)
563+
return [arr]
564+
return [ensure_numpy(a) for a in arr]
565+
566+
567+
def infer_limits(
568+
samples: List[np.ndarray], dim: int, points: Optional[List[np.ndarray]] = None
569+
) -> List:
570+
"""Infer limits for the plot."""
571+
limits = []
572+
for d in range(dim):
573+
min_val = min(np.min(sample[:, d]) for sample in samples)
574+
max_val = max(np.max(sample[:, d]) for sample in samples)
575+
if points is not None:
576+
min_val = min(min_val, min(np.min(point[:, d]) for point in points))
577+
max_val = max(max_val, max(np.max(point[:, d]) for point in points))
578+
limits.append([min_val, max_val])
579+
return limits
580+
581+
557582
def prepare_for_plot(
558583
samples: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
559-
limits: Optional[Union[List, torch.Tensor, np.ndarray]],
584+
limits: Optional[Union[List, torch.Tensor, np.ndarray]] = None,
585+
points: Optional[
586+
Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor]
587+
] = None,
560588
) -> Tuple[List[np.ndarray], int, torch.Tensor]:
561589
"""
562590
Ensures correct formatting for samples and limits, and returns dimension
563591
of the samples.
564592
"""
565593

566-
# Prepare samples
567-
if not isinstance(samples, list):
568-
samples = ensure_numpy(samples)
569-
samples = [samples]
570-
else:
571-
samples = [ensure_numpy(sample) for sample in samples]
594+
samples = convert_to_list_of_numpy(samples)
595+
if points is not None:
596+
points = convert_to_list_of_numpy(points)
572597

573-
# check if nans and infs
574598
samples = handle_nan_infs(samples)
575599

576-
# Dimensionality of the problem.
577600
dim = samples[0].shape[1]
578601

579-
# Prepare limits. Infer them from samples if they had not been passed.
580-
if limits == [] or limits is None:
581-
limits = []
582-
for d in range(dim):
583-
min = +np.inf
584-
max = -np.inf
585-
for sample in samples:
586-
min_ = np.min(sample[:, d])
587-
min = min_ if min_ < min else min
588-
max_ = np.max(sample[:, d])
589-
max = max_ if max_ > max else max
590-
limits.append([min, max])
602+
if limits is None or limits == []:
603+
limits = infer_limits(samples, dim, points)
591604
else:
592605
limits = [limits[0] for _ in range(dim)] if len(limits) == 1 else limits
606+
593607
limits = torch.as_tensor(limits)
594608
return samples, dim, limits
595609

@@ -737,7 +751,7 @@ def pairplot(
737751
)
738752
return fig, axes
739753

740-
samples, dim, limits = prepare_for_plot(samples, limits)
754+
samples, dim, limits = prepare_for_plot(samples, limits, points)
741755

742756
# prepate figure kwargs
743757
fig_kwargs_filled = _get_default_fig_kwargs()
@@ -978,6 +992,7 @@ def _get_default_diag_kwargs(diag: Optional[str], i: int = 0) -> Dict:
978992
"density": False,
979993
"histtype": "step",
980994
},
995+
"bins": "auto",
981996
}
982997
elif diag == "scatter":
983998
diag_kwargs = {

0 commit comments

Comments
 (0)