1
+ from contextlib import contextmanager
2
+
1
3
import arviz as az
2
4
import numpy as np
3
5
import pymc as pm
16
18
)
17
19
18
20
21
+ @contextmanager
22
+ def no_op ():
23
+ yield
24
+
25
+
19
26
@pytest .fixture
20
27
def rng ():
21
28
return np .random .default_rng ()
@@ -62,14 +69,18 @@ def hierarchical_model(rng):
62
69
return model , mu_val , H_inv , test_point
63
70
64
71
65
- def test_laplace_draws_to_inferencedata ( simple_model , rng ):
66
- # Simulate posterior draws: 2 variables, each (chains, draws)
72
+ @ pytest . mark . parametrize ( "use_context" , [ False , True ], ids = [ "model_arg" , "model_context" ])
73
+ def test_laplace_draws_to_inferencedata ( use_context , simple_model , rng ):
67
74
chains , draws = 2 , 5
68
75
mu_draws = rng .normal (size = (chains , draws ))
69
76
sigma_draws = np .abs (rng .normal (size = (chains , draws )))
70
77
model , * _ = simple_model
71
78
72
- idata = laplace_draws_to_inferencedata ([mu_draws , sigma_draws ], model = model )
79
+ context = model if use_context else no_op ()
80
+ model_arg = model if not use_context else None
81
+
82
+ with context :
83
+ idata = laplace_draws_to_inferencedata ([mu_draws , sigma_draws ], model = model_arg )
73
84
74
85
assert isinstance (idata , az .InferenceData )
75
86
assert "mu" in idata .posterior
@@ -93,14 +104,21 @@ def check_idata(self, idata, var_names, n_vars):
93
104
assert fit .coords ["rows" ].values .tolist () == var_names
94
105
assert fit .coords ["columns" ].values .tolist () == var_names
95
106
96
- def test_add_fit_to_inferencedata (self , simple_model , rng ):
107
+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
108
+ def test_add_fit_to_inferencedata (self , use_context , simple_model , rng ):
97
109
model , mu_val , H_inv , test_point = simple_model
98
110
idata = az .from_dict (posterior = {"mu" : rng .normal (size = ()), "sigma" : rng .normal (size = ())})
99
- idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model )
111
+
112
+ context = model if use_context else no_op ()
113
+ model_arg = model if not use_context else None
114
+
115
+ with context :
116
+ idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model_arg )
100
117
101
118
self .check_idata (idata2 , ["mu" , "sigma" ], 2 )
102
119
103
- def test_add_fit_with_coords_to_inferencedata (self , hierarchical_model , rng ):
120
+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
121
+ def test_add_fit_with_coords_to_inferencedata (self , use_context , hierarchical_model , rng ):
104
122
model , mu_val , H_inv , test_point = hierarchical_model
105
123
idata = az .from_dict (
106
124
posterior = {
@@ -111,26 +129,38 @@ def test_add_fit_with_coords_to_inferencedata(self, hierarchical_model, rng):
111
129
}
112
130
)
113
131
114
- idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model )
132
+ context = model if use_context else no_op ()
133
+ model_arg = model if not use_context else None
134
+
135
+ with context :
136
+ idata2 = add_fit_to_inference_data (idata , test_point , H_inv , model = model_arg )
115
137
116
138
self .check_idata (
117
139
idata2 , ["mu_loc" , "mu_scale" , "mu[1]" , "mu[2]" , "mu[3]" , "mu[4]" , "mu[5]" , "sigma" ], 8
118
140
)
119
141
120
142
121
- def test_add_data_to_inferencedata (simple_model , rng ):
143
+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
144
+ def test_add_data_to_inferencedata (use_context , simple_model , rng ):
122
145
model , * _ = simple_model
123
146
124
147
idata = az .from_dict (
125
148
posterior = {"mu" : rng .standard_normal ((1 , 1 )), "sigma" : rng .standard_normal ((1 , 1 ))}
126
149
)
127
- idata2 = add_data_to_inference_data (idata , model = model )
150
+
151
+ context = model if use_context else no_op ()
152
+ model_arg = model if not use_context else None
153
+
154
+ with context :
155
+ idata2 = add_data_to_inference_data (idata , model = model_arg )
156
+
128
157
assert "observed_data" in idata2 .groups ()
129
158
assert "constant_data" in idata2 .groups ()
130
159
assert "obs" in idata2 .observed_data
131
160
132
161
133
- def test_optimizer_result_to_dataset_basic (simple_model , rng ):
162
+ @pytest .mark .parametrize ("use_context" , [False , True ], ids = ["model_arg" , "model_context" ])
163
+ def test_optimizer_result_to_dataset_basic (use_context , simple_model , rng ):
134
164
model , mu_val , H_inv , test_point = simple_model
135
165
result = OptimizeResult (
136
166
x = np .array ([1.0 , 2.0 ]),
@@ -144,7 +174,11 @@ def test_optimizer_result_to_dataset_basic(simple_model, rng):
144
174
status = 0 ,
145
175
)
146
176
147
- ds = optimizer_result_to_dataset (result , method = "BFGS" , model = model , mu = test_point )
177
+ context = model if use_context else no_op ()
178
+ model_arg = model if not use_context else None
179
+ with context :
180
+ ds = optimizer_result_to_dataset (result , method = "BFGS" , model = model_arg , mu = test_point )
181
+
148
182
assert isinstance (ds , xr .Dataset )
149
183
assert all (
150
184
key in ds
@@ -169,48 +203,68 @@ def test_optimizer_result_to_dataset_basic(simple_model, rng):
169
203
assert ds ["jac" ].coords ["variables" ].values .tolist () == ["mu" , "sigma" ]
170
204
171
205
172
- def test_optimizer_result_to_dataset_hess_inv_matrix (hierarchical_model , rng ):
173
- model , mu_val , H_inv , test_point = hierarchical_model
174
- result = OptimizeResult (
175
- x = np .zeros ((8 ,)),
176
- hess_inv = np .eye (8 ),
206
+ @pytest .mark .parametrize (
207
+ "optimizer_method, use_context, model_name" ,
208
+ [("BFGS" , True , "hierarchical_model" ), ("L-BFGS-B" , False , "simple_model" )],
209
+ )
210
+ def test_optimizer_result_to_dataset_hess_inv_types (
211
+ optimizer_method , use_context , model_name , rng , request
212
+ ):
213
+ def get_hess_inv_and_expected_names (method ):
214
+ model , mu_val , H_inv , test_point = request .getfixturevalue (model_name )
215
+ n = mu_val .shape [0 ]
216
+
217
+ if method == "BFGS" :
218
+ hess_inv = np .eye (n )
219
+ expected_names = [
220
+ "mu_loc" ,
221
+ "mu_scale" ,
222
+ "mu[1]" ,
223
+ "mu[2]" ,
224
+ "mu[3]" ,
225
+ "mu[4]" ,
226
+ "mu[5]" ,
227
+ "sigma" ,
228
+ ]
229
+ result = OptimizeResult (
230
+ x = np .zeros ((n ,)),
231
+ hess_inv = hess_inv ,
232
+ )
233
+ elif method == "L-BFGS-B" :
234
+
235
+ def linop_func (x ):
236
+ return np .array ([2 * xi for xi in x ])
237
+
238
+ linop = LinearOperator ((n , n ), matvec = linop_func )
239
+ hess_inv = 2 * np .eye (n )
240
+ expected_names = ["mu" , "sigma" ]
241
+ result = OptimizeResult (
242
+ x = np .ones (n ),
243
+ hess_inv = linop ,
244
+ )
245
+ else :
246
+ raise ValueError ("Unknown optimizer_method" )
247
+
248
+ return model , test_point , hess_inv , expected_names , result
249
+
250
+ model , test_point , hess_inv , expected_names , result = get_hess_inv_and_expected_names (
251
+ optimizer_method
177
252
)
178
- ds = optimizer_result_to_dataset (result , method = "BFGS" , model = model , mu = test_point )
179
253
180
- assert "hess_inv" in ds
181
- assert ds ["hess_inv" ].shape == (8 , 8 )
182
- assert list (ds ["hess_inv" ].coords .keys ()) == ["variables" , "variables_aux" ]
183
-
184
- expected_names = ["mu_loc" , "mu_scale" , "mu[1]" , "mu[2]" , "mu[3]" , "mu[4]" , "mu[5]" , "sigma" ]
185
- assert ds ["hess_inv" ].coords ["variables" ].values .tolist () == expected_names
186
- assert ds ["hess_inv" ].coords ["variables_aux" ].values .tolist () == expected_names
187
-
188
-
189
- def test_optimizer_result_to_dataset_hess_inv_linear_operator (simple_model , rng ):
190
- model , mu_val , H_inv , test_point = simple_model
191
- n = mu_val .shape [0 ]
192
-
193
- def matvec (x ):
194
- return np .array ([2 * xi for xi in x ])
195
-
196
- linop = LinearOperator ((n , n ), matvec = matvec )
197
- result = OptimizeResult (
198
- x = np .ones (n ),
199
- hess_inv = linop ,
200
- )
254
+ context = model if use_context else no_op ()
255
+ model_arg = model if not use_context else None
201
256
202
- with model :
203
- ds = optimizer_result_to_dataset (result , method = "BFGS" , mu = test_point )
257
+ with context :
258
+ ds = optimizer_result_to_dataset (
259
+ result , method = optimizer_method , mu = test_point , model = model_arg
260
+ )
204
261
205
262
assert "hess_inv" in ds
206
- assert ds ["hess_inv" ].shape == (n , n )
263
+ assert ds ["hess_inv" ].shape == (len ( expected_names ), len ( expected_names ) )
207
264
assert list (ds ["hess_inv" ].coords .keys ()) == ["variables" , "variables_aux" ]
208
-
209
- expected_names = ["mu" , "sigma" ]
210
265
assert ds ["hess_inv" ].coords ["variables" ].values .tolist () == expected_names
211
266
assert ds ["hess_inv" ].coords ["variables_aux" ].values .tolist () == expected_names
212
-
213
- np .testing .assert_allclose (ds ["hess_inv" ].values , 2 * np .eye (n ))
267
+ np .testing .assert_allclose (ds ["hess_inv" ].values , hess_inv )
214
268
215
269
216
270
def test_optimizer_result_to_dataset_extra_fields (simple_model , rng ):
@@ -228,3 +282,25 @@ def test_optimizer_result_to_dataset_extra_fields(simple_model, rng):
228
282
assert ds ["custom_stat" ].shape == (2 ,)
229
283
assert list (ds ["custom_stat" ].coords .keys ()) == ["custom_stat_dim_0" ]
230
284
assert ds ["custom_stat" ].coords ["custom_stat_dim_0" ].values .tolist () == [0 , 1 ]
285
+
286
+
287
+ def test_optimizer_result_to_dataset_hess_inv_basinhopping (simple_model , rng ):
288
+ model , mu_val , H_inv , test_point = simple_model
289
+ n = mu_val .shape [0 ]
290
+ hess_inv_inner = np .eye (n ) * 3.0
291
+
292
+ # Basinhopping returns an OptimizeResult with a nested OptimizeResult
293
+ result = OptimizeResult (
294
+ x = np .ones (n ),
295
+ lowest_optimization_result = OptimizeResult (x = np .ones (n ), hess_inv = hess_inv_inner ),
296
+ )
297
+
298
+ with model :
299
+ ds = optimizer_result_to_dataset (result , method = "basinhopping" , mu = test_point )
300
+
301
+ assert "hess_inv" in ds
302
+ assert ds ["hess_inv" ].shape == (n , n )
303
+ np .testing .assert_allclose (ds ["hess_inv" ].values , hess_inv_inner )
304
+ expected_names = ["mu" , "sigma" ]
305
+ assert ds ["hess_inv" ].coords ["variables" ].values .tolist () == expected_names
306
+ assert ds ["hess_inv" ].coords ["variables_aux" ].values .tolist () == expected_names
0 commit comments