Skip to content

Commit 15bf836

Browse files
committed
fix #1260: include points in plotting limits
1 parent becc93c commit 15bf836

File tree

2 files changed

+49
-25
lines changed

2 files changed

+49
-25
lines changed

sbi/analysis/plot.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -554,42 +554,66 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]:
554554
return samples
555555

556556

557+
def convert_to_list_of_numpy(
558+
arr: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
559+
) -> List[np.ndarray]:
560+
"""Converts a list of torch.Tensor to a list of np.ndarray."""
561+
if not isinstance(arr, list):
562+
arr = ensure_numpy(arr)
563+
return [arr]
564+
return [ensure_numpy(a) for a in arr]
565+
566+
567+
def infer_limits(
568+
samples: List[np.ndarray],
569+
dim: int,
570+
points: Optional[List[np.ndarray]] = None,
571+
eps: float = 0.1,
572+
) -> List:
573+
"""Infer limits for the plot.
574+
575+
Args:
576+
samples: List of samples.
577+
dim: Dimension of the samples.
578+
points: List of points.
579+
eps: Relative margin for the limits.
580+
"""
581+
limits = []
582+
for d in range(dim):
583+
min_val = min(np.min(sample[:, d]) for sample in samples)
584+
max_val = max(np.max(sample[:, d]) for sample in samples)
585+
if points is not None:
586+
min_val = min(min_val, min(np.min(point[:, d]) for point in points))
587+
max_val = max(max_val, max(np.max(point[:, d]) for point in points))
588+
limits.append([min_val * (1 + eps), max_val * (1 + eps)])
589+
return limits
590+
591+
557592
def prepare_for_plot(
558593
samples: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
559-
limits: Optional[Union[List, torch.Tensor, np.ndarray]],
594+
limits: Optional[Union[List, torch.Tensor, np.ndarray]] = None,
595+
points: Optional[
596+
Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor]
597+
] = None,
560598
) -> Tuple[List[np.ndarray], int, torch.Tensor]:
561599
"""
562600
Ensures correct formatting for samples and limits, and returns dimension
563601
of the samples.
564602
"""
565603

566-
# Prepare samples
567-
if not isinstance(samples, list):
568-
samples = ensure_numpy(samples)
569-
samples = [samples]
570-
else:
571-
samples = [ensure_numpy(sample) for sample in samples]
604+
samples = convert_to_list_of_numpy(samples)
605+
if points is not None:
606+
points = convert_to_list_of_numpy(points)
572607

573-
# check if nans and infs
574608
samples = handle_nan_infs(samples)
575609

576-
# Dimensionality of the problem.
577610
dim = samples[0].shape[1]
578611

579-
# Prepare limits. Infer them from samples if they had not been passed.
580-
if limits == [] or limits is None:
581-
limits = []
582-
for d in range(dim):
583-
min = +np.inf
584-
max = -np.inf
585-
for sample in samples:
586-
min_ = np.min(sample[:, d])
587-
min = min_ if min_ < min else min
588-
max_ = np.max(sample[:, d])
589-
max = max_ if max_ > max else max
590-
limits.append([min, max])
612+
if limits is None or limits == []:
613+
limits = infer_limits(samples, dim, points)
591614
else:
592615
limits = [limits[0] for _ in range(dim)] if len(limits) == 1 else limits
616+
593617
limits = torch.as_tensor(limits)
594618
return samples, dim, limits
595619

@@ -737,7 +761,7 @@ def pairplot(
737761
)
738762
return fig, axes
739763

740-
samples, dim, limits = prepare_for_plot(samples, limits)
764+
samples, dim, limits = prepare_for_plot(samples, limits, points)
741765

742766
# prepate figure kwargs
743767
fig_kwargs_filled = _get_default_fig_kwargs()

tests/plot_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
@pytest.mark.parametrize("samples", (torch.randn(100, 1),))
18-
@pytest.mark.parametrize("limits", ([(-1, 1)],))
18+
@pytest.mark.parametrize("limits", ([(-1, 1)], None))
1919
def test_pairplot1D(samples, limits):
2020
fig, axs = pairplot(**{k: v for k, v in locals().items() if v is not None})
2121
assert isinstance(fig, Figure)
@@ -24,7 +24,7 @@ def test_pairplot1D(samples, limits):
2424

2525

2626
@pytest.mark.parametrize("samples", (torch.randn(100, 2),))
27-
@pytest.mark.parametrize("limits", ([(-1, 1)],))
27+
@pytest.mark.parametrize("limits", ([(-1, 1)], None))
2828
def test_nan_inf(samples, limits):
2929
samples[0, 0] = np.nan
3030
samples[5, 1] = np.inf
@@ -37,7 +37,7 @@ def test_nan_inf(samples, limits):
3737

3838
@pytest.mark.parametrize("samples", (torch.randn(100, 2), [torch.randn(100, 3)] * 2))
3939
@pytest.mark.parametrize("points", (torch.ones(1, 3),))
40-
@pytest.mark.parametrize("limits", ([(-3, 3)],))
40+
@pytest.mark.parametrize("limits", ([(-3, 3)], None))
4141
@pytest.mark.parametrize("subset", (None, [0, 1]))
4242
@pytest.mark.parametrize("upper", ("scatter",))
4343
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)