Skip to content

Commit f973cfb

Browse files
Refactor find_MAP
1 parent 93e3aa2 commit f973cfb

File tree

8 files changed

+825
-422
lines changed

8 files changed

+825
-422
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 113 additions & 360 deletions
Large diffs are not rendered by default.

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 133 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import reduce
22
from itertools import product
3-
from typing import Literal
3+
from typing import Any, Literal
44

55
import arviz as az
66
import numpy as np
@@ -35,6 +35,29 @@ def make_unpacked_variable_names(name, model: pm.Model) -> list[str]:
3535
return [f"{name}[{','.join(map(str, label))}]" for label in labels]
3636

3737

38+
def map_results_to_inferece_data(results: dict[str, Any], model: pm.Model | None = None):
39+
"""
40+
Convert a dictionary of results to an InferenceData object.
41+
42+
Parameters
43+
----------
44+
results: dict
45+
A dictionary containing the results to convert.
46+
model: Model, optional
47+
A PyMC model. If None, the model is taken from the current model context.
48+
49+
Returns
50+
-------
51+
idata: az.InferenceData
52+
An InferenceData object containing the results.
53+
"""
54+
model = pm.modelcontext(model)
55+
coords, dims = coords_and_dims_for_inferencedata(model)
56+
57+
idata = az.convert_to_inference_data(results, coords=coords, dims=dims)
58+
return idata
59+
60+
3861
def laplace_draws_to_inferencedata(
3962
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
4063
) -> az.InferenceData:
@@ -91,13 +114,67 @@ def make_rv_dims(name):
91114
return idata
92115

93116

94-
def add_fit_to_inferencedata(
117+
def add_map_posterior_to_inference_data(
118+
idata: az.InferenceData,
119+
map_point: dict[str, float | int | np.ndarray],
120+
model: pm.Model | None = None,
121+
):
122+
"""
123+
Add the MAP point to an InferenceData object in the posterior group.
124+
125+
Unlike a typical posterior, the MAP point is a single point estimate rather than a distribution. As a result, it
126+
does not have a chain or draw dimension, and is stored as a single point in the posterior group.
127+
128+
Parameters
129+
----------
130+
idata: az.InferenceData
131+
An InferenceData object to which the MAP point will be added.
132+
map_point: dict
133+
A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and
134+
the values should be the corresponding MAP estimates.
135+
model: Model, optional
136+
A PyMC model. If None, the model is taken from the current model context.
137+
138+
Returns
139+
-------
140+
idata: az.InferenceData
141+
The provided InferenceData, with the MAP point added to the posterior group.
142+
"""
143+
144+
model = pm.modelcontext(model) if model is None else model
145+
coords, dims = coords_and_dims_for_inferencedata(model)
146+
147+
# The MAP point will have both the transformed and untransformed variables, so we need to ensure that
148+
# we have the correct dimensions for each variable.
149+
var_name_to_value_name = {rv.name: value.name for rv, value in model.rvs_to_values.items()}
150+
dims.update(
151+
{
152+
value_name: dims[var_name]
153+
for var_name, value_name in var_name_to_value_name.items()
154+
if var_name in dims
155+
}
156+
)
157+
158+
posterior_data = {
159+
name: xr.DataArray(
160+
data=np.asarray(value),
161+
coords={dim: coords[dim] for dim in dims.get(name, [])},
162+
dims=dims.get(name),
163+
name=name,
164+
)
165+
for name, value in map_point.items()
166+
}
167+
idata.add_groups(posterior=xr.Dataset(posterior_data))
168+
169+
return idata
170+
171+
172+
def add_fit_to_inference_data(
95173
idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
96174
) -> az.InferenceData:
97175
"""
98176
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
99177
100-
101178
Parameters
102179
----------
103180
idata: az.InfereceData
@@ -123,19 +200,24 @@ def add_fit_to_inferencedata(
123200
)
124201

125202
mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names})
126-
cov_dataarray = xr.DataArray(
127-
H_inv,
128-
dims=["rows", "columns"],
129-
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
130-
)
131203

132-
dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
204+
data = {"mean_vector": mean_dataarray}
205+
206+
if H_inv is not None:
207+
cov_dataarray = xr.DataArray(
208+
H_inv,
209+
dims=["rows", "columns"],
210+
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
211+
)
212+
data["covariance_matrix"] = cov_dataarray
213+
214+
dataset = xr.Dataset(data)
133215
idata.add_groups(fit=dataset)
134216

135217
return idata
136218

137219

138-
def add_data_to_inferencedata(
220+
def add_data_to_inference_data(
139221
idata: az.InferenceData,
140222
progressbar: bool = True,
141223
model: pm.Model | None = None,
@@ -163,8 +245,14 @@ def add_data_to_inferencedata(
163245
model = pm.modelcontext(model)
164246

165247
if model.deterministics:
248+
expand_dims = {}
249+
if "chain" not in idata.posterior.coords:
250+
expand_dims["chain"] = [0]
251+
if "draw" not in idata.posterior.coords:
252+
expand_dims["draw"] = [0]
253+
166254
idata.posterior = pm.compute_deterministics(
167-
idata.posterior,
255+
idata.posterior.expand_dims(expand_dims),
168256
model=model,
169257
merge_dataset=True,
170258
progressbar=progressbar,
@@ -299,3 +387,37 @@ def optimizer_result_to_dataset(
299387
data_vars["method"] = xr.DataArray(np.array(method), dims=[])
300388

301389
return xr.Dataset(data_vars)
390+
391+
392+
def add_optimizer_result_to_inference_data(
393+
idata: az.InferenceData,
394+
result: OptimizeResult,
395+
method: minimize_method | Literal["basinhopping"],
396+
mu: RaveledVars | None = None,
397+
model: pm.Model | None = None,
398+
) -> az.InferenceData:
399+
"""
400+
Add the optimization result to an InferenceData object.
401+
402+
Parameters
403+
----------
404+
idata: az.InferenceData
405+
An InferenceData object containing the approximated posterior samples.
406+
result: OptimizeResult
407+
The result of the optimization process.
408+
method: minimize_method or "basinhopping"
409+
The optimization method used.
410+
mu: RaveledVars, optional
411+
The MAP estimate of the model parameters.
412+
model: Model, optional
413+
A PyMC model. If None, the model is taken from the current model context.
414+
415+
Returns
416+
-------
417+
idata: az.InferenceData
418+
The provided InferenceData, with the optimization results added to the "optimizer" group.
419+
"""
420+
dataset = optimizer_result_to_dataset(result, method=method, mu=mu, model=model)
421+
idata.add_groups({"optimizer_result": dataset})
422+
423+
return idata

0 commit comments

Comments
 (0)