Skip to content

Commit 862e52d

Browse files
Implement a minimizer for INLA (#513)
1 parent 7eaaccd commit 862e52d

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

pymc_extras/inference/laplace.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717

18+
from collections.abc import Callable
1819
from functools import reduce
1920
from importlib.util import find_spec
2021
from itertools import product
@@ -29,6 +30,7 @@
2930

3031
from arviz import dict_to_dataset
3132
from better_optimize.constants import minimize_method
33+
from numpy.typing import ArrayLike
3234
from pymc import DictToArrayBijection
3335
from pymc.backends.arviz import (
3436
coords_and_dims_for_inferencedata,
@@ -39,6 +41,8 @@
3941
from pymc.model.transform.conditioning import remove_value_transforms
4042
from pymc.model.transform.optimization import freeze_dims_and_data
4143
from pymc.util import get_default_varnames
44+
from pytensor.tensor import TensorVariable
45+
from pytensor.tensor.optimize import minimize
4246
from scipy import stats
4347

4448
from pymc_extras.inference.find_map import (
@@ -52,6 +56,102 @@
5256
_log = logging.getLogger(__name__)
5357

5458

59+
def get_conditional_gaussian_approximation(
60+
x: TensorVariable,
61+
Q: TensorVariable | ArrayLike,
62+
mu: TensorVariable | ArrayLike,
63+
args: list[TensorVariable] | None = None,
64+
model: pm.Model | None = None,
65+
method: minimize_method = "BFGS",
66+
use_jac: bool = True,
67+
use_hess: bool = False,
68+
optimizer_kwargs: dict | None = None,
69+
) -> Callable:
70+
"""
71+
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
72+
73+
That is:
74+
y | x, sigma ~ N(Ax, sigma^2 W)
75+
x | params ~ N(mu, Q(params)^-1)
76+
77+
We seek to estimate log(p(x | y, params)):
78+
79+
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
80+
81+
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
82+
83+
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
84+
85+
Thus:
86+
87+
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
88+
89+
2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
90+
91+
Parameters
92+
----------
93+
x: TensorVariable
94+
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
95+
Q: TensorVariable | ArrayLike
96+
The precision matrix of the latent field x.
97+
mu: TensorVariable | ArrayLike
98+
The mean of the latent field x.
99+
args: list[TensorVariable]
100+
Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
101+
model: Model
102+
PyMC model to use.
103+
method: minimize_method
104+
Which minimization algorithm to use.
105+
use_jac: bool
106+
If true, the minimizer will compute the gradient of log(p(x | y, params)).
107+
use_hess: bool
108+
If true, the minimizer will compute the Hessian log(p(x | y, params)).
109+
optimizer_kwargs: dict
110+
Kwargs to pass to scipy.optimize.minimize.
111+
112+
Returns
113+
-------
114+
f: Callable
115+
A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
116+
"""
117+
model = pm.modelcontext(model)
118+
119+
if args is None:
120+
args = model.continuous_value_vars + model.discrete_value_vars
121+
122+
# f = log(p(y | x, params))
123+
f_x = model.logp()
124+
jac = pytensor.gradient.grad(f_x, x)
125+
hess = pytensor.gradient.jacobian(jac.flatten(), x)
126+
127+
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
128+
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
129+
130+
# Maximize log(p(x | y, params)) wrt x to find mode x0
131+
x0, _ = minimize(
132+
objective=-log_x_posterior,
133+
x=x,
134+
method=method,
135+
jac=use_jac,
136+
hess=use_hess,
137+
optimizer_kwargs=optimizer_kwargs,
138+
)
139+
140+
# require f'(x0) and f''(x0) for Laplace approx
141+
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
142+
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
143+
144+
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
145+
_, logdetQ = pt.nlinalg.slogdet(Q)
146+
conditional_gaussian_approx = (
147+
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
148+
)
149+
150+
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
151+
# far from the mode x0 or in a neighbourhood which results in poor convergence.
152+
return pytensor.function(args, [x0, conditional_gaussian_approx])
153+
154+
55155
def laplace_draws_to_inferencedata(
56156
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
57157
) -> az.InferenceData:
@@ -308,6 +408,8 @@ def fit_mvn_at_MAP(
308408
)
309409

310410
H = -f_hess(mu.data)
411+
if H.ndim == 1:
412+
H = np.expand_dims(H, axis=1)
311413
H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
312414

313415
def stabilize(x, jitter):

tests/test_laplace.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pymc_extras.inference.laplace import (
2424
fit_laplace,
2525
fit_mvn_at_MAP,
26+
get_conditional_gaussian_approximation,
2627
sample_laplace_posterior,
2728
)
2829

@@ -279,3 +280,85 @@ def test_laplace_scalar():
279280
assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
280281

281282
np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)
283+
284+
285+
def test_get_conditional_gaussian_approximation():
286+
"""
287+
Consider the trivial case of:
288+
289+
y | x ~ N(x, cov_param)
290+
x | param ~ N(mu_param, Q^-1)
291+
292+
cov_param ~ N(cov_mu, cov_cov)
293+
mu_param ~ N(mu_mu, mu_cov)
294+
Q ~ N(Q_mu, Q_cov)
295+
296+
This has an analytic solution at the mode which we can compare against.
297+
"""
298+
rng = np.random.default_rng(12345)
299+
n = 10000
300+
d = 10
301+
302+
# Initialise arrays
303+
mu_true = rng.random(d)
304+
cov_true = np.diag(rng.random(d))
305+
Q_val = np.diag(rng.random(d))
306+
cov_param_val = np.diag(rng.random(d))
307+
308+
x_val = rng.random(d)
309+
mu_val = rng.random(d)
310+
311+
mu_mu = rng.random(d)
312+
mu_cov = np.diag(np.ones(d))
313+
cov_mu = rng.random(d**2)
314+
cov_cov = np.diag(np.ones(d**2))
315+
Q_mu = rng.random(d**2)
316+
Q_cov = np.diag(np.ones(d**2))
317+
318+
with pm.Model() as model:
319+
y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)
320+
321+
mu_param = pm.MvNormal("mu_param", mu=mu_mu, cov=mu_cov)
322+
cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov)
323+
Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov)
324+
325+
# Pytensor currently doesn't support autograd for pt inverses, so we use a numeric Q instead
326+
x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val))
327+
328+
y = pm.MvNormal(
329+
"y",
330+
mu=x,
331+
cov=cov_param.reshape((d, d)),
332+
observed=y_obs,
333+
)
334+
335+
# logp(x | y, params)
336+
cga = get_conditional_gaussian_approximation(
337+
x=model.rvs_to_values[x],
338+
Q=Q.reshape((d, d)),
339+
mu=mu_param,
340+
optimizer_kwargs={"tol": 1e-25},
341+
)
342+
343+
x0, log_x_posterior = cga(
344+
x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten()
345+
)
346+
347+
# Get analytic values of the mode and Laplace-approximated log posterior
348+
cov_param_inv = np.linalg.inv(cov_param_val)
349+
350+
x0_true = np.linalg.inv(n * cov_param_inv + 2 * Q_val) @ (
351+
cov_param_inv @ y_obs.sum(axis=0) + 2 * Q_val @ mu_val
352+
)
353+
354+
jac_true = cov_param_inv @ (y_obs - x0_true).sum(axis=0) - Q_val @ (x0_true - mu_val)
355+
hess_true = -n * cov_param_inv - Q_val
356+
357+
log_x_posterior_laplace_true = (
358+
-0.5 * x_val.T @ (-hess_true + Q_val) @ x_val
359+
+ x_val.T @ (Q_val @ mu_val + jac_true - hess_true @ x0_true)
360+
+ 0.5 * np.log(np.linalg.det(Q_val))
361+
)
362+
363+
np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1)
364+
np.testing.assert_allclose(log_x_posterior, log_x_posterior_laplace_true, atol=0.1, rtol=0.1)

0 commit comments

Comments
 (0)