Skip to content

Commit 7738c1a

Browse files
committed
add sensitivity analysis tests.
1 parent cb930e5 commit 7738c1a

File tree

1 file changed

+300
-0
lines changed

1 file changed

+300
-0
lines changed

tests/sensitivity_analysis_test.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
from typing import Tuple
2+
3+
import pytest
4+
import torch
5+
from torch import Tensor, nn
6+
7+
from sbi.analysis.sensitivity_analysis import (
8+
ActiveSubspace,
9+
Destandardize,
10+
build_input_output_layer,
11+
destandardizing_net,
12+
)
13+
14+
# ------------------------
15+
# Fixtures and test helpers
16+
# ------------------------
17+
18+
19+
@pytest.fixture
20+
def toy_theta_property() -> Tuple[Tensor, Tensor]:
21+
# Small synthetic regression task: y = sum(theta) + noise
22+
n, d = 64, 3
23+
theta = torch.randn(n, d)
24+
y = theta.sum(dim=1, keepdim=True) + 0.05 * torch.randn(n, 1)
25+
return theta, y
26+
27+
28+
class _PriorWithStats:
29+
def __init__(self, d: int):
30+
self.mean = torch.zeros(d)
31+
self.stddev = torch.ones(d)
32+
33+
def sample(self, shape: Tuple[int, ...]) -> Tensor:
34+
return torch.randn(*shape, self.mean.numel())
35+
36+
37+
class _PriorWithoutStats:
38+
def __init__(self, d: int):
39+
self._d = d
40+
41+
def sample(self, shape: Tuple[int, ...]) -> Tensor:
42+
return torch.randn(*shape, self._d)
43+
44+
45+
class _PosteriorStub:
46+
def __init__(self, d: int, with_stats: bool = True):
47+
self._device = "cpu"
48+
self.d = d
49+
self.prior = _PriorWithStats(d) if with_stats else _PriorWithoutStats(d)
50+
51+
def sample(self, shape: Tuple[int, ...]) -> Tensor:
52+
return torch.randn(*shape, self.d, requires_grad=False)
53+
54+
def potential(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
55+
# Smooth, concave potential with gradient everywhere
56+
# Return shape (N, 1) to match regression output style
57+
return -(theta**2).sum(dim=1, keepdim=True) * 0.5
58+
59+
60+
@pytest.fixture(
61+
params=[
62+
pytest.param(True, id="prior_with_stats"),
63+
pytest.param(False, id="prior_without_stats"),
64+
]
65+
)
66+
def posterior_stub(request) -> _PosteriorStub:
67+
return _PosteriorStub(d=3, with_stats=request.param)
68+
69+
70+
@pytest.fixture
71+
def embedding_net_theta() -> nn.Module:
72+
# Lightweight embedding to exercise embedding pathway
73+
return nn.Sequential(nn.Linear(3, 3), nn.ReLU())
74+
75+
76+
# ------------------------
77+
# Utility layer tests
78+
# ------------------------
79+
80+
81+
def test_destandardize_and_destandardizing_net_forward() -> None:
82+
# Create a batch with near-zero variance in one dim to exercise min-std flooring
83+
n, d = 50, 2
84+
col0 = torch.randn(n, 1)
85+
col1 = torch.zeros(n, 1)
86+
batch = torch.cat([col0, col1], dim=1)
87+
88+
min_std = 0.5
89+
net = destandardizing_net(batch, min_std=min_std)
90+
# Expected mean and clamped std
91+
mean = batch.mean(dim=0)
92+
std = batch.std(dim=0)
93+
std = torch.where(std < min_std, torch.full_like(std, min_std), std)
94+
95+
# Check that zero maps to mean, ones maps to mean + std
96+
out0 = net(torch.zeros(1, d))
97+
out1 = net(torch.ones(1, d))
98+
assert torch.allclose(out0, mean.unsqueeze(0), atol=1e-6)
99+
assert torch.allclose(out1, (mean + std).unsqueeze(0), atol=1e-6)
100+
101+
# Also test Destandardize directly
102+
dn = Destandardize(mean, std)
103+
assert torch.allclose(dn(torch.zeros(1, d)), out0)
104+
105+
106+
def test_destandardizing_net_single_sample_branch() -> None:
107+
# When batch has a single row, we enter the else-branch and use t_std = 1
108+
batch = torch.tensor([[2.0, -3.0]], dtype=torch.float32)
109+
net = destandardizing_net(batch, min_std=0.5)
110+
# For standardized input 0, output should equal mean (the single row)
111+
out0 = net(torch.zeros(1, 2))
112+
assert torch.allclose(out0, batch, atol=1e-6)
113+
# For standardized input 1, output should be mean + std (std == 1 here)
114+
out1 = net(torch.ones(1, 2))
115+
assert torch.allclose(out1, batch + 1.0, atol=1e-6)
116+
117+
118+
@pytest.mark.parametrize(
119+
"z_theta, z_prop",
120+
[
121+
(True, True),
122+
(True, False),
123+
(False, True),
124+
(False, False),
125+
],
126+
)
127+
def test_build_input_output_layer_shapes_and_types(
128+
toy_theta_property, embedding_net_theta, z_theta, z_prop
129+
) -> None:
130+
"""Sanity check that the built input-output layer can be composed
131+
132+
Args:
133+
toy_theta_property: Fixture providing (theta, property) data
134+
embedding_net_theta: Fixture providing a small embedding net for theta
135+
z_theta: Whether to z-score standardize theta
136+
z_prop: Whether to z-score standardize the property
137+
"""
138+
theta, y = toy_theta_property
139+
inp, out = build_input_output_layer(
140+
batch_theta=theta,
141+
batch_property=y,
142+
z_score_theta=z_theta,
143+
z_score_property=z_prop,
144+
embedding_net_theta=embedding_net_theta,
145+
)
146+
# Compose a tiny regression head to ensure shape compatibility end-to-end
147+
head = nn.Linear(theta.shape[1], 1)
148+
model = nn.Sequential(inp, head, out)
149+
preds = model(theta)
150+
assert preds.shape == y.shape
151+
152+
153+
# ------------------------
154+
# ActiveSubspace.add_property tests
155+
# ------------------------
156+
157+
158+
@pytest.mark.parametrize("model_name", ["mlp", "resnet"])
159+
@pytest.mark.parametrize("clip_max_norm", [None, 5.0])
160+
def test_add_property_and_train_models(
161+
model_name,
162+
clip_max_norm,
163+
toy_theta_property,
164+
posterior_stub,
165+
embedding_net_theta,
166+
) -> None:
167+
theta, y = toy_theta_property
168+
a = ActiveSubspace(posterior_stub)
169+
a.add_property(
170+
theta=theta,
171+
emergent_property=y,
172+
model=model_name,
173+
hidden_features=16,
174+
num_blocks=1,
175+
dropout_probability=0.1,
176+
z_score_theta=True,
177+
z_score_property=True,
178+
embedding_net=embedding_net_theta,
179+
)
180+
net = a.train(
181+
training_batch_size=16,
182+
learning_rate=1e-3,
183+
validation_fraction=0.25,
184+
stop_after_epochs=2,
185+
max_num_epochs=3,
186+
clip_max_norm=clip_max_norm,
187+
)
188+
assert isinstance(net, nn.Module)
189+
assert len(a._validation_log_probs) >= 1
190+
191+
192+
def test_add_property_with_callable_model(toy_theta_property, posterior_stub) -> None:
193+
theta, y = toy_theta_property
194+
195+
def builder(batch_theta: Tensor) -> nn.Module:
196+
d = batch_theta.shape[1]
197+
return nn.Sequential(nn.Identity(), nn.Linear(d, 1))
198+
199+
a = ActiveSubspace(posterior_stub)
200+
a.add_property(theta=theta, emergent_property=y, model=builder)
201+
net = a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2)
202+
assert isinstance(net, nn.Module)
203+
204+
205+
def test_add_property_invalid_model_raises(toy_theta_property, posterior_stub) -> None:
206+
theta, y = toy_theta_property
207+
a = ActiveSubspace(posterior_stub)
208+
with pytest.raises(NameError):
209+
a.add_property(theta=theta, emergent_property=y, model="unknown")
210+
211+
212+
def test_train_reuses_existing_net(toy_theta_property, posterior_stub) -> None:
213+
theta, y = toy_theta_property
214+
a = ActiveSubspace(posterior_stub)
215+
a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8)
216+
_ = a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2)
217+
first_id = id(a._regression_net)
218+
_ = a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2)
219+
assert id(a._regression_net) == first_id # net is reused, not rebuilt
220+
221+
222+
# ------------------------
223+
# ActiveSubspace.find_directions tests
224+
# ------------------------
225+
226+
227+
@pytest.mark.parametrize("norm_gradients", [True, False])
228+
def test_find_directions_with_regression_net(
229+
norm_gradients, toy_theta_property, posterior_stub
230+
) -> None:
231+
"""Tests that find_directions runs and returns correctly shaped outputs."""
232+
theta, y = toy_theta_property
233+
a = ActiveSubspace(posterior_stub)
234+
a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8)
235+
a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2)
236+
237+
evals, evecs = a.find_directions(
238+
posterior_log_prob_as_property=False,
239+
norm_gradients_to_prior=norm_gradients,
240+
num_monte_carlo_samples=128,
241+
)
242+
d = theta.shape[1]
243+
assert evals.shape == (d,)
244+
assert evecs.shape == (d, d)
245+
# Ascending order
246+
assert torch.all(evals[1:] >= evals[:-1])
247+
# Columns are unit vectors
248+
assert torch.allclose(torch.linalg.norm(evecs, dim=0), torch.ones(d), atol=1e-5)
249+
250+
251+
def test_find_directions_with_posterior_log_prob_warns(
252+
toy_theta_property, posterior_stub
253+
) -> None:
254+
theta, y = toy_theta_property
255+
a = ActiveSubspace(posterior_stub)
256+
a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8)
257+
a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2)
258+
259+
with pytest.warns(UserWarning):
260+
evals, evecs = a.find_directions(
261+
posterior_log_prob_as_property=True,
262+
norm_gradients_to_prior=True,
263+
num_monte_carlo_samples=64,
264+
)
265+
assert evals.numel() == evecs.shape[0]
266+
267+
268+
def test_find_directions_raises_without_property(posterior_stub) -> None:
269+
a = ActiveSubspace(posterior_stub)
270+
with pytest.raises(ValueError):
271+
_ = a.find_directions(
272+
posterior_log_prob_as_property=False, num_monte_carlo_samples=8
273+
)
274+
275+
276+
# ------------------------
277+
# ActiveSubspace.project tests
278+
# ------------------------
279+
280+
281+
@pytest.mark.parametrize("norm_gradients", [True, False])
282+
def test_project_after_find_directions(
283+
norm_gradients, toy_theta_property, posterior_stub
284+
) -> None:
285+
theta, y = toy_theta_property
286+
a = ActiveSubspace(posterior_stub)
287+
a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8)
288+
a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2)
289+
_evals, _evecs = a.find_directions(
290+
posterior_log_prob_as_property=False,
291+
norm_gradients_to_prior=norm_gradients,
292+
num_monte_carlo_samples=64,
293+
)
294+
295+
proj1 = a.project(theta, num_dimensions=1)
296+
proj2 = a.project(theta, num_dimensions=2)
297+
assert proj1.shape == (theta.shape[0], 1)
298+
assert proj2.shape == (theta.shape[0], 2)
299+
# Different dimensionality gives different projections
300+
assert not torch.allclose(proj1, proj2[:, :1])

0 commit comments

Comments
 (0)