Skip to content

Commit d79d642

Browse files
Add more tests
1 parent 067860f commit d79d642

File tree

2 files changed

+130
-47
lines changed

2 files changed

+130
-47
lines changed

tests/inference/laplace_approx/test_idata.py

Lines changed: 121 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextlib import contextmanager
2+
13
import arviz as az
24
import numpy as np
35
import pymc as pm
@@ -16,6 +18,11 @@
1618
)
1719

1820

21+
@contextmanager
22+
def no_op():
23+
yield
24+
25+
1926
@pytest.fixture
2027
def rng():
2128
return np.random.default_rng()
@@ -62,14 +69,18 @@ def hierarchical_model(rng):
6269
return model, mu_val, H_inv, test_point
6370

6471

65-
def test_laplace_draws_to_inferencedata(simple_model, rng):
66-
# Simulate posterior draws: 2 variables, each (chains, draws)
72+
@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"])
73+
def test_laplace_draws_to_inferencedata(use_context, simple_model, rng):
6774
chains, draws = 2, 5
6875
mu_draws = rng.normal(size=(chains, draws))
6976
sigma_draws = np.abs(rng.normal(size=(chains, draws)))
7077
model, *_ = simple_model
7178

72-
idata = laplace_draws_to_inferencedata([mu_draws, sigma_draws], model=model)
79+
context = model if use_context else no_op()
80+
model_arg = model if not use_context else None
81+
82+
with context:
83+
idata = laplace_draws_to_inferencedata([mu_draws, sigma_draws], model=model_arg)
7384

7485
assert isinstance(idata, az.InferenceData)
7586
assert "mu" in idata.posterior
@@ -93,14 +104,21 @@ def check_idata(self, idata, var_names, n_vars):
93104
assert fit.coords["rows"].values.tolist() == var_names
94105
assert fit.coords["columns"].values.tolist() == var_names
95106

96-
def test_add_fit_to_inferencedata(self, simple_model, rng):
107+
@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"])
108+
def test_add_fit_to_inferencedata(self, use_context, simple_model, rng):
97109
model, mu_val, H_inv, test_point = simple_model
98110
idata = az.from_dict(posterior={"mu": rng.normal(size=()), "sigma": rng.normal(size=())})
99-
idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model)
111+
112+
context = model if use_context else no_op()
113+
model_arg = model if not use_context else None
114+
115+
with context:
116+
idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model_arg)
100117

101118
self.check_idata(idata2, ["mu", "sigma"], 2)
102119

103-
def test_add_fit_with_coords_to_inferencedata(self, hierarchical_model, rng):
120+
@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"])
121+
def test_add_fit_with_coords_to_inferencedata(self, use_context, hierarchical_model, rng):
104122
model, mu_val, H_inv, test_point = hierarchical_model
105123
idata = az.from_dict(
106124
posterior={
@@ -111,26 +129,38 @@ def test_add_fit_with_coords_to_inferencedata(self, hierarchical_model, rng):
111129
}
112130
)
113131

114-
idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model)
132+
context = model if use_context else no_op()
133+
model_arg = model if not use_context else None
134+
135+
with context:
136+
idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model_arg)
115137

116138
self.check_idata(
117139
idata2, ["mu_loc", "mu_scale", "mu[1]", "mu[2]", "mu[3]", "mu[4]", "mu[5]", "sigma"], 8
118140
)
119141

120142

121-
def test_add_data_to_inferencedata(simple_model, rng):
143+
@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"])
144+
def test_add_data_to_inferencedata(use_context, simple_model, rng):
122145
model, *_ = simple_model
123146

124147
idata = az.from_dict(
125148
posterior={"mu": rng.standard_normal((1, 1)), "sigma": rng.standard_normal((1, 1))}
126149
)
127-
idata2 = add_data_to_inference_data(idata, model=model)
150+
151+
context = model if use_context else no_op()
152+
model_arg = model if not use_context else None
153+
154+
with context:
155+
idata2 = add_data_to_inference_data(idata, model=model_arg)
156+
128157
assert "observed_data" in idata2.groups()
129158
assert "constant_data" in idata2.groups()
130159
assert "obs" in idata2.observed_data
131160

132161

133-
def test_optimizer_result_to_dataset_basic(simple_model, rng):
162+
@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"])
163+
def test_optimizer_result_to_dataset_basic(use_context, simple_model, rng):
134164
model, mu_val, H_inv, test_point = simple_model
135165
result = OptimizeResult(
136166
x=np.array([1.0, 2.0]),
@@ -144,7 +174,11 @@ def test_optimizer_result_to_dataset_basic(simple_model, rng):
144174
status=0,
145175
)
146176

147-
ds = optimizer_result_to_dataset(result, method="BFGS", model=model, mu=test_point)
177+
context = model if use_context else no_op()
178+
model_arg = model if not use_context else None
179+
with context:
180+
ds = optimizer_result_to_dataset(result, method="BFGS", model=model_arg, mu=test_point)
181+
148182
assert isinstance(ds, xr.Dataset)
149183
assert all(
150184
key in ds
@@ -169,48 +203,68 @@ def test_optimizer_result_to_dataset_basic(simple_model, rng):
169203
assert ds["jac"].coords["variables"].values.tolist() == ["mu", "sigma"]
170204

171205

172-
def test_optimizer_result_to_dataset_hess_inv_matrix(hierarchical_model, rng):
173-
model, mu_val, H_inv, test_point = hierarchical_model
174-
result = OptimizeResult(
175-
x=np.zeros((8,)),
176-
hess_inv=np.eye(8),
206+
@pytest.mark.parametrize(
207+
"optimizer_method, use_context, model_name",
208+
[("BFGS", True, "hierarchical_model"), ("L-BFGS-B", False, "simple_model")],
209+
)
210+
def test_optimizer_result_to_dataset_hess_inv_types(
211+
optimizer_method, use_context, model_name, rng, request
212+
):
213+
def get_hess_inv_and_expected_names(method):
214+
model, mu_val, H_inv, test_point = request.getfixturevalue(model_name)
215+
n = mu_val.shape[0]
216+
217+
if method == "BFGS":
218+
hess_inv = np.eye(n)
219+
expected_names = [
220+
"mu_loc",
221+
"mu_scale",
222+
"mu[1]",
223+
"mu[2]",
224+
"mu[3]",
225+
"mu[4]",
226+
"mu[5]",
227+
"sigma",
228+
]
229+
result = OptimizeResult(
230+
x=np.zeros((n,)),
231+
hess_inv=hess_inv,
232+
)
233+
elif method == "L-BFGS-B":
234+
235+
def linop_func(x):
236+
return np.array([2 * xi for xi in x])
237+
238+
linop = LinearOperator((n, n), matvec=linop_func)
239+
hess_inv = 2 * np.eye(n)
240+
expected_names = ["mu", "sigma"]
241+
result = OptimizeResult(
242+
x=np.ones(n),
243+
hess_inv=linop,
244+
)
245+
else:
246+
raise ValueError("Unknown optimizer_method")
247+
248+
return model, test_point, hess_inv, expected_names, result
249+
250+
model, test_point, hess_inv, expected_names, result = get_hess_inv_and_expected_names(
251+
optimizer_method
177252
)
178-
ds = optimizer_result_to_dataset(result, method="BFGS", model=model, mu=test_point)
179253

180-
assert "hess_inv" in ds
181-
assert ds["hess_inv"].shape == (8, 8)
182-
assert list(ds["hess_inv"].coords.keys()) == ["variables", "variables_aux"]
183-
184-
expected_names = ["mu_loc", "mu_scale", "mu[1]", "mu[2]", "mu[3]", "mu[4]", "mu[5]", "sigma"]
185-
assert ds["hess_inv"].coords["variables"].values.tolist() == expected_names
186-
assert ds["hess_inv"].coords["variables_aux"].values.tolist() == expected_names
187-
188-
189-
def test_optimizer_result_to_dataset_hess_inv_linear_operator(simple_model, rng):
190-
model, mu_val, H_inv, test_point = simple_model
191-
n = mu_val.shape[0]
192-
193-
def matvec(x):
194-
return np.array([2 * xi for xi in x])
195-
196-
linop = LinearOperator((n, n), matvec=matvec)
197-
result = OptimizeResult(
198-
x=np.ones(n),
199-
hess_inv=linop,
200-
)
254+
context = model if use_context else no_op()
255+
model_arg = model if not use_context else None
201256

202-
with model:
203-
ds = optimizer_result_to_dataset(result, method="BFGS", mu=test_point)
257+
with context:
258+
ds = optimizer_result_to_dataset(
259+
result, method=optimizer_method, mu=test_point, model=model_arg
260+
)
204261

205262
assert "hess_inv" in ds
206-
assert ds["hess_inv"].shape == (n, n)
263+
assert ds["hess_inv"].shape == (len(expected_names), len(expected_names))
207264
assert list(ds["hess_inv"].coords.keys()) == ["variables", "variables_aux"]
208-
209-
expected_names = ["mu", "sigma"]
210265
assert ds["hess_inv"].coords["variables"].values.tolist() == expected_names
211266
assert ds["hess_inv"].coords["variables_aux"].values.tolist() == expected_names
212-
213-
np.testing.assert_allclose(ds["hess_inv"].values, 2 * np.eye(n))
267+
np.testing.assert_allclose(ds["hess_inv"].values, hess_inv)
214268

215269

216270
def test_optimizer_result_to_dataset_extra_fields(simple_model, rng):
@@ -228,3 +282,25 @@ def test_optimizer_result_to_dataset_extra_fields(simple_model, rng):
228282
assert ds["custom_stat"].shape == (2,)
229283
assert list(ds["custom_stat"].coords.keys()) == ["custom_stat_dim_0"]
230284
assert ds["custom_stat"].coords["custom_stat_dim_0"].values.tolist() == [0, 1]
285+
286+
287+
def test_optimizer_result_to_dataset_hess_inv_basinhopping(simple_model, rng):
288+
model, mu_val, H_inv, test_point = simple_model
289+
n = mu_val.shape[0]
290+
hess_inv_inner = np.eye(n) * 3.0
291+
292+
# Basinhopping returns an OptimizeResult with a nested OptimizeResult
293+
result = OptimizeResult(
294+
x=np.ones(n),
295+
lowest_optimization_result=OptimizeResult(x=np.ones(n), hess_inv=hess_inv_inner),
296+
)
297+
298+
with model:
299+
ds = optimizer_result_to_dataset(result, method="basinhopping", mu=test_point)
300+
301+
assert "hess_inv" in ds
302+
assert ds["hess_inv"].shape == (n, n)
303+
np.testing.assert_allclose(ds["hess_inv"].values, hess_inv_inner)
304+
expected_names = ["mu", "sigma"]
305+
assert ds["hess_inv"].coords["variables"].values.tolist() == expected_names
306+
assert ds["hess_inv"].coords["variables_aux"].values.tolist() == expected_names

tests/inference/laplace_approx/test_laplace.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,22 @@ def test_fit_laplace_ragged_coords(rng):
157157
assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all()
158158

159159

160-
def test_laplace_scalar():
160+
# Test these three optimizers because they are either special cases for H_inv (BFGS, L-BFGS-B) or are
161+
# gradient free and require re-compilation of hessp (powell).
162+
@pytest.mark.parametrize("optimizer_method", ["BFGS", "L-BFGS-B", "powell"])
163+
def test_laplace_scalar_basinhopping(optimizer_method):
161164
# Example model from Statistical Rethinking
162165
data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
163166

164167
with pm.Model():
165168
p = pm.Uniform("p", 0, 1)
166169
w = pm.Binomial("w", n=len(data), p=p, observed=data.sum())
167170

168-
idata_laplace = pmx.fit_laplace(optimize_method="powell", progressbar=False)
171+
idata_laplace = pmx.fit_laplace(
172+
optimize_method="basinhopping",
173+
optimizer_kwargs={"minimizer_kwargs": {"method": optimizer_method}, "niter": 1},
174+
progressbar=False,
175+
)
169176

170177
assert idata_laplace.fit.mean_vector.shape == (1,)
171178
assert idata_laplace.fit.covariance_matrix.shape == (1, 1)

0 commit comments

Comments
 (0)