|
15 | 15 |
|
16 | 16 | import logging
|
17 | 17 |
|
| 18 | +from collections.abc import Callable |
18 | 19 | from functools import reduce
|
19 | 20 | from importlib.util import find_spec
|
20 | 21 | from itertools import product
|
|
29 | 30 |
|
30 | 31 | from arviz import dict_to_dataset
|
31 | 32 | from better_optimize.constants import minimize_method
|
| 33 | +from numpy.typing import ArrayLike |
32 | 34 | from pymc import DictToArrayBijection
|
33 | 35 | from pymc.backends.arviz import (
|
34 | 36 | coords_and_dims_for_inferencedata,
|
|
39 | 41 | from pymc.model.transform.conditioning import remove_value_transforms
|
40 | 42 | from pymc.model.transform.optimization import freeze_dims_and_data
|
41 | 43 | from pymc.util import get_default_varnames
|
| 44 | +from pytensor.tensor import TensorVariable |
| 45 | +from pytensor.tensor.optimize import minimize |
42 | 46 | from scipy import stats
|
43 | 47 |
|
44 | 48 | from pymc_extras.inference.find_map import (
|
|
52 | 56 | _log = logging.getLogger(__name__)
|
53 | 57 |
|
54 | 58 |
|
| 59 | +def get_conditional_gaussian_approximation( |
| 60 | + x: TensorVariable, |
| 61 | + Q: TensorVariable | ArrayLike, |
| 62 | + mu: TensorVariable | ArrayLike, |
| 63 | + args: list[TensorVariable] | None = None, |
| 64 | + model: pm.Model | None = None, |
| 65 | + method: minimize_method = "BFGS", |
| 66 | + use_jac: bool = True, |
| 67 | + use_hess: bool = False, |
| 68 | + optimizer_kwargs: dict | None = None, |
| 69 | +) -> Callable: |
| 70 | + """ |
| 71 | + Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. |
| 72 | +
|
| 73 | + That is: |
| 74 | + y | x, sigma ~ N(Ax, sigma^2 W) |
| 75 | + x | params ~ N(mu, Q(params)^-1) |
| 76 | +
|
| 77 | + We seek to estimate log(p(x | y, params)): |
| 78 | +
|
| 79 | + log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const |
| 80 | +
|
| 81 | + Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). |
| 82 | +
|
| 83 | + This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. |
| 84 | +
|
| 85 | + Thus: |
| 86 | +
|
| 87 | + 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. |
| 88 | +
|
| 89 | + 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q). |
| 90 | +
|
| 91 | + Parameters |
| 92 | + ---------- |
| 93 | + x: TensorVariable |
| 94 | + The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1). |
| 95 | + Q: TensorVariable | ArrayLike |
| 96 | + The precision matrix of the latent field x. |
| 97 | + mu: TensorVariable | ArrayLike |
| 98 | + The mean of the latent field x. |
| 99 | + args: list[TensorVariable] |
| 100 | + Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args. |
| 101 | + model: Model |
| 102 | + PyMC model to use. |
| 103 | + method: minimize_method |
| 104 | + Which minimization algorithm to use. |
| 105 | + use_jac: bool |
| 106 | + If true, the minimizer will compute the gradient of log(p(x | y, params)). |
| 107 | + use_hess: bool |
| 108 | + If true, the minimizer will compute the Hessian log(p(x | y, params)). |
| 109 | + optimizer_kwargs: dict |
| 110 | + Kwargs to pass to scipy.optimize.minimize. |
| 111 | +
|
| 112 | + Returns |
| 113 | + ------- |
| 114 | + f: Callable |
| 115 | + A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer. |
| 116 | + """ |
| 117 | + model = pm.modelcontext(model) |
| 118 | + |
| 119 | + if args is None: |
| 120 | + args = model.continuous_value_vars + model.discrete_value_vars |
| 121 | + |
| 122 | + # f = log(p(y | x, params)) |
| 123 | + f_x = model.logp() |
| 124 | + jac = pytensor.gradient.grad(f_x, x) |
| 125 | + hess = pytensor.gradient.jacobian(jac.flatten(), x) |
| 126 | + |
| 127 | + # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) |
| 128 | + log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) |
| 129 | + |
| 130 | + # Maximize log(p(x | y, params)) wrt x to find mode x0 |
| 131 | + x0, _ = minimize( |
| 132 | + objective=-log_x_posterior, |
| 133 | + x=x, |
| 134 | + method=method, |
| 135 | + jac=use_jac, |
| 136 | + hess=use_hess, |
| 137 | + optimizer_kwargs=optimizer_kwargs, |
| 138 | + ) |
| 139 | + |
| 140 | + # require f'(x0) and f''(x0) for Laplace approx |
| 141 | + jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) |
| 142 | + hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) |
| 143 | + |
| 144 | + # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) |
| 145 | + _, logdetQ = pt.nlinalg.slogdet(Q) |
| 146 | + conditional_gaussian_approx = ( |
| 147 | + -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ |
| 148 | + ) |
| 149 | + |
| 150 | + # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is |
| 151 | + # far from the mode x0 or in a neighbourhood which results in poor convergence. |
| 152 | + return pytensor.function(args, [x0, conditional_gaussian_approx]) |
| 153 | + |
| 154 | + |
55 | 155 | def laplace_draws_to_inferencedata(
|
56 | 156 | posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
|
57 | 157 | ) -> az.InferenceData:
|
@@ -308,6 +408,8 @@ def fit_mvn_at_MAP(
|
308 | 408 | )
|
309 | 409 |
|
310 | 410 | H = -f_hess(mu.data)
|
| 411 | + if H.ndim == 1: |
| 412 | + H = np.expand_dims(H, axis=1) |
311 | 413 | H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
|
312 | 414 |
|
313 | 415 | def stabilize(x, jitter):
|
|
0 commit comments