|
| 1 | +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed |
| 2 | +# under the Apache License Version 2.0, see <https://www.apache.org/licenses/> |
| 3 | + |
| 4 | +# NOTE: the entire file was drafted by GPT-5 in GH Copilot, then editied by janfb. |
| 5 | + |
| 6 | +from typing import Tuple |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | +from torch import Tensor, nn |
| 11 | + |
| 12 | +from sbi.analysis.sensitivity_analysis import ( |
| 13 | + ActiveSubspace, |
| 14 | + Destandardize, |
| 15 | + build_input_output_layer, |
| 16 | + destandardizing_net, |
| 17 | +) |
| 18 | + |
| 19 | +# ------------------------ |
| 20 | +# Fixtures and test helpers |
| 21 | +# ------------------------ |
| 22 | + |
| 23 | + |
| 24 | +@pytest.fixture |
| 25 | +def toy_theta_property() -> Tuple[Tensor, Tensor]: |
| 26 | + """Small synthetic (theta, property) dataset for regression testing.""" |
| 27 | + n, d = 64, 3 |
| 28 | + theta = torch.randn(n, d) |
| 29 | + y = theta.sum(dim=1, keepdim=True) + 0.05 * torch.randn(n, 1) |
| 30 | + return theta, y |
| 31 | + |
| 32 | + |
| 33 | +class _PriorWithStats: |
| 34 | + """Simple Gaussian prior with mean and stddev attributes.""" |
| 35 | + |
| 36 | + def __init__(self, d: int): |
| 37 | + self.mean = torch.zeros(d) |
| 38 | + self.stddev = torch.ones(d) |
| 39 | + |
| 40 | + def sample(self, shape: Tuple[int, ...]) -> Tensor: |
| 41 | + return torch.randn(*shape, self.mean.numel()) |
| 42 | + |
| 43 | + |
| 44 | +class _PriorWithoutStats: |
| 45 | + """Simple prior without mean and stddev attributes.""" |
| 46 | + |
| 47 | + def __init__(self, d: int): |
| 48 | + self._d = d |
| 49 | + |
| 50 | + def sample(self, shape: Tuple[int, ...]) -> Tensor: |
| 51 | + return torch.randn(*shape, self._d) |
| 52 | + |
| 53 | + |
| 54 | +class _PosteriorStub: |
| 55 | + """Stub posterior with simple Gaussian prior and quadratic potential.""" |
| 56 | + |
| 57 | + def __init__(self, d: int, with_stats: bool = True): |
| 58 | + self._device = "cpu" |
| 59 | + self.d = d |
| 60 | + self.prior = _PriorWithStats(d) if with_stats else _PriorWithoutStats(d) |
| 61 | + |
| 62 | + def sample(self, shape: Tuple[int, ...]) -> Tensor: |
| 63 | + return torch.randn(*shape, self.d, requires_grad=False) |
| 64 | + |
| 65 | + def potential(self, theta: Tensor, track_gradients: bool = True) -> Tensor: |
| 66 | + # Smooth, concave potential with gradient everywhere |
| 67 | + # Return shape (N, 1) to match regression output style |
| 68 | + return -(theta**2).sum(dim=1, keepdim=True) * 0.5 |
| 69 | + |
| 70 | + |
| 71 | +@pytest.fixture( |
| 72 | + params=[ |
| 73 | + pytest.param(True, id="prior_with_stats"), |
| 74 | + pytest.param(False, id="prior_without_stats"), |
| 75 | + ] |
| 76 | +) |
| 77 | +def posterior_stub(request) -> _PosteriorStub: |
| 78 | + """Fixture providing a stub posterior with and without prior stats.""" |
| 79 | + return _PosteriorStub(d=3, with_stats=request.param) |
| 80 | + |
| 81 | + |
| 82 | +@pytest.fixture |
| 83 | +def embedding_net_theta() -> nn.Module: |
| 84 | + """Small embedding net for theta.""" |
| 85 | + return nn.Sequential(nn.Linear(3, 3), nn.ReLU()) |
| 86 | + |
| 87 | + |
| 88 | +# ------------------------ |
| 89 | +# Utility layer tests |
| 90 | +# ------------------------ |
| 91 | + |
| 92 | + |
| 93 | +def test_destandardize_and_destandardizing_net_forward() -> None: |
| 94 | + """Tests destandardizing_net and Destandardize layer.""" |
| 95 | + # Create a batch with near-zero variance in one dim to exercise min-std flooring |
| 96 | + n, d = 50, 2 |
| 97 | + col0 = torch.randn(n, 1) |
| 98 | + col1 = torch.zeros(n, 1) |
| 99 | + batch = torch.cat([col0, col1], dim=1) |
| 100 | + |
| 101 | + min_std = 0.5 |
| 102 | + net = destandardizing_net(batch, min_std=min_std) |
| 103 | + # Expected mean and clamped std |
| 104 | + mean = batch.mean(dim=0) |
| 105 | + std = batch.std(dim=0) |
| 106 | + std = torch.where(std < min_std, torch.full_like(std, min_std), std) |
| 107 | + |
| 108 | + # Check that zero maps to mean, ones maps to mean + std |
| 109 | + out0 = net(torch.zeros(1, d)) |
| 110 | + out1 = net(torch.ones(1, d)) |
| 111 | + assert torch.allclose(out0, mean.unsqueeze(0), atol=1e-6) |
| 112 | + assert torch.allclose(out1, (mean + std).unsqueeze(0), atol=1e-6) |
| 113 | + |
| 114 | + # Also test Destandardize directly |
| 115 | + dn = Destandardize(mean, std) |
| 116 | + assert torch.allclose(dn(torch.zeros(1, d)), out0) |
| 117 | + |
| 118 | + |
| 119 | +def test_destandardizing_net_single_sample_branch() -> None: |
| 120 | + """Tests destandardizing_net when batch has a single row.""" |
| 121 | + # When batch has a single row, we enter the else-branch and use t_std = 1 |
| 122 | + batch = torch.tensor([[2.0, -3.0]], dtype=torch.float32) |
| 123 | + net = destandardizing_net(batch, min_std=0.5) |
| 124 | + # For standardized input 0, output should equal mean (the single row) |
| 125 | + out0 = net(torch.zeros(1, 2)) |
| 126 | + assert torch.allclose(out0, batch, atol=1e-6) |
| 127 | + # For standardized input 1, output should be mean + std (std == 1 here) |
| 128 | + out1 = net(torch.ones(1, 2)) |
| 129 | + assert torch.allclose(out1, batch + 1.0, atol=1e-6) |
| 130 | + |
| 131 | + |
| 132 | +@pytest.mark.parametrize( |
| 133 | + "z_theta, z_prop", |
| 134 | + [ |
| 135 | + (True, True), |
| 136 | + (True, False), |
| 137 | + (False, True), |
| 138 | + (False, False), |
| 139 | + ], |
| 140 | +) |
| 141 | +def test_build_input_output_layer_shapes_and_types( |
| 142 | + toy_theta_property, embedding_net_theta, z_theta, z_prop |
| 143 | +) -> None: |
| 144 | + """Sanity check that the built input-output layer can be composed |
| 145 | +
|
| 146 | + Args: |
| 147 | + toy_theta_property: Fixture providing (theta, property) data |
| 148 | + embedding_net_theta: Fixture providing a small embedding net for theta |
| 149 | + z_theta: Whether to z-score standardize theta |
| 150 | + z_prop: Whether to z-score standardize the property |
| 151 | + """ |
| 152 | + theta, y = toy_theta_property |
| 153 | + inp, out = build_input_output_layer( |
| 154 | + batch_theta=theta, |
| 155 | + batch_property=y, |
| 156 | + z_score_theta=z_theta, |
| 157 | + z_score_property=z_prop, |
| 158 | + embedding_net_theta=embedding_net_theta, |
| 159 | + ) |
| 160 | + # Compose a tiny regression head to ensure shape compatibility end-to-end |
| 161 | + head = nn.Linear(theta.shape[1], 1) |
| 162 | + model = nn.Sequential(inp, head, out) |
| 163 | + preds = model(theta) |
| 164 | + assert preds.shape == y.shape |
| 165 | + |
| 166 | + |
| 167 | +# ------------------------ |
| 168 | +# ActiveSubspace tests |
| 169 | +# ------------------------ |
| 170 | + |
| 171 | + |
| 172 | +@pytest.mark.parametrize("model_name", ["mlp", "resnet"]) |
| 173 | +@pytest.mark.parametrize("clip_max_norm", [None, 5.0]) |
| 174 | +def test_add_property_and_train_models( |
| 175 | + model_name, |
| 176 | + clip_max_norm, |
| 177 | + toy_theta_property, |
| 178 | + posterior_stub, |
| 179 | + embedding_net_theta, |
| 180 | +) -> None: |
| 181 | + """Tests that add_property and train run without error for different models.""" |
| 182 | + theta, y = toy_theta_property |
| 183 | + a = ActiveSubspace(posterior_stub) |
| 184 | + a.add_property( |
| 185 | + theta=theta, |
| 186 | + emergent_property=y, |
| 187 | + model=model_name, |
| 188 | + hidden_features=16, |
| 189 | + num_blocks=1, |
| 190 | + dropout_probability=0.1, |
| 191 | + z_score_theta=True, |
| 192 | + z_score_property=True, |
| 193 | + embedding_net=embedding_net_theta, |
| 194 | + ) |
| 195 | + net = a.train( |
| 196 | + training_batch_size=16, |
| 197 | + learning_rate=1e-3, |
| 198 | + validation_fraction=0.25, |
| 199 | + stop_after_epochs=2, |
| 200 | + max_num_epochs=3, |
| 201 | + clip_max_norm=clip_max_norm, |
| 202 | + ) |
| 203 | + assert isinstance(net, nn.Module) |
| 204 | + assert len(a._validation_log_probs) >= 1 |
| 205 | + |
| 206 | + |
| 207 | +def test_add_property_with_callable_model(toy_theta_property, posterior_stub) -> None: |
| 208 | + """Tests that add_property works with a user-defined model builder.""" |
| 209 | + theta, y = toy_theta_property |
| 210 | + |
| 211 | + def builder(batch_theta: Tensor) -> nn.Module: |
| 212 | + d = batch_theta.shape[1] |
| 213 | + return nn.Sequential(nn.Identity(), nn.Linear(d, 1)) |
| 214 | + |
| 215 | + a = ActiveSubspace(posterior_stub) |
| 216 | + a.add_property(theta=theta, emergent_property=y, model=builder) |
| 217 | + net = a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2) |
| 218 | + assert isinstance(net, nn.Module) |
| 219 | + |
| 220 | + |
| 221 | +def test_add_property_invalid_model_raises(toy_theta_property, posterior_stub) -> None: |
| 222 | + """Tests that add_property raises for invalid model name.""" |
| 223 | + theta, y = toy_theta_property |
| 224 | + a = ActiveSubspace(posterior_stub) |
| 225 | + with pytest.raises(NameError): |
| 226 | + a.add_property(theta=theta, emergent_property=y, model="unknown") |
| 227 | + |
| 228 | + |
| 229 | +def test_train_reuses_existing_net(toy_theta_property, posterior_stub) -> None: |
| 230 | + theta, y = toy_theta_property |
| 231 | + a = ActiveSubspace(posterior_stub) |
| 232 | + a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8) |
| 233 | + _ = a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2) |
| 234 | + first_id = id(a._regression_net) |
| 235 | + _ = a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2) |
| 236 | + assert id(a._regression_net) == first_id # net is reused, not rebuilt |
| 237 | + |
| 238 | + |
| 239 | +@pytest.mark.parametrize("norm_gradients", [True, False]) |
| 240 | +def test_find_directions_with_regression_net( |
| 241 | + norm_gradients, toy_theta_property, posterior_stub |
| 242 | +) -> None: |
| 243 | + """Tests that find_directions runs and returns correctly shaped outputs.""" |
| 244 | + theta, y = toy_theta_property |
| 245 | + a = ActiveSubspace(posterior_stub) |
| 246 | + a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8) |
| 247 | + a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2) |
| 248 | + |
| 249 | + evals, evecs = a.find_directions( |
| 250 | + posterior_log_prob_as_property=False, |
| 251 | + norm_gradients_to_prior=norm_gradients, |
| 252 | + num_monte_carlo_samples=128, |
| 253 | + ) |
| 254 | + d = theta.shape[1] |
| 255 | + assert evals.shape == (d,) |
| 256 | + assert evecs.shape == (d, d) |
| 257 | + # Ascending order |
| 258 | + assert torch.all(evals[1:] >= evals[:-1]) |
| 259 | + # Columns are unit vectors |
| 260 | + assert torch.allclose(torch.linalg.norm(evecs, dim=0), torch.ones(d), atol=1e-5) |
| 261 | + |
| 262 | + |
| 263 | +def test_find_directions_with_posterior_log_prob_warns( |
| 264 | + toy_theta_property, posterior_stub |
| 265 | +) -> None: |
| 266 | + """Tests that find_directions with posterior log-prob issues a warning.""" |
| 267 | + theta, y = toy_theta_property |
| 268 | + a = ActiveSubspace(posterior_stub) |
| 269 | + a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8) |
| 270 | + a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2) |
| 271 | + |
| 272 | + with pytest.warns(UserWarning): |
| 273 | + evals, evecs = a.find_directions( |
| 274 | + posterior_log_prob_as_property=True, |
| 275 | + norm_gradients_to_prior=True, |
| 276 | + num_monte_carlo_samples=64, |
| 277 | + ) |
| 278 | + assert evals.numel() == evecs.shape[0] |
| 279 | + |
| 280 | + |
| 281 | +def test_find_directions_raises_without_property(posterior_stub) -> None: |
| 282 | + """Tests that find_directions raises if no property was added.""" |
| 283 | + a = ActiveSubspace(posterior_stub) |
| 284 | + with pytest.raises(ValueError): |
| 285 | + _ = a.find_directions( |
| 286 | + posterior_log_prob_as_property=False, num_monte_carlo_samples=8 |
| 287 | + ) |
| 288 | + |
| 289 | + |
| 290 | +@pytest.mark.parametrize("norm_gradients", [True, False]) |
| 291 | +def test_project_after_find_directions( |
| 292 | + norm_gradients, toy_theta_property, posterior_stub |
| 293 | +) -> None: |
| 294 | + """Tests that project works after find_directions and returns correct shapes.""" |
| 295 | + theta, y = toy_theta_property |
| 296 | + a = ActiveSubspace(posterior_stub) |
| 297 | + a.add_property(theta=theta, emergent_property=y, model="mlp", hidden_features=8) |
| 298 | + a.train(training_batch_size=16, stop_after_epochs=1, max_num_epochs=2) |
| 299 | + _evals, _evecs = a.find_directions( |
| 300 | + posterior_log_prob_as_property=False, |
| 301 | + norm_gradients_to_prior=norm_gradients, |
| 302 | + num_monte_carlo_samples=64, |
| 303 | + ) |
| 304 | + |
| 305 | + proj1 = a.project(theta, num_dimensions=1) |
| 306 | + proj2 = a.project(theta, num_dimensions=2) |
| 307 | + assert proj1.shape == (theta.shape[0], 1) |
| 308 | + assert proj2.shape == (theta.shape[0], 2) |
| 309 | + # Different dimensionality gives different projections |
| 310 | + assert not torch.allclose(proj1, proj2[:, :1]) |
0 commit comments