1
1
from functools import reduce
2
2
from itertools import product
3
- from typing import Literal
3
+ from typing import Any , Literal
4
4
5
5
import arviz as az
6
6
import numpy as np
@@ -35,6 +35,29 @@ def make_unpacked_variable_names(name, model: pm.Model) -> list[str]:
35
35
return [f"{ name } [{ ',' .join (map (str , label ))} ]" for label in labels ]
36
36
37
37
38
+ def map_results_to_inferece_data (results : dict [str , Any ], model : pm .Model | None = None ):
39
+ """
40
+ Convert a dictionary of results to an InferenceData object.
41
+
42
+ Parameters
43
+ ----------
44
+ results: dict
45
+ A dictionary containing the results to convert.
46
+ model: Model, optional
47
+ A PyMC model. If None, the model is taken from the current model context.
48
+
49
+ Returns
50
+ -------
51
+ idata: az.InferenceData
52
+ An InferenceData object containing the results.
53
+ """
54
+ model = pm .modelcontext (model )
55
+ coords , dims = coords_and_dims_for_inferencedata (model )
56
+
57
+ idata = az .convert_to_inference_data (results , coords = coords , dims = dims )
58
+ return idata
59
+
60
+
38
61
def laplace_draws_to_inferencedata (
39
62
posterior_draws : list [np .ndarray [float | int ]], model : pm .Model | None = None
40
63
) -> az .InferenceData :
@@ -91,13 +114,67 @@ def make_rv_dims(name):
91
114
return idata
92
115
93
116
94
- def add_fit_to_inferencedata (
117
+ def add_map_posterior_to_inference_data (
118
+ idata : az .InferenceData ,
119
+ map_point : dict [str , float | int | np .ndarray ],
120
+ model : pm .Model | None = None ,
121
+ ):
122
+ """
123
+ Add the MAP point to an InferenceData object in the posterior group.
124
+
125
+ Unlike a typical posterior, the MAP point is a single point estimate rather than a distribution. As a result, it
126
+ does not have a chain or draw dimension, and is stored as a single point in the posterior group.
127
+
128
+ Parameters
129
+ ----------
130
+ idata: az.InferenceData
131
+ An InferenceData object to which the MAP point will be added.
132
+ map_point: dict
133
+ A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and
134
+ the values should be the corresponding MAP estimates.
135
+ model: Model, optional
136
+ A PyMC model. If None, the model is taken from the current model context.
137
+
138
+ Returns
139
+ -------
140
+ idata: az.InferenceData
141
+ The provided InferenceData, with the MAP point added to the posterior group.
142
+ """
143
+
144
+ model = pm .modelcontext (model ) if model is None else model
145
+ coords , dims = coords_and_dims_for_inferencedata (model )
146
+
147
+ # The MAP point will have both the transformed and untransformed variables, so we need to ensure that
148
+ # we have the correct dimensions for each variable.
149
+ var_name_to_value_name = {rv .name : value .name for rv , value in model .rvs_to_values .items ()}
150
+ dims .update (
151
+ {
152
+ value_name : dims [var_name ]
153
+ for var_name , value_name in var_name_to_value_name .items ()
154
+ if var_name in dims
155
+ }
156
+ )
157
+
158
+ posterior_data = {
159
+ name : xr .DataArray (
160
+ data = np .asarray (value ),
161
+ coords = {dim : coords [dim ] for dim in dims .get (name , [])},
162
+ dims = dims .get (name ),
163
+ name = name ,
164
+ )
165
+ for name , value in map_point .items ()
166
+ }
167
+ idata .add_groups (posterior = xr .Dataset (posterior_data ))
168
+
169
+ return idata
170
+
171
+
172
+ def add_fit_to_inference_data (
95
173
idata : az .InferenceData , mu : RaveledVars , H_inv : np .ndarray , model : pm .Model | None = None
96
174
) -> az .InferenceData :
97
175
"""
98
176
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
99
177
100
-
101
178
Parameters
102
179
----------
103
180
idata: az.InfereceData
@@ -123,19 +200,24 @@ def add_fit_to_inferencedata(
123
200
)
124
201
125
202
mean_dataarray = xr .DataArray (mu .data , dims = ["rows" ], coords = {"rows" : unpacked_variable_names })
126
- cov_dataarray = xr .DataArray (
127
- H_inv ,
128
- dims = ["rows" , "columns" ],
129
- coords = {"rows" : unpacked_variable_names , "columns" : unpacked_variable_names },
130
- )
131
203
132
- dataset = xr .Dataset ({"mean_vector" : mean_dataarray , "covariance_matrix" : cov_dataarray })
204
+ data = {"mean_vector" : mean_dataarray }
205
+
206
+ if H_inv is not None :
207
+ cov_dataarray = xr .DataArray (
208
+ H_inv ,
209
+ dims = ["rows" , "columns" ],
210
+ coords = {"rows" : unpacked_variable_names , "columns" : unpacked_variable_names },
211
+ )
212
+ data ["covariance_matrix" ] = cov_dataarray
213
+
214
+ dataset = xr .Dataset (data )
133
215
idata .add_groups (fit = dataset )
134
216
135
217
return idata
136
218
137
219
138
- def add_data_to_inferencedata (
220
+ def add_data_to_inference_data (
139
221
idata : az .InferenceData ,
140
222
progressbar : bool = True ,
141
223
model : pm .Model | None = None ,
@@ -163,8 +245,14 @@ def add_data_to_inferencedata(
163
245
model = pm .modelcontext (model )
164
246
165
247
if model .deterministics :
248
+ expand_dims = {}
249
+ if "chain" not in idata .posterior .coords :
250
+ expand_dims ["chain" ] = [0 ]
251
+ if "draw" not in idata .posterior .coords :
252
+ expand_dims ["draw" ] = [0 ]
253
+
166
254
idata .posterior = pm .compute_deterministics (
167
- idata .posterior ,
255
+ idata .posterior . expand_dims ( expand_dims ) ,
168
256
model = model ,
169
257
merge_dataset = True ,
170
258
progressbar = progressbar ,
@@ -299,3 +387,37 @@ def optimizer_result_to_dataset(
299
387
data_vars ["method" ] = xr .DataArray (np .array (method ), dims = [])
300
388
301
389
return xr .Dataset (data_vars )
390
+
391
+
392
+ def add_optimizer_result_to_inference_data (
393
+ idata : az .InferenceData ,
394
+ result : OptimizeResult ,
395
+ method : minimize_method | Literal ["basinhopping" ],
396
+ mu : RaveledVars | None = None ,
397
+ model : pm .Model | None = None ,
398
+ ) -> az .InferenceData :
399
+ """
400
+ Add the optimization result to an InferenceData object.
401
+
402
+ Parameters
403
+ ----------
404
+ idata: az.InferenceData
405
+ An InferenceData object containing the approximated posterior samples.
406
+ result: OptimizeResult
407
+ The result of the optimization process.
408
+ method: minimize_method or "basinhopping"
409
+ The optimization method used.
410
+ mu: RaveledVars, optional
411
+ The MAP estimate of the model parameters.
412
+ model: Model, optional
413
+ A PyMC model. If None, the model is taken from the current model context.
414
+
415
+ Returns
416
+ -------
417
+ idata: az.InferenceData
418
+ The provided InferenceData, with the optimization results added to the "optimizer" group.
419
+ """
420
+ dataset = optimizer_result_to_dataset (result , method = method , mu = mu , model = model )
421
+ idata .add_groups ({"optimizer_result" : dataset })
422
+
423
+ return idata
0 commit comments