From 23140a5d07e13a94eb584f9c152a7bf1b8e7067e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 28 Jun 2025 21:43:39 +0800 Subject: [PATCH 1/7] Move laplace and find_map to submodule --- pymc_extras/inference/__init__.py | 4 +- pymc_extras/inference/fit.py | 2 +- .../inference/laplace_approx/__init__.py | 0 .../{ => laplace_approx}/find_map.py | 138 +++++++----- .../inference/{ => laplace_approx}/laplace.py | 207 +----------------- .../inference/pathfinder/pathfinder.py | 2 +- tests/inference/__init__.py | 0 tests/inference/laplace_approx/__init__.py | 0 .../laplace_approx}/test_find_map.py | 47 ++-- .../laplace_approx}/test_laplace.py | 4 +- 10 files changed, 136 insertions(+), 268 deletions(-) create mode 100644 pymc_extras/inference/laplace_approx/__init__.py rename pymc_extras/inference/{ => laplace_approx}/find_map.py (81%) rename pymc_extras/inference/{ => laplace_approx}/laplace.py (76%) create mode 100644 tests/inference/__init__.py create mode 100644 tests/inference/laplace_approx/__init__.py rename tests/{ => inference/laplace_approx}/test_find_map.py (79%) rename tests/{ => inference/laplace_approx}/test_laplace.py (98%) diff --git a/pymc_extras/inference/__init__.py b/pymc_extras/inference/__init__.py index a01fdd5c3..a536f91e6 100644 --- a/pymc_extras/inference/__init__.py +++ b/pymc_extras/inference/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pymc_extras.inference.find_map import find_MAP from pymc_extras.inference.fit import fit -from pymc_extras.inference.laplace import fit_laplace +from pymc_extras.inference.laplace_approx.find_map import find_MAP +from pymc_extras.inference.laplace_approx.laplace import fit_laplace from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder __all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"] diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py index 5b83ff1f3..ac51e76bb 100644 --- a/pymc_extras/inference/fit.py +++ b/pymc_extras/inference/fit.py @@ -37,6 +37,6 @@ def fit(method: str, **kwargs) -> az.InferenceData: return fit_pathfinder(**kwargs) if method == "laplace": - from pymc_extras.inference.laplace import fit_laplace + from pymc_extras.inference import fit_laplace return fit_laplace(**kwargs) diff --git a/pymc_extras/inference/laplace_approx/__init__.py b/pymc_extras/inference/laplace_approx/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_extras/inference/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py similarity index 81% rename from pymc_extras/inference/find_map.py rename to pymc_extras/inference/laplace_approx/find_map.py index a4d664789..6100097a6 100644 --- a/pymc_extras/inference/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -114,14 +114,14 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, def _compile_grad_and_hess_to_jax( - f_loss: Function, use_hess: bool, use_hessp: bool + f_fused: Function, use_hess: bool, use_hessp: bool ) -> tuple[Callable | None, Callable | None]: """ Compile loss function gradients using JAX. Parameters ---------- - f_loss: Function + f_fused: Function The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss, compiled with mode="JAX". use_hess: bool @@ -131,43 +131,40 @@ def _compile_grad_and_hess_to_jax( Returns ------- - f_loss_and_grad: Callable - The compiled loss function and gradient function. - f_hess: Callable | None - The compiled hessian function, or None if use_hess is False. + f_fused: Callable + The compiled loss function and gradient function, which may also compute the hessian if requested. f_hessp: Callable | None The compiled hessian-vector product function, or None if use_hessp is False. """ import jax - f_hess = None f_hessp = None - orig_loss_fn = f_loss.vm.jit_fn + orig_loss_fn = f_fused.vm.jit_fn - @jax.jit - def loss_fn_jax_grad(x): - return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + if use_hess: + + @jax.jit + def loss_fn_fused(x): + loss_and_grad = jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + hess = jax.hessian(lambda x: orig_loss_fn(x)[0])(x) + return *loss_and_grad, hess + + else: - f_loss_and_grad = loss_fn_jax_grad + @jax.jit + def loss_fn_fused(x): + return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) if use_hessp: def f_hessp_jax(x, p): - y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,)) + y, u = jax.jvp(lambda x: loss_fn_fused(x)[1], (x,), (p,)) return jax.numpy.stack(u) f_hessp = jax.jit(f_hessp_jax) - if use_hess: - _f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1]) - - def f_hess_jax(x): - return jax.numpy.stack(_f_hess_jax(x)) - - f_hess = jax.jit(f_hess_jax) - - return f_loss_and_grad, f_hess, f_hessp + return loss_fn_fused, f_hessp def _compile_functions_for_scipy_optimize( @@ -199,33 +196,47 @@ def _compile_functions_for_scipy_optimize( Returns ------- - f_loss: Function - - f_hess: Function | None + f_fused: Function + The compiled loss function, which may also include gradients and hessian if requested. f_hessp: Function | None + The compiled hessian-vector product function, or None if compute_hessp is False. """ + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + loss = pm.pytensorf.rewrite_pregrad(loss) - f_hess = None f_hessp = None - if compute_grad: - grads = pytensor.gradient.grad(loss, inputs) - grad = pt.concatenate([grad.ravel() for grad in grads]) - f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs) - else: + # In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent + # with the case where we also compute gradients, hessians, or hessian-vector products. + if not (compute_grad or compute_hess or compute_hessp): f_loss = pm.compile(inputs, loss, **compile_kwargs) return [f_loss] - if compute_hess: - hess = pytensor.gradient.jacobian(grad, inputs)[0] - f_hess = pm.compile(inputs, hess, **compile_kwargs) + # Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single + # fused function and retun it. If the user also wants the hession, the fused function will return the loss, + # gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss + # and gradients, and a separate function for the hessian-vector product. if compute_hessp: + # Handle this first, since it can be compiled alone. p = pt.tensor("p", shape=inputs[0].type.shape) hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs) - return [f_loss_and_grad, f_hess, f_hessp] + outputs = [loss] + + if compute_grad: + grads = pytensor.gradient.grad(loss, inputs) + grad = pt.concatenate([grad.ravel() for grad in grads]) + outputs.append(grad) + + if compute_hess: + hess = pytensor.gradient.jacobian(grad, inputs)[0] + outputs.append(hess) + + f_fused = pm.compile(inputs, outputs, **compile_kwargs) + + return [f_fused, f_hessp] def scipy_optimize_funcs_from_loss( @@ -262,10 +273,8 @@ def scipy_optimize_funcs_from_loss( Returns ------- - f_loss: Callable - The compiled loss function. - f_hess: Callable | None - The compiled hessian function, or None if use_hess is False. + f_fused: Callable + The compiled loss function, which may also include gradients and hessian if requested. f_hessp: Callable | None The compiled hessian-vector product function, or None if use_hessp is False. """ @@ -322,16 +331,15 @@ def scipy_optimize_funcs_from_loss( compile_kwargs=compile_kwargs, ) - # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values - f_loss = funcs.pop(0) - f_hess = funcs.pop(0) if compute_grad else None - f_hessp = funcs.pop(0) if compute_grad else None + # Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients, + # or the loss function with gradients and hessian. + f_fused = funcs.pop(0) + f_hessp = funcs.pop(0) if compute_hessp else None if use_jax_gradients: - # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values - f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp) + f_fused, f_hessp = _compile_grad_and_hess_to_jax(f_fused, use_hess, use_hessp) - return f_loss, f_hess, f_hessp + return f_fused, f_hessp def find_MAP( @@ -434,7 +442,7 @@ def find_MAP( method, use_grad, use_hess, use_hessp ) - f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss( + f_fused, f_hessp = scipy_optimize_funcs_from_loss( loss=-frozen_model.logp(jacobian=False), inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, initial_point_dict=start_dict, @@ -445,7 +453,7 @@ def find_MAP( compile_kwargs=compile_kwargs, ) - args = optimizer_kwargs.pop("args", None) + args = optimizer_kwargs.pop("args", ()) # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument # if so. That is why the jac argument is not passed here in either branch. @@ -453,15 +461,13 @@ def find_MAP( if do_basinhopping: if "args" not in minimizer_kwargs: minimizer_kwargs["args"] = args - if "hess" not in minimizer_kwargs: - minimizer_kwargs["hess"] = f_hess if "hessp" not in minimizer_kwargs: minimizer_kwargs["hessp"] = f_hessp if "method" not in minimizer_kwargs: minimizer_kwargs["method"] = method optimizer_result = basinhopping( - func=f_logp, + func=f_fused, x0=cast(np.ndarray[float], initial_params.data), progressbar=progressbar, minimizer_kwargs=minimizer_kwargs, @@ -470,10 +476,9 @@ def find_MAP( else: optimizer_result = minimize( - f=f_logp, + f=f_fused, x0=cast(np.ndarray[float], initial_params.data), args=args, - hess=f_hess, hessp=f_hessp, progressbar=progressbar, method=method, @@ -486,6 +491,33 @@ def find_MAP( DictToArrayBijection.rmap(raveled_optimized) ) + # Downstream computation will probably want the covaraince matrix at the optimized point, so we compute it here, + # while we still have access to the compiled function. + x_star = optimizer_result.x + n_vars = len(x_star) + + if method == "BFGS": + # If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than + # re-computing something + getattr(optimizer_result, "hess_inv", None) + elif method == "L-BFGS-B": + # Here we will have a LinearOperator representing the inverse Hessian-Vector product. + f_hessp_inv = optimizer_result.hess_inv + basis = np.eye(n_vars) + np.stack([f_hessp_inv(basis[:, i]) for i in range(n_vars)], axis=-1) + + elif f_hessp is not None: + # In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from + # the hessp function, using euclidian basis vector. + basis = np.eye(n_vars) + H = np.stack([f_hessp(optimizer_result.x, basis[:, i]) for i in range(n_vars)], axis=-1) + np.linalg.inv(get_nearest_psd(H)) + + elif use_hess: + # If we compiled a hessian function, just use it + _, _, H = f_fused(x_star) + np.linalg.inv(get_nearest_psd(H)) + optimized_point = { var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) } diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py similarity index 76% rename from pymc_extras/inference/laplace.py rename to pymc_extras/inference/laplace_approx/laplace.py index d64d2adab..488d41911 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -16,9 +16,7 @@ import logging from collections.abc import Callable -from functools import reduce from importlib.util import find_spec -from itertools import product from typing import Literal import arviz as az @@ -26,32 +24,29 @@ import pymc as pm import pytensor import pytensor.tensor as pt -import xarray as xr -from arviz import dict_to_dataset from better_optimize.constants import minimize_method from numpy.typing import ArrayLike from pymc import DictToArrayBijection -from pymc.backends.arviz import ( - coords_and_dims_for_inferencedata, - find_constants, - find_observations, -) from pymc.blocking import RaveledVars from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data -from pymc.util import get_default_varnames from pytensor.tensor import TensorVariable from pytensor.tensor.optimize import minimize from scipy import stats -from pymc_extras.inference.find_map import ( +from pymc_extras.inference.laplace_approx.find_map import ( GradientBackend, _unconstrained_vector_to_constrained_rvs, find_MAP, get_nearest_psd, scipy_optimize_funcs_from_loss, ) +from pymc_extras.inference.laplace_approx.idata import ( + add_data_to_inferencedata, + add_fit_to_inferencedata, + laplace_draws_to_inferencedata, +) _log = logging.getLogger(__name__) @@ -152,186 +147,6 @@ def get_conditional_gaussian_approximation( return pytensor.function(args, [x0, conditional_gaussian_approx]) -def laplace_draws_to_inferencedata( - posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None -) -> az.InferenceData: - """ - Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object. - - - Parameters - ---------- - posterior_draws: list of np.ndarray - A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where - shape is the shape of the variable in the posterior. - model: Model, optional - A PyMC model. If None, the model is taken from the current model context. - - Returns - ------- - idata: az.InferenceData - An InferenceData object containing the approximated posterior samples - """ - model = pm.modelcontext(model) - chains, draws, *_ = posterior_draws[0].shape - - def make_rv_coords(name): - coords = {"chain": range(chains), "draw": range(draws)} - extra_dims = model.named_vars_to_dims.get(name) - if extra_dims is None: - return coords - return coords | {dim: list(model.coords[dim]) for dim in extra_dims} - - def make_rv_dims(name): - dims = ["chain", "draw"] - extra_dims = model.named_vars_to_dims.get(name) - if extra_dims is None: - return dims - return dims + list(extra_dims) - - names = [ - x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) - ] - idata = { - name: xr.DataArray( - data=draws, - coords=make_rv_coords(name), - dims=make_rv_dims(name), - name=name, - ) - for name, draws in zip(names, posterior_draws) - } - - coords, dims = coords_and_dims_for_inferencedata(model) - idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) - - return idata - - -def add_fit_to_inferencedata( - idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None -) -> az.InferenceData: - """ - Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. - - - Parameters - ---------- - idata: az.InfereceData - An InferenceData object containing the approximated posterior samples. - mu: RaveledVars - The MAP estimate of the model parameters. - H_inv: np.ndarray - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. - model: Model, optional - A PyMC model. If None, the model is taken from the current model context. - - Returns - ------- - idata: az.InferenceData - The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group. - """ - model = pm.modelcontext(model) - coords = model.coords - - variable_names, *_ = zip(*mu.point_map_info) - - def make_unpacked_variable_names(name): - value_to_dim = { - x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None) - for x in model.value_vars - } - value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None} - - rv_to_dim = model.named_vars_to_dims - dims_dict = rv_to_dim | value_to_dim - - dims = dims_dict.get(name) - if dims is None: - return [name] - labels = product(*(coords[dim] for dim in dims)) - return [f"{name}[{','.join(map(str, label))}]" for label in labels] - - unpacked_variable_names = reduce( - lambda lst, name: lst + make_unpacked_variable_names(name), variable_names, [] - ) - - mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names}) - cov_dataarray = xr.DataArray( - H_inv, - dims=["rows", "columns"], - coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, - ) - - dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray}) - idata.add_groups(fit=dataset) - - return idata - - -def add_data_to_inferencedata( - idata: az.InferenceData, - progressbar: bool = True, - model: pm.Model | None = None, - compile_kwargs: dict | None = None, -) -> az.InferenceData: - """ - Add observed and constant data to an InferenceData object. - - Parameters - ---------- - idata: az.InferenceData - An InferenceData object containing the approximated posterior samples. - progressbar: bool - Whether to display a progress bar during computations. Default is True. - model: Model, optional - A PyMC model. If None, the model is taken from the current model context. - compile_kwargs: dict, optional - Additional keyword arguments to pass to pytensor.function. - - Returns - ------- - idata: az.InferenceData - The provided InferenceData, with observed and constant data added. - """ - model = pm.modelcontext(model) - - if model.deterministics: - idata.posterior = pm.compute_deterministics( - idata.posterior, - model=model, - merge_dataset=True, - progressbar=progressbar, - compile_kwargs=compile_kwargs, - ) - - coords, dims = coords_and_dims_for_inferencedata(model) - - observed_data = dict_to_dataset( - find_observations(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - - constant_data = dict_to_dataset( - find_constants(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - - idata.add_groups( - {"observed_data": observed_data, "constant_data": constant_data}, - coords=coords, - dims=dims, - ) - - return idata - - def fit_mvn_at_MAP( optimized_point: dict[str, np.ndarray], model: pm.Model | None = None, @@ -348,8 +163,8 @@ def fit_mvn_at_MAP( Parameters ---------- - optimized_point : dict[str, np.ndarray] - Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map + optimized_point : idata + Local maximum a posteriori (MAP) point returned from pymc_extras.inference.find_MAP model : Model, optional A PyMC model. If None, the model is taken from the current model context. on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' @@ -396,7 +211,7 @@ def fit_mvn_at_MAP( optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} mu = DictToArrayBijection.map(optimized_free_params) - _, f_hess, _ = scipy_optimize_funcs_from_loss( + f_fused, _ = scipy_optimize_funcs_from_loss( loss=-logp, inputs=variables, initial_point_dict=optimized_free_params, @@ -407,7 +222,7 @@ def fit_mvn_at_MAP( compile_kwargs=compile_kwargs, ) - H = -f_hess(mu.data) + H = -f_fused(mu.data)[-1] if H.ndim == 1: H = np.expand_dims(H, axis=1) H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) @@ -619,7 +434,7 @@ def fit_laplace( Examples -------- - >>> from pymc_extras.inference.laplace import fit_laplace + >>> from pymc_extras.inference import fit_laplace >>> import numpy as np >>> import pymc as pm >>> import arviz as az diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index cddc175ba..4f280f8fc 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -63,7 +63,7 @@ # TODO: change to typing.Self after Python versions greater than 3.10 from typing_extensions import Self -from pymc_extras.inference.laplace import add_data_to_inferencedata +from pymc_extras.inference.laplace_approx.idata import add_data_to_inferencedata from pymc_extras.inference.pathfinder.importance_sampling import ( importance_sampling as _importance_sampling, ) diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/laplace_approx/__init__.py b/tests/inference/laplace_approx/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py similarity index 79% rename from tests/test_find_map.py rename to tests/inference/laplace_approx/test_find_map.py index f5aa549c7..d09fb3d87 100644 --- a/tests/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -4,7 +4,7 @@ import pytensor.tensor as pt import pytest -from pymc_extras.inference.find_map import ( +from pymc_extras.inference.laplace_approx.find_map import ( GradientBackend, find_MAP, scipy_optimize_funcs_from_loss, @@ -29,7 +29,7 @@ def compute_z(x): return z1, z2 z = pt.stack(compute_z(x)) - f_loss, f_hess, f_hessp = scipy_optimize_funcs_from_loss( + f_fused, f_hessp = scipy_optimize_funcs_from_loss( loss=z.sum(), inputs=[x], initial_point_dict={"x": np.array([1.0, 2.0])}, @@ -43,11 +43,11 @@ def compute_z(x): x_val = np.array([1.0, 2.0]) expected_z = sum(compute_z(x_val)) - z_jax, grad_val = f_loss(x_val) + z_jax, grad_val, hess_val = f_fused(x_val) np.testing.assert_allclose(z_jax, expected_z) np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]])) - hess_val = np.array(f_hess(x_val)) + hess_val = np.array(hess_val) np.testing.assert_allclose(hess_val.squeeze(), np.array([[2, 1], [1, 0]])) hessp_val = np.array(f_hessp(x_val, np.array([1.0, 0.0]))) @@ -75,8 +75,15 @@ def compute_z(x): ("trust-constr", True, True, False), ], ) -@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) -def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng): +@pytest.mark.parametrize( + "backend, gradient_backend", + # JAX backend is faster, so only test it + [("jax", "jax"), ("jax", "pytensor")], + ids=str, +) +def test_find_MAP( + method, use_grad, use_hess, use_hessp, backend, gradient_backend: GradientBackend, rng +): extra_kwargs = {} if method == "dogleg": # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point @@ -96,7 +103,7 @@ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: Gradie use_hessp=use_hessp, progressbar=False, gradient_backend=gradient_backend, - compile_kwargs={"mode": "JAX"}, + compile_kwargs={"mode": backend.upper()}, ) mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] @@ -104,7 +111,13 @@ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: Gradie assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) -def test_JAX_map_shared_variables(): +@pytest.mark.parametrize( + "backend, gradient_backend", + # JAX backend is faster, so only test it + [("jax", "jax")], + ids=str, +) +def test_map_shared_variables(backend, gradient_backend: GradientBackend): with pm.Model() as m: data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data") mu = pm.Normal("mu") @@ -117,8 +130,8 @@ def test_JAX_map_shared_variables(): use_hess=False, use_hessp=False, progressbar=False, - gradient_backend="jax", - compile_kwargs={"mode": "JAX"}, + gradient_backend=gradient_backend, + compile_kwargs={"mode": backend.upper()}, ) mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] @@ -135,7 +148,15 @@ def test_JAX_map_shared_variables(): ("trust-ncg", True, False, True), ], ) -def test_find_MAP_basinhopping(method, use_grad, use_hess, use_hessp, rng): +@pytest.mark.parametrize( + "backend, gradient_backend", + # JAX backend is faster, so only test it) + [("jax", "pytensor")], + ids=str, +) +def test_find_MAP_basinhopping( + method, use_grad, use_hess, use_hessp, backend, gradient_backend, rng +): with pm.Model() as m: mu = pm.Normal("mu") sigma = pm.Exponential("sigma", 1) @@ -147,8 +168,8 @@ def test_find_MAP_basinhopping(method, use_grad, use_hess, use_hessp, rng): use_hess=use_hess, use_hessp=use_hessp, progressbar=False, - gradient_backend="pytensor", - compile_kwargs={"mode": "JAX"}, + gradient_backend=gradient_backend, + compile_kwargs={"mode": backend.upper()}, minimizer_kwargs=dict(method=method), ) diff --git a/tests/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py similarity index 98% rename from tests/test_laplace.py rename to tests/inference/laplace_approx/test_laplace.py index 72ff3e937..31c8eaf2b 100644 --- a/tests/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -19,8 +19,8 @@ import pymc_extras as pmx -from pymc_extras.inference.find_map import GradientBackend, find_MAP -from pymc_extras.inference.laplace import ( +from pymc_extras.inference.laplace_approx.find_map import GradientBackend, find_MAP +from pymc_extras.inference.laplace_approx.laplace import ( fit_laplace, fit_mvn_at_MAP, get_conditional_gaussian_approximation, From 25b0805bd6ac3d85b9ff6b0562ae76ddabae5f9d Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 29 Jun 2025 05:18:32 +0800 Subject: [PATCH 2/7] Split idata utilities into `idata.py` --- pymc_extras/inference/laplace_approx/idata.py | 378 ++++++++++++++++++ .../inference/pathfinder/pathfinder.py | 4 +- tests/inference/laplace_approx/test_idata.py | 297 ++++++++++++++ 3 files changed, 677 insertions(+), 2 deletions(-) create mode 100644 pymc_extras/inference/laplace_approx/idata.py create mode 100644 tests/inference/laplace_approx/test_idata.py diff --git a/pymc_extras/inference/laplace_approx/idata.py b/pymc_extras/inference/laplace_approx/idata.py new file mode 100644 index 000000000..b82031f94 --- /dev/null +++ b/pymc_extras/inference/laplace_approx/idata.py @@ -0,0 +1,378 @@ +from itertools import product +from typing import Any, Literal + +import arviz as az +import numpy as np +import pymc as pm +import xarray as xr + +from arviz import dict_to_dataset +from better_optimize.constants import minimize_method +from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations +from pymc.blocking import RaveledVars +from scipy.optimize import OptimizeResult +from scipy.sparse.linalg import LinearOperator + + +def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]: + coords = model.coords + initial_point = model.initial_point() + + value_to_dim = { + value.name: model.named_vars_to_dims.get(model.values_to_rvs[value].name, None) + for value in model.value_vars + } + value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None} + + rv_to_dim = model.named_vars_to_dims + dims_dict = rv_to_dim | value_to_dim + + unpacked_variable_names = [] + for name in names: + shape = initial_point[name].shape + if shape: + labels_by_dim = [ + coords[dim] if shape[i] == len(coords[dim]) else np.arange(shape[i]) + for i, dim in enumerate(dims_dict.get(name, [name])) + ] + labels = product(*labels_by_dim) + unpacked_variable_names.extend( + [f"{name}[{','.join(map(str, label))}]" for label in labels] + ) + else: + unpacked_variable_names.extend([name]) + return unpacked_variable_names + + +def map_results_to_inference_data(results: dict[str, Any], model: pm.Model | None = None): + """ + Convert a dictionary of results to an InferenceData object. + + Parameters + ---------- + results: dict + A dictionary containing the results to convert. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + + Returns + ------- + idata: az.InferenceData + An InferenceData object containing the results. + """ + model = pm.modelcontext(model) + coords, dims = coords_and_dims_for_inferencedata(model) + + idata = az.convert_to_inference_data(results, coords=coords, dims=dims) + return idata + + +def add_map_posterior_to_inference_data( + idata: az.InferenceData, + map_point: dict[str, float | int | np.ndarray], + model: pm.Model | None = None, +): + """ + Add the MAP point to an InferenceData object in the posterior group. + + Unlike a typical posterior, the MAP point is a single point estimate rather than a distribution. As a result, it + does not have a chain or draw dimension, and is stored as a single point in the posterior group. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object to which the MAP point will be added. + map_point: dict + A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and + the values should be the corresponding MAP estimates. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the MAP point added to the posterior group. + """ + + model = pm.modelcontext(model) if model is None else model + coords, dims = coords_and_dims_for_inferencedata(model) + initial_point = model.initial_point() + + # The MAP point will have both the transformed and untransformed variables, so we need to ensure that + # we have the correct dimensions for each variable. + var_name_to_value_name = { + rv.name: value.name + for rv, value in model.rvs_to_values.items() + if rv not in model.observed_RVs + } + dims.update( + { + value_name: dims[var_name] + for var_name, value_name in var_name_to_value_name.items() + if var_name in dims and (initial_point[value_name].shape == map_point[var_name].shape) + } + ) + + idata = az.from_dict( + {k: np.expand_dims(v, (0, 1)) for k, v in map_point.items()}, coords=coords, dims=dims + ) + + return idata + + +def add_fit_to_inference_data( + idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None +) -> az.InferenceData: + """ + Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. + + Parameters + ---------- + idata: az.InfereceData + An InferenceData object containing the approximated posterior samples. + mu: RaveledVars + The MAP estimate of the model parameters. + H_inv: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group. + """ + model = pm.modelcontext(model) if model is None else model + + variable_names, *_ = zip(*mu.point_map_info) + + unpacked_variable_names = make_unpacked_variable_names(variable_names, model) + + mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names}) + + data = {"mean_vector": mean_dataarray} + + if H_inv is not None: + cov_dataarray = xr.DataArray( + H_inv, + dims=["rows", "columns"], + coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, + ) + data["covariance_matrix"] = cov_dataarray + + dataset = xr.Dataset(data) + idata.add_groups(fit=dataset) + + return idata + + +def add_data_to_inference_data( + idata: az.InferenceData, + progressbar: bool = True, + model: pm.Model | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Add observed and constant data to an InferenceData object. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + progressbar: bool + Whether to display a progress bar during computations. Default is True. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with observed and constant data added. + """ + model = pm.modelcontext(model) if model is None else model + + if model.deterministics: + expand_dims = {} + if "chain" not in idata.posterior.coords: + expand_dims["chain"] = [0] + if "draw" not in idata.posterior.coords: + expand_dims["draw"] = [0] + + idata.posterior = pm.compute_deterministics( + idata.posterior.expand_dims(expand_dims), + model=model, + merge_dataset=True, + progressbar=progressbar, + compile_kwargs=compile_kwargs, + ) + + coords, dims = coords_and_dims_for_inferencedata(model) + + observed_data = dict_to_dataset( + find_observations(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + constant_data = dict_to_dataset( + find_constants(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + idata.add_groups( + {"observed_data": observed_data, "constant_data": constant_data}, + coords=coords, + dims=dims, + ) + + return idata + + +def optimizer_result_to_dataset( + result: OptimizeResult, + method: minimize_method | Literal["basinhopping"], + mu: RaveledVars | None = None, + model: pm.Model | None = None, +) -> xr.Dataset: + """ + Convert an OptimizeResult object to an xarray Dataset object. + + Parameters + ---------- + result: OptimizeResult + The result of the optimization process. + method: minimize_method or "basinhopping" + The optimization method used. + + Returns + ------- + dataset: xr.Dataset + An xarray Dataset containing the optimization results. + """ + if not isinstance(result, OptimizeResult): + raise TypeError("result must be an instance of OptimizeResult") + + model = pm.modelcontext(model) if model is None else model + variable_names, *_ = zip(*mu.point_map_info) + unpacked_variable_names = make_unpacked_variable_names(variable_names, model) + + data_vars = {} + + if hasattr(result, "lowest_optimization_result"): + # If we did basinhopping, there's a results inside the results. We want to pop this out and collapse them, + # overwriting outer keys with the inner keys + inner_res = result.pop("lowest_optimization_result") + for key in inner_res.keys(): + result[key] = inner_res[key] + + if hasattr(result, "x"): + data_vars["x"] = xr.DataArray( + result.x, dims=["variables"], coords={"variables": unpacked_variable_names} + ) + if hasattr(result, "fun"): + data_vars["fun"] = xr.DataArray(result.fun, dims=[]) + if hasattr(result, "success"): + data_vars["success"] = xr.DataArray(result.success, dims=[]) + if hasattr(result, "message"): + data_vars["message"] = xr.DataArray(str(result.message), dims=[]) + if hasattr(result, "jac") and result.jac is not None: + jac = np.asarray(result.jac) + if jac.ndim == 1: + data_vars["jac"] = xr.DataArray( + jac, dims=["variables"], coords={"variables": unpacked_variable_names} + ) + else: + data_vars["jac"] = xr.DataArray( + jac, + dims=["variables", "variables_aux"], + coords={ + "variables": unpacked_variable_names, + "variables_aux": unpacked_variable_names, + }, + ) + + if hasattr(result, "hess_inv") and result.hess_inv is not None: + hess_inv = result.hess_inv + if isinstance(hess_inv, LinearOperator): + n = hess_inv.shape[0] + eye = np.eye(n) + hess_inv_mat = np.column_stack([hess_inv.matvec(eye[:, i]) for i in range(n)]) + hess_inv = hess_inv_mat + else: + hess_inv = np.asarray(hess_inv) + data_vars["hess_inv"] = xr.DataArray( + hess_inv, + dims=["variables", "variables_aux"], + coords={"variables": unpacked_variable_names, "variables_aux": unpacked_variable_names}, + ) + + if hasattr(result, "nit"): + data_vars["nit"] = xr.DataArray(result.nit, dims=[]) + if hasattr(result, "nfev"): + data_vars["nfev"] = xr.DataArray(result.nfev, dims=[]) + if hasattr(result, "njev"): + data_vars["njev"] = xr.DataArray(result.njev, dims=[]) + if hasattr(result, "status"): + data_vars["status"] = xr.DataArray(result.status, dims=[]) + + # Add any other fields present in result + for key, value in result.items(): + if key in data_vars: + continue # already added + if value is None: + continue + arr = np.asarray(value) + + # TODO: We can probably do something smarter here with a dictionary of all possible values and their expected + # dimensions. + dims = [f"{key}_dim_{i}" for i in range(arr.ndim)] + data_vars[key] = xr.DataArray( + arr, + dims=dims, + coords={f"{key}_dim_{i}": np.arange(arr.shape[i]) for i in range(len(dims))}, + ) + + data_vars["method"] = xr.DataArray(np.array(method), dims=[]) + + return xr.Dataset(data_vars) + + +def add_optimizer_result_to_inference_data( + idata: az.InferenceData, + result: OptimizeResult, + method: minimize_method | Literal["basinhopping"], + mu: RaveledVars | None = None, + model: pm.Model | None = None, +) -> az.InferenceData: + """ + Add the optimization result to an InferenceData object. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + result: OptimizeResult + The result of the optimization process. + method: minimize_method or "basinhopping" + The optimization method used. + mu: RaveledVars, optional + The MAP estimate of the model parameters. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the optimization results added to the "optimizer" group. + """ + dataset = optimizer_result_to_dataset(result, method=method, mu=mu, model=model) + idata.add_groups({"optimizer_result": dataset}) + + return idata diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 4f280f8fc..774541bc4 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -63,7 +63,7 @@ # TODO: change to typing.Self after Python versions greater than 3.10 from typing_extensions import Self -from pymc_extras.inference.laplace_approx.idata import add_data_to_inferencedata +from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data from pymc_extras.inference.pathfinder.importance_sampling import ( importance_sampling as _importance_sampling, ) @@ -1759,6 +1759,6 @@ def fit_pathfinder( importance_sampling=importance_sampling, ) - idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) + idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs) return idata diff --git a/tests/inference/laplace_approx/test_idata.py b/tests/inference/laplace_approx/test_idata.py new file mode 100644 index 000000000..8a2cd4444 --- /dev/null +++ b/tests/inference/laplace_approx/test_idata.py @@ -0,0 +1,297 @@ +from contextlib import contextmanager + +import arviz as az +import numpy as np +import pymc as pm +import pytest +import xarray as xr + +from pymc.blocking import RaveledVars +from scipy.optimize import OptimizeResult +from scipy.sparse.linalg import LinearOperator + +from pymc_extras.inference.laplace_approx.idata import ( + add_data_to_inference_data, + add_fit_to_inference_data, + optimizer_result_to_dataset, +) + + +@contextmanager +def no_op(): + yield + + +@pytest.fixture +def rng(): + return np.random.default_rng() + + +@pytest.fixture +def simple_model(rng): + with pm.Model() as model: + x = pm.Data("data", rng.normal(size=(10,))) + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + obs = pm.Normal("obs", mu + x, sigma, observed=rng.normal(size=(10,))) + + mu_val = np.array([0.5, 1.0]) + H_inv = np.eye(2) + + point_map_info = (("mu", (), 1, "float64"), ("sigma_log__", (), 1, "float64")) + test_point = RaveledVars(mu_val, point_map_info) + + return model, mu_val, H_inv, test_point + + +@pytest.fixture +def hierarchical_model(rng): + with pm.Model(coords={"group": [1, 2, 3, 4, 5]}) as model: + mu_loc = pm.Normal("mu_loc", 0, 1) + mu_scale = pm.HalfNormal("mu_scale", 1) + mu = pm.Normal("mu", mu_loc, mu_scale, dims="group") + sigma = pm.HalfNormal("sigma", 1) + obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=rng.normal(size=(5, 10))) + + mu_val = rng.normal(size=(8,)) + H_inv = np.eye(8) + + point_map_info = ( + ("mu_loc", (), 1, "float64"), + ("mu_scale_log__", (), 1, "float64"), + ("mu", (5,), 5, "float64"), + ("sigma_log__", (), 1, "float64"), + ) + + test_point = RaveledVars(mu_val, point_map_info) + + return model, mu_val, H_inv, test_point + + +class TestFittoInferenceData: + def check_idata(self, idata, var_names, n_vars): + assert "fit" in idata.groups() + + fit = idata.fit + assert "mean_vector" in fit + assert "covariance_matrix" in fit + assert fit["mean_vector"].shape[0] == n_vars + assert fit["covariance_matrix"].shape == (n_vars, n_vars) + + assert list(fit.coords.keys()) == ["rows", "columns"] + assert fit.coords["rows"].values.tolist() == var_names + assert fit.coords["columns"].values.tolist() == var_names + + @pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) + def test_add_fit_to_inferencedata(self, use_context, simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + idata = az.from_dict( + posterior={"mu": rng.normal(size=()), "sigma_log__": rng.normal(size=())} + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model_arg) + + self.check_idata(idata2, ["mu", "sigma_log__"], 2) + + @pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) + def test_add_fit_with_coords_to_inferencedata(self, use_context, hierarchical_model, rng): + model, mu_val, H_inv, test_point = hierarchical_model + idata = az.from_dict( + posterior={ + "mu_loc": rng.normal(size=()), + "mu_scale_log__": rng.normal(size=()), + "mu": rng.normal(size=(5,)), + "sigma_log__": rng.normal(size=()), + } + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model_arg) + + self.check_idata( + idata2, + [ + "mu_loc", + "mu_scale_log__", + "mu[1]", + "mu[2]", + "mu[3]", + "mu[4]", + "mu[5]", + "sigma_log__", + ], + 8, + ) + + +@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) +def test_add_data_to_inferencedata(use_context, simple_model, rng): + model, *_ = simple_model + + idata = az.from_dict( + posterior={"mu": rng.standard_normal((1, 1)), "sigma_log__": rng.standard_normal((1, 1))} + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + idata2 = add_data_to_inference_data(idata, model=model_arg) + + assert "observed_data" in idata2.groups() + assert "constant_data" in idata2.groups() + assert "obs" in idata2.observed_data + + +@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) +def test_optimizer_result_to_dataset_basic(use_context, simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + result = OptimizeResult( + x=np.array([1.0, 2.0]), + fun=0.5, + success=True, + message="Optimization succeeded", + jac=np.array([0.1, 0.2]), + nit=5, + nfev=10, + njev=3, + status=0, + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + with context: + ds = optimizer_result_to_dataset(result, method="BFGS", model=model_arg, mu=test_point) + + assert isinstance(ds, xr.Dataset) + assert all( + key in ds + for key in [ + "x", + "fun", + "success", + "message", + "jac", + "nit", + "nfev", + "njev", + "status", + "method", + ] + ) + + assert list(ds["x"].coords.keys()) == ["variables"] + assert ds["x"].coords["variables"].values.tolist() == ["mu", "sigma_log__"] + + assert list(ds["jac"].coords.keys()) == ["variables"] + assert ds["jac"].coords["variables"].values.tolist() == ["mu", "sigma_log__"] + + +@pytest.mark.parametrize( + "optimizer_method, use_context, model_name", + [("BFGS", True, "hierarchical_model"), ("L-BFGS-B", False, "simple_model")], +) +def test_optimizer_result_to_dataset_hess_inv_types( + optimizer_method, use_context, model_name, rng, request +): + def get_hess_inv_and_expected_names(method): + model, mu_val, H_inv, test_point = request.getfixturevalue(model_name) + n = mu_val.shape[0] + + if method == "BFGS": + hess_inv = np.eye(n) + expected_names = [ + "mu_loc", + "mu_scale_log__", + "mu[1]", + "mu[2]", + "mu[3]", + "mu[4]", + "mu[5]", + "sigma_log__", + ] + result = OptimizeResult( + x=np.zeros((n,)), + hess_inv=hess_inv, + ) + elif method == "L-BFGS-B": + + def linop_func(x): + return np.array([2 * xi for xi in x]) + + linop = LinearOperator((n, n), matvec=linop_func) + hess_inv = 2 * np.eye(n) + expected_names = ["mu", "sigma_log__"] + result = OptimizeResult( + x=np.ones(n), + hess_inv=linop, + ) + else: + raise ValueError("Unknown optimizer_method") + + return model, test_point, hess_inv, expected_names, result + + model, test_point, hess_inv, expected_names, result = get_hess_inv_and_expected_names( + optimizer_method + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + ds = optimizer_result_to_dataset( + result, method=optimizer_method, mu=test_point, model=model_arg + ) + + assert "hess_inv" in ds + assert ds["hess_inv"].shape == (len(expected_names), len(expected_names)) + assert list(ds["hess_inv"].coords.keys()) == ["variables", "variables_aux"] + assert ds["hess_inv"].coords["variables"].values.tolist() == expected_names + assert ds["hess_inv"].coords["variables_aux"].values.tolist() == expected_names + np.testing.assert_allclose(ds["hess_inv"].values, hess_inv) + + +def test_optimizer_result_to_dataset_extra_fields(simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + + result = OptimizeResult( + x=np.array([1.0, 2.0]), + custom_stat=np.array([42, 43]), + ) + + with model: + ds = optimizer_result_to_dataset(result, method="BFGS", mu=test_point) + + assert "custom_stat" in ds + assert ds["custom_stat"].shape == (2,) + assert list(ds["custom_stat"].coords.keys()) == ["custom_stat_dim_0"] + assert ds["custom_stat"].coords["custom_stat_dim_0"].values.tolist() == [0, 1] + + +def test_optimizer_result_to_dataset_hess_inv_basinhopping(simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + n = mu_val.shape[0] + hess_inv_inner = np.eye(n) * 3.0 + + # Basinhopping returns an OptimizeResult with a nested OptimizeResult + result = OptimizeResult( + x=np.ones(n), + lowest_optimization_result=OptimizeResult(x=np.ones(n), hess_inv=hess_inv_inner), + ) + + with model: + ds = optimizer_result_to_dataset(result, method="basinhopping", mu=test_point) + + assert "hess_inv" in ds + assert ds["hess_inv"].shape == (n, n) + np.testing.assert_allclose(ds["hess_inv"].values, hess_inv_inner) + expected_names = ["mu", "sigma_log__"] + assert ds["hess_inv"].coords["variables"].values.tolist() == expected_names + assert ds["hess_inv"].coords["variables_aux"].values.tolist() == expected_names From 02d703277aef01ee3032fd0aafac310fd87b7e8e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 5 Jul 2025 12:50:22 +0800 Subject: [PATCH 3/7] Refactor find_MAP --- .../inference/laplace_approx/find_map.py | 423 +++++------------- .../laplace_approx/scipy_interface.py | 242 ++++++++++ .../inference/laplace_approx/test_find_map.py | 193 ++++++-- .../laplace_approx/test_scipy_interface.py | 118 +++++ 4 files changed, 629 insertions(+), 347 deletions(-) create mode 100644 pymc_extras/inference/laplace_approx/scipy_interface.py create mode 100644 tests/inference/laplace_approx/test_scipy_interface.py diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index 6100097a6..c20f7e8b0 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -1,30 +1,33 @@ import logging from collections.abc import Callable -from importlib.util import find_spec -from typing import Literal, cast, get_args +from typing import Literal, cast +import arviz as az import numpy as np import pymc as pm -import pytensor -import pytensor.tensor as pt from better_optimize import basinhopping, minimize from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn from pymc.model.transform.optimization import freeze_dims_and_data -from pymc.pytensorf import join_nonshared_inputs from pymc.util import get_default_varnames -from pytensor.compile import Function -from pytensor.compile.mode import Mode from pytensor.tensor import TensorVariable from scipy.optimize import OptimizeResult -_log = logging.getLogger(__name__) +from pymc_extras.inference.laplace_approx.idata import ( + add_data_to_inference_data, + add_fit_to_inference_data, + add_map_posterior_to_inference_data, + add_optimizer_result_to_inference_data, +) +from pymc_extras.inference.laplace_approx.scipy_interface import ( + GradientBackend, + scipy_optimize_funcs_from_loss, +) -GradientBackend = Literal["pytensor", "jax"] -VALID_BACKENDS = get_args(GradientBackend) +_log = logging.getLogger(__name__) def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): @@ -81,265 +84,105 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray: return eigvec @ np.diag(eigval) @ eigvec.T -def _unconstrained_vector_to_constrained_rvs(model): - constrained_rvs, unconstrained_vector = join_nonshared_inputs( - model.initial_point(), - inputs=model.value_vars, - outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False), - ) - - unconstrained_vector.name = "unconstrained_vector" - return constrained_rvs, unconstrained_vector - - -def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws): - X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0])) - out = [] - for rv, idx in slices.items(): - f = model.rvs_to_transforms[rv] - untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx] - - if rv in out_shapes: - new_shape = (chains, draws) + out_shapes[rv] - untransformed_X = untransformed_X.reshape(new_shape) - - out.append(untransformed_X) +def _make_initial_point(model, initvals=None, random_seed=None, jitter_rvs=None): + jitter_rvs = [] if jitter_rvs is None else jitter_rvs - f_untransform = pytensor.function( - inputs=[pytensor.In(X, borrow=True)], - outputs=pytensor.Out(out, borrow=True), - mode=Mode(linker="py", optimizer="FAST_COMPILE"), + ipfn = make_initial_point_fn( + model=model, + jitter_rvs=set(jitter_rvs), + return_transformed=True, + overrides=initvals, ) - return f_untransform(posterior_draws) - - -def _compile_grad_and_hess_to_jax( - f_fused: Function, use_hess: bool, use_hessp: bool -) -> tuple[Callable | None, Callable | None]: - """ - Compile loss function gradients using JAX. - - Parameters - ---------- - f_fused: Function - The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss, - compiled with mode="JAX". - use_hess: bool - Whether to compile a function to compute the hessian of the loss function. - use_hessp: bool - Whether to compile a function to compute the hessian-vector product of the loss function. - - Returns - ------- - f_fused: Callable - The compiled loss function and gradient function, which may also compute the hessian if requested. - f_hessp: Callable | None - The compiled hessian-vector product function, or None if use_hessp is False. - """ - import jax - - f_hessp = None - - orig_loss_fn = f_fused.vm.jit_fn - - if use_hess: - - @jax.jit - def loss_fn_fused(x): - loss_and_grad = jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) - hess = jax.hessian(lambda x: orig_loss_fn(x)[0])(x) - return *loss_and_grad, hess - - else: - @jax.jit - def loss_fn_fused(x): - return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) - - if use_hessp: - - def f_hessp_jax(x, p): - y, u = jax.jvp(lambda x: loss_fn_fused(x)[1], (x,), (p,)) - return jax.numpy.stack(u) - - f_hessp = jax.jit(f_hessp_jax) - - return loss_fn_fused, f_hessp - - -def _compile_functions_for_scipy_optimize( - loss: TensorVariable, - inputs: list[TensorVariable], - compute_grad: bool, - compute_hess: bool, - compute_hessp: bool, - compile_kwargs: dict | None = None, -) -> list[Function] | list[Function, Function | None, Function | None]: - """ - Compile loss functions for use with scipy.optimize.minimize. - - Parameters - ---------- - loss: TensorVariable - The loss function to compile. - inputs: list[TensorVariable] - A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines - expect the function signature to be f(x, *args), where x is a 1D array of parameters. - compute_grad: bool - Whether to compile a function that computes the gradients of the loss function. - compute_hess: bool - Whether to compile a function that computes the Hessian of the loss function. - compute_hessp: bool - Whether to compile a function that computes the Hessian-vector product of the loss function. - compile_kwargs: dict, optional - Additional keyword arguments to pass to the ``pm.compile`` function. - - Returns - ------- - f_fused: Function - The compiled loss function, which may also include gradients and hessian if requested. - f_hessp: Function | None - The compiled hessian-vector product function, or None if compute_hessp is False. - """ - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - - loss = pm.pytensorf.rewrite_pregrad(loss) - f_hessp = None - - # In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent - # with the case where we also compute gradients, hessians, or hessian-vector products. - if not (compute_grad or compute_hess or compute_hessp): - f_loss = pm.compile(inputs, loss, **compile_kwargs) - return [f_loss] - - # Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single - # fused function and retun it. If the user also wants the hession, the fused function will return the loss, - # gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss - # and gradients, and a separate function for the hessian-vector product. - - if compute_hessp: - # Handle this first, since it can be compiled alone. - p = pt.tensor("p", shape=inputs[0].type.shape) - hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) - f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs) - - outputs = [loss] - - if compute_grad: - grads = pytensor.gradient.grad(loss, inputs) - grad = pt.concatenate([grad.ravel() for grad in grads]) - outputs.append(grad) - - if compute_hess: - hess = pytensor.gradient.jacobian(grad, inputs)[0] - outputs.append(hess) - - f_fused = pm.compile(inputs, outputs, **compile_kwargs) + start_dict = ipfn(random_seed) + vars_dict = {var.name: var for var in model.continuous_value_vars} + initial_params = DictToArrayBijection.map( + {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} + ) - return [f_fused, f_hessp] + return initial_params -def scipy_optimize_funcs_from_loss( - loss: TensorVariable, - inputs: list[TensorVariable], - initial_point_dict: dict[str, np.ndarray | float | int], - use_grad: bool, +def _compute_inverse_hessian( + optimizer_result: OptimizeResult | None, + optimal_point: np.ndarray | None, + f_fused: Callable | None, + f_hessp: Callable | None, use_hess: bool, - use_hessp: bool, - gradient_backend: GradientBackend = "pytensor", - compile_kwargs: dict | None = None, -) -> tuple[Callable, ...]: + method: minimize_method | Literal["BFGS", "L-BFGS-B"], +): """ - Compile loss functions for use with scipy.optimize.minimize. + Compute the Hessian matrix or its inverse based on the optimization result and the method used. + + Downstream functions (e.g. laplace approximation) will need the inverse Hessian matrix. This function computes it + in the cheapest way possible, depending on the optimization method used and the available compiled functions. Parameters ---------- - loss: TensorVariable - The loss function to compile. - inputs: list[TensorVariable] - The input variables to the loss function. - initial_point_dict: dict[str, np.ndarray | float | int] - Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables. - use_grad: bool - Whether to compile a function that computes the gradients of the loss function. + optimizer_result: OptimizeResult, optional + The result of the optimization, containing the optimized parameters and possibly an approximate inverse Hessian. + optimal_point: np.ndarray, optional + The optimal point found by the optimizer, used to compute the Hessian if necessary. If not provided, it will be + extracted from the optimizer result. + f_fused: callable, optional + The compiled function representing the loss and possibly its gradient and Hessian. + f_hessp: callable, optional + The compiled function for Hessian-vector products, if available. use_hess: bool - Whether to compile a function that computes the Hessian of the loss function. - use_hessp: bool - Whether to compile a function that computes the Hessian-vector product of the loss function. - gradient_backend: str, default "pytensor" - Which backend to use to compute gradients. Must be one of "jax" or "pytensor" - compile_kwargs: - Additional keyword arguments to pass to the ``pm.compile`` function. + Whether the Hessian matrix was used in the optimization. + method: minimize_method + The optimization method used, which determines how the Hessian is computed. Returns ------- - f_fused: Callable - The compiled loss function, which may also include gradients and hessian if requested. - f_hessp: Callable | None - The compiled hessian-vector product function, or None if use_hessp is False. + H_inv: np.ndarray + The inverse Hessian matrix, computed based on the optimization method and available functions. """ + if optimal_point is None and optimizer_result is None: + raise ValueError("At least one of `optimal_point` or `optimizer_result` must be provided.") - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - - if (use_hess or use_hessp) and not use_grad: - raise ValueError( - "Cannot compute hessian or hessian-vector product without also computing the gradient" - ) - - if gradient_backend not in VALID_BACKENDS: - raise ValueError( - f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}" - ) - - use_jax_gradients = (gradient_backend == "jax") and use_grad - if use_jax_gradients and not find_spec("jax"): - raise ImportError("JAX must be installed to use JAX gradients") - - mode = compile_kwargs.get("mode", None) - if mode is None and use_jax_gradients: - compile_kwargs["mode"] = "JAX" - elif mode != "JAX" and use_jax_gradients: - raise ValueError( - 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"' - ) - - if not isinstance(inputs, list): - inputs = [inputs] - - [loss], flat_input = join_nonshared_inputs( - point=initial_point_dict, outputs=[loss], inputs=inputs - ) - - # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When - # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them - # away. - if use_jax_gradients: - from pymc.sampling.jax import _replace_shared_variables + x_star = optimizer_result.x if optimizer_result is not None else optimal_point + n_vars = len(x_star) - [loss] = _replace_shared_variables([loss]) + if method == "BFGS" and optimizer_result is not None: + # If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than + # re-computing something + if hasattr(optimizer_result, "lowest_optimization_result"): + # We did basinhopping, need to get the inner optimizer results + H_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None) + else: + H_inv = getattr(optimizer_result, "hess_inv", None) - compute_grad = use_grad and not use_jax_gradients - compute_hess = use_hess and not use_jax_gradients - compute_hessp = use_hessp and not use_jax_gradients + elif method == "L-BFGS-B" and optimizer_result is not None: + # Here we will have a LinearOperator representing the inverse Hessian-Vector product. + if hasattr(optimizer_result, "lowest_optimization_result"): + # We did basinhopping, need to get the inner optimizer results + f_hessp_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None) + else: + f_hessp_inv = getattr(optimizer_result, "hess_inv", None) + + if f_hessp_inv is not None: + basis = np.eye(n_vars) + H_inv = np.stack([f_hessp_inv(basis[:, i]) for i in range(n_vars)], axis=-1) + else: + H_inv = None - funcs = _compile_functions_for_scipy_optimize( - loss=loss, - inputs=[flat_input], - compute_grad=compute_grad, - compute_hess=compute_hess, - compute_hessp=compute_hessp, - compile_kwargs=compile_kwargs, - ) + elif f_hessp is not None: + # In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from + # the hessp function, using euclidian basis vector. + basis = np.eye(n_vars) + H = np.stack([f_hessp(x_star, basis[:, i]) for i in range(n_vars)], axis=-1) + H_inv = np.linalg.inv(get_nearest_psd(H)) - # Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients, - # or the loss function with gradients and hessian. - f_fused = funcs.pop(0) - f_hessp = funcs.pop(0) if compute_hessp else None + elif use_hess and f_fused is not None: + # If we compiled a hessian function, just use it + _, _, H = f_fused(x_star) + H_inv = np.linalg.inv(get_nearest_psd(H)) - if use_jax_gradients: - f_fused, f_hessp = _compile_grad_and_hess_to_jax(f_fused, use_hess, use_hessp) + else: + H_inv = None - return f_fused, f_hessp + return H_inv def find_MAP( @@ -351,14 +194,18 @@ def find_MAP( use_hess: bool | None = None, initvals: dict | None = None, random_seed: int | np.random.Generator | None = None, - return_raw: bool = False, jitter_rvs: list[TensorVariable] | None = None, progressbar: bool = True, include_transformed: bool = True, gradient_backend: GradientBackend = "pytensor", compile_kwargs: dict | None = None, **optimizer_kwargs, -) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]: +) -> ( + dict[str, np.ndarray] + | tuple[dict[str, np.ndarray], np.ndarray] + | tuple[dict[str, np.ndarray], OptimizeResult] + | tuple[dict[str, np.ndarray], OptimizeResult, np.ndarray] +): """ Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize. @@ -381,12 +228,10 @@ def find_MAP( Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on the ``method``. initvals : None | dict, optional - Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. + Initial values for the model parameters, as str:ndarray key-value pairs. Partial initialization is permitted. If None, the model's default initial values are used. random_seed : None | int | np.random.Generator, optional Seed for the random number generator or a numpy Generator for reproducibility - return_raw: bool | False, optinal - Whether to also return the full output of `scipy.optimize.minimize` jitter_rvs : list of TensorVariables, optional Variables whose initial values should be jittered. If None, all variables are jittered. progressbar : bool, optional @@ -404,28 +249,15 @@ def find_MAP( Returns ------- - optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult] - Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True, - also returns the object returned by ``scipy.optimize.minimize``. + map_result: az.InferenceData + Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed + latent variables, and optimizer results. """ - model = pm.modelcontext(model) + model = pm.modelcontext(model) if model is None else model frozen_model = freeze_dims_and_data(model) - - jitter_rvs = [] if jitter_rvs is None else jitter_rvs compile_kwargs = {} if compile_kwargs is None else compile_kwargs - ipfn = make_initial_point_fn( - model=frozen_model, - jitter_rvs=set(jitter_rvs), - return_transformed=True, - overrides=initvals, - ) - - start_dict = ipfn(random_seed) - vars_dict = {var.name: var for var in frozen_model.continuous_value_vars} - initial_params = DictToArrayBijection.map( - {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} - ) + initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs) do_basinhopping = method == "basinhopping" minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {}) @@ -443,9 +275,9 @@ def find_MAP( ) f_fused, f_hessp = scipy_optimize_funcs_from_loss( - loss=-frozen_model.logp(jacobian=False), + loss=-frozen_model.logp(), inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, - initial_point_dict=start_dict, + initial_point_dict=DictToArrayBijection.rmap(initial_params), use_grad=use_grad, use_hess=use_hess, use_hessp=use_hessp, @@ -491,38 +323,27 @@ def find_MAP( DictToArrayBijection.rmap(raveled_optimized) ) - # Downstream computation will probably want the covaraince matrix at the optimized point, so we compute it here, - # while we still have access to the compiled function. - x_star = optimizer_result.x - n_vars = len(x_star) - - if method == "BFGS": - # If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than - # re-computing something - getattr(optimizer_result, "hess_inv", None) - elif method == "L-BFGS-B": - # Here we will have a LinearOperator representing the inverse Hessian-Vector product. - f_hessp_inv = optimizer_result.hess_inv - basis = np.eye(n_vars) - np.stack([f_hessp_inv(basis[:, i]) for i in range(n_vars)], axis=-1) - - elif f_hessp is not None: - # In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from - # the hessp function, using euclidian basis vector. - basis = np.eye(n_vars) - H = np.stack([f_hessp(optimizer_result.x, basis[:, i]) for i in range(n_vars)], axis=-1) - np.linalg.inv(get_nearest_psd(H)) - - elif use_hess: - # If we compiled a hessian function, just use it - _, _, H = f_fused(x_star) - np.linalg.inv(get_nearest_psd(H)) - optimized_point = { var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) } - if return_raw: - return optimized_point, optimizer_result + H_inv = _compute_inverse_hessian( + optimizer_result=optimizer_result, + optimal_point=None, + f_fused=f_fused, + f_hessp=f_hessp, + use_hess=use_hess, + method=method, + ) + + idata = az.InferenceData() + idata = add_map_posterior_to_inference_data(idata, optimized_point, frozen_model) + idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv) + idata = add_optimizer_result_to_inference_data( + idata, optimizer_result, method, raveled_optimized, model + ) + idata = add_data_to_inference_data( + idata, progressbar=False, model=model, compile_kwargs=compile_kwargs + ) - return optimized_point + return idata diff --git a/pymc_extras/inference/laplace_approx/scipy_interface.py b/pymc_extras/inference/laplace_approx/scipy_interface.py new file mode 100644 index 000000000..a7489be3e --- /dev/null +++ b/pymc_extras/inference/laplace_approx/scipy_interface.py @@ -0,0 +1,242 @@ +from collections.abc import Callable +from importlib.util import find_spec +from typing import Literal, get_args + +import numpy as np +import pymc as pm +import pytensor + +from pymc import join_nonshared_inputs +from pytensor import tensor as pt +from pytensor.compile import Function +from pytensor.tensor import TensorVariable + +GradientBackend = Literal["pytensor", "jax"] +VALID_BACKENDS = get_args(GradientBackend) + + +def _compile_grad_and_hess_to_jax( + f_fused: Function, use_hess: bool, use_hessp: bool +) -> tuple[Callable | None, Callable | None]: + """ + Compile loss function gradients using JAX. + + Parameters + ---------- + f_fused: Function + The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss, + compiled with mode="JAX". + use_hess: bool + Whether to compile a function to compute the hessian of the loss function. + use_hessp: bool + Whether to compile a function to compute the hessian-vector product of the loss function. + + Returns + ------- + f_fused: Callable + The compiled loss function and gradient function, which may also compute the hessian if requested. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + import jax + + f_hessp = None + + orig_loss_fn = f_fused.vm.jit_fn + + if use_hess: + + @jax.jit + def loss_fn_fused(x): + loss_and_grad = jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + hess = jax.hessian(lambda x: orig_loss_fn(x)[0])(x) + return *loss_and_grad, hess + + else: + + @jax.jit + def loss_fn_fused(x): + return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + + if use_hessp: + + def f_hessp_jax(x, p): + y, u = jax.jvp(lambda x: loss_fn_fused(x)[1], (x,), (p,)) + return jax.numpy.stack(u) + + f_hessp = jax.jit(f_hessp_jax) + + return loss_fn_fused, f_hessp + + +def _compile_functions_for_scipy_optimize( + loss: TensorVariable, + inputs: list[TensorVariable], + compute_grad: bool, + compute_hess: bool, + compute_hessp: bool, + compile_kwargs: dict | None = None, +) -> list[Function] | list[Function, Function | None, Function | None]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines + expect the function signature to be f(x, *args), where x is a 1D array of parameters. + compute_grad: bool + Whether to compile a function that computes the gradients of the loss function. + compute_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + compute_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + compile_kwargs: dict, optional + Additional keyword arguments to pass to the ``pm.compile`` function. + + Returns + ------- + f_fused: Function + The compiled loss function, which may also include gradients and hessian if requested. + f_hessp: Function | None + The compiled hessian-vector product function, or None if compute_hessp is False. + """ + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + loss = pm.pytensorf.rewrite_pregrad(loss) + f_hessp = None + + # In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent + # with the case where we also compute gradients, hessians, or hessian-vector products. + if not (compute_grad or compute_hess or compute_hessp): + f_loss = pm.compile(inputs, loss, **compile_kwargs) + return [f_loss] + + # Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single + # fused function and return it. If the user also wants the hessian, the fused function will return the loss, + # gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss + # and gradients, and a separate function for the hessian-vector product. + + if compute_hessp: + # Handle this first, since it can be compiled alone. + p = pt.tensor("p", shape=inputs[0].type.shape) + hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) + f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs) + + outputs = [loss] + + if compute_grad: + grads = pytensor.gradient.grad(loss, inputs) + grad = pt.concatenate([grad.ravel() for grad in grads]) + outputs.append(grad) + + if compute_hess: + hess = pytensor.gradient.jacobian(grad, inputs)[0] + outputs.append(hess) + + f_fused = pm.compile(inputs, outputs, **compile_kwargs) + + return [f_fused, f_hessp] + + +def scipy_optimize_funcs_from_loss( + loss: TensorVariable, + inputs: list[TensorVariable], + initial_point_dict: dict[str, np.ndarray | float | int], + use_grad: bool, + use_hess: bool, + use_hessp: bool, + gradient_backend: GradientBackend = "pytensor", + compile_kwargs: dict | None = None, +) -> tuple[Callable, ...]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + The input variables to the loss function. + initial_point_dict: dict[str, np.ndarray | float | int] + Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables. + use_grad: bool + Whether to compile a function that computes the gradients of the loss function. + use_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + use_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + gradient_backend: str, default "pytensor" + Which backend to use to compute gradients. Must be one of "jax" or "pytensor" + compile_kwargs: + Additional keyword arguments to pass to the ``pm.compile`` function. + + Returns + ------- + f_fused: Callable + The compiled loss function, which may also include gradients and hessian if requested. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + if use_hess and not use_grad: + raise ValueError("Cannot compute hessian without also computing the gradient") + + if gradient_backend not in VALID_BACKENDS: + raise ValueError( + f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}" + ) + + use_jax_gradients = (gradient_backend == "jax") and use_grad + if use_jax_gradients and not find_spec("jax"): + raise ImportError("JAX must be installed to use JAX gradients") + + mode = compile_kwargs.get("mode", None) + if mode is None and use_jax_gradients: + compile_kwargs["mode"] = "JAX" + elif mode != "JAX" and use_jax_gradients: + raise ValueError( + 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"' + ) + + if not isinstance(inputs, list): + inputs = [inputs] + + [loss], flat_input = join_nonshared_inputs( + point=initial_point_dict, outputs=[loss], inputs=inputs + ) + + # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When + # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them + # away. + if use_jax_gradients: + from pymc.sampling.jax import _replace_shared_variables + + [loss] = _replace_shared_variables([loss]) + + compute_grad = use_grad and not use_jax_gradients + compute_hess = use_hess and not use_jax_gradients + compute_hessp = use_hessp and not use_jax_gradients + + funcs = _compile_functions_for_scipy_optimize( + loss=loss, + inputs=[flat_input], + compute_grad=compute_grad, + compute_hess=compute_hess, + compute_hessp=compute_hessp, + compile_kwargs=compile_kwargs, + ) + + # Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients, + # or the loss function with gradients and hessian. + f_fused = funcs.pop(0) + f_hessp = funcs.pop(0) if compute_hessp else None + + if use_jax_gradients: + f_fused, f_hessp = _compile_grad_and_hess_to_jax(f_fused, use_hess, use_hessp) + + return f_fused, f_hessp diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py index d09fb3d87..61f99e6bf 100644 --- a/tests/inference/laplace_approx/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -1,17 +1,18 @@ import numpy as np import pymc as pm -import pytensor import pytensor.tensor as pt import pytest from pymc_extras.inference.laplace_approx.find_map import ( - GradientBackend, find_MAP, + get_nearest_psd, + set_optimizer_function_defaults, +) +from pymc_extras.inference.laplace_approx.scipy_interface import ( + GradientBackend, scipy_optimize_funcs_from_loss, ) -pytest.importorskip("jax") - @pytest.fixture(scope="session") def rng(): @@ -19,8 +20,71 @@ def rng(): return np.random.default_rng(seed) +def test_get_nearest_psd_returns_psd(rng): + # Matrix with negative eigenvalues + A = np.array([[2, -3], [-3, 2]]) + psd = get_nearest_psd(A) + + # Should be symmetric + np.testing.assert_allclose(psd, psd.T) + + # All eigenvalues should be >= 0 + eigvals = np.linalg.eigvalsh(psd) + assert np.all(eigvals >= -1e-12), "All eigenvalues should be non-negative" + + +def test_get_nearest_psd_given_psd_input(rng): + L = rng.normal(size=(2, 2)) + A = L @ L.T + psd = get_nearest_psd(A) + + # Given PSD input, should return the same matrix + assert np.allclose(psd, A) + + +def test_set_optimizer_function_defaults_warns_and_prefers_hessp(caplog): + # "trust-ncg" uses_grad=True, uses_hess=True, uses_hessp=True + method = "trust-ncg" + with caplog.at_level("WARNING"): + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, True, True, True) + + message = caplog.messages[0] + assert message.startswith('Both "use_hess" and "use_hessp" are set to True') + + assert use_grad + assert not use_hess + assert use_hessp + + +def test_set_optimizer_function_defaults_infers_hess_and_hessp(): + # "trust-ncg" uses_grad=True, uses_hess=True, uses_hessp=True + method = "trust-ncg" + + # If only use_hessp is set, use_hess should be False but use_grad should be inferred as True + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, None, None, True) + assert use_grad + assert not use_hess + assert use_hessp + + # Only use_hess is set + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, None, True, None) + assert use_hess + assert not use_hessp + + +def test_set_optimizer_function_defaults_defaults(): + # "trust-ncg" uses_grad=True, uses_hess=True, uses_hessp=True + method = "trust-ncg" + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, None, None, None) + assert use_grad + assert not use_hess + assert use_hessp + + @pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) def test_jax_functions_from_graph(gradient_backend: GradientBackend): + pytest.importorskip("jax") + x = pt.tensor("x", shape=(2,)) def compute_z(x): @@ -57,74 +121,70 @@ def compute_z(x): @pytest.mark.parametrize( "method, use_grad, use_hess, use_hessp", [ - ("nelder-mead", False, False, False), - ("powell", False, False, False), - ("CG", True, False, False), + ( + "Newton-CG", + True, + True, + False, + ), + ("Newton-CG", True, False, True), ("BFGS", True, False, False), ("L-BFGS-B", True, False, False), - ("TNC", True, False, False), - ("SLSQP", True, False, False), - ("dogleg", True, True, False), - ("Newton-CG", True, True, False), - ("Newton-CG", True, False, True), - ("trust-ncg", True, True, False), - ("trust-ncg", True, False, True), - ("trust-exact", True, True, False), - ("trust-krylov", True, True, False), - ("trust-krylov", True, False, True), - ("trust-constr", True, True, False), ], ) @pytest.mark.parametrize( "backend, gradient_backend", - # JAX backend is faster, so only test it [("jax", "jax"), ("jax", "pytensor")], ids=str, ) def test_find_MAP( method, use_grad, use_hess, use_hessp, backend, gradient_backend: GradientBackend, rng ): - extra_kwargs = {} - if method == "dogleg": - # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point - # where this is true - extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}} + pytest.importorskip("jax") with pm.Model() as m: mu = pm.Normal("mu") sigma = pm.Exponential("sigma", 1) - pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=10)) - optimized_point = find_MAP( + idata = find_MAP( method=method, - **extra_kwargs, use_grad=use_grad, use_hess=use_hess, use_hessp=use_hessp, progressbar=False, gradient_backend=gradient_backend, compile_kwargs={"mode": backend.upper()}, + maxiter=5, ) - mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] - assert np.isclose(mu_hat, 3, atol=0.5) - assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) + assert hasattr(idata, "posterior") + assert hasattr(idata, "fit") + assert hasattr(idata, "optimizer_result") + assert hasattr(idata, "observed_data") + + posterior = idata.posterior.squeeze(["chain", "draw"]) + assert "mu" in posterior and "sigma_log__" in posterior and "sigma" in posterior + assert posterior["mu"].shape == () + assert posterior["sigma_log__"].shape == () + assert posterior["sigma"].shape == () @pytest.mark.parametrize( "backend, gradient_backend", - # JAX backend is faster, so only test it [("jax", "jax")], ids=str, ) def test_map_shared_variables(backend, gradient_backend: GradientBackend): + pytest.importorskip("jax") + with pm.Model() as m: - data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data") + data = pm.Data("data", np.random.normal(loc=3, scale=1.5, size=10)) mu = pm.Normal("mu") sigma = pm.Exponential("sigma", 1) y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data) - optimized_point = find_MAP( + idata = find_MAP( method="L-BFGS-B", use_grad=True, use_hess=False, @@ -133,36 +193,43 @@ def test_map_shared_variables(backend, gradient_backend: GradientBackend): gradient_backend=gradient_backend, compile_kwargs={"mode": backend.upper()}, ) - mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] - assert np.isclose(mu_hat, 3, atol=0.5) - assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) + assert hasattr(idata, "posterior") + assert hasattr(idata, "fit") + assert hasattr(idata, "optimizer_result") + assert hasattr(idata, "observed_data") + assert hasattr(idata, "constant_data") + + posterior = idata.posterior.squeeze(["chain", "draw"]) + assert "mu" in posterior and "sigma_log__" in posterior and "sigma" in posterior + assert posterior["mu"].shape == () + assert posterior["sigma_log__"].shape == () + assert posterior["sigma"].shape == () @pytest.mark.parametrize( "method, use_grad, use_hess, use_hessp", [ - ("nelder-mead", False, False, False), - ("L-BFGS-B", True, False, False), - ("trust-exact", True, True, False), - ("trust-ncg", True, False, True), + ("Newton-CG", True, True, False), + ("Newton-CG", True, False, True), ], ) @pytest.mark.parametrize( "backend, gradient_backend", - # JAX backend is faster, so only test it) [("jax", "pytensor")], ids=str, ) def test_find_MAP_basinhopping( method, use_grad, use_hess, use_hessp, backend, gradient_backend, rng ): + pytest.importorskip("jax") + with pm.Model() as m: mu = pm.Normal("mu") sigma = pm.Exponential("sigma", 1) - pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=10)) - optimized_point = find_MAP( + idata = find_MAP( method="basinhopping", use_grad=use_grad, use_hess=use_hess, @@ -171,9 +238,43 @@ def test_find_MAP_basinhopping( gradient_backend=gradient_backend, compile_kwargs={"mode": backend.upper()}, minimizer_kwargs=dict(method=method), + niter=1, ) - mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] + assert hasattr(idata, "posterior") + posterior = idata.posterior.squeeze(["chain", "draw"]) + assert "mu" in posterior and "sigma_log__" in posterior + assert posterior["mu"].shape == () + assert posterior["sigma_log__"].shape == () + + +def test_find_MAP_with_coords(): + with pm.Model(coords={"group": [1, 2, 3, 4, 5]}) as m: + mu_loc = pm.Normal("mu_loc", 0, 1) + mu_scale = pm.HalfNormal("mu_scale", 1) + + mu = pm.Normal("mu", mu_loc, mu_scale, dims=["group"]) + sigma = pm.HalfNormal("sigma", 1, dims=["group"]) - assert np.isclose(mu_hat, 3, atol=0.5) - assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) + obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=np.random.normal(size=(10, 5))) + + idata = find_MAP(progressbar=False, method="L-BFGS-B") + + assert hasattr(idata, "posterior") + assert hasattr(idata, "fit") + + posterior = idata.posterior.squeeze(["chain", "draw"]) + assert ( + "mu_loc" in posterior + and "mu_scale" in posterior + and "mu_scale_log__" in posterior + and "mu" in posterior + and "sigma_log__" in posterior + and "sigma" in posterior + ) + assert posterior["mu_loc"].shape == () + assert posterior["mu_scale"].shape == () + assert posterior["mu_scale_log__"].shape == () + assert posterior["mu"].shape == (5,) + assert posterior["sigma_log__"].shape == (5,) + assert posterior["sigma"].shape == (5,) diff --git a/tests/inference/laplace_approx/test_scipy_interface.py b/tests/inference/laplace_approx/test_scipy_interface.py new file mode 100644 index 000000000..4e8b56a08 --- /dev/null +++ b/tests/inference/laplace_approx/test_scipy_interface.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from pytensor import tensor as pt + +from pymc_extras.inference.laplace_approx import scipy_interface + + +@pytest.fixture +def simple_loss_and_inputs(): + x = pt.vector("x") + loss = pt.sum(x**2) + return loss, [x] + + +def test_compile_functions_for_scipy_optimize_loss_only(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=False, compute_hess=False, compute_hessp=False + ) + assert len(funcs) == 1 + f_loss = funcs[0] + x_val = np.array([1.0, 2.0, 3.0]) + result = f_loss(x_val) + assert np.isclose(result, np.sum(x_val**2)) + + +def test_compile_functions_for_scipy_optimize_with_grad(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=True, compute_hess=False, compute_hessp=False + ) + f_fused = funcs[0] + x_val = np.array([1.0, 2.0, 3.0]) + loss_val, grad_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + + +def test_compile_functions_for_scipy_optimize_with_hess(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=True, compute_hess=True, compute_hessp=False + ) + f_fused = funcs[0] + x_val = np.array([1.0, 2.0]) + loss_val, grad_val, hess_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + assert np.allclose(hess_val, 2 * np.eye(len(x_val))) + + +def test_compile_functions_for_scipy_optimize_with_hessp(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=True, compute_hess=False, compute_hessp=True + ) + f_fused, f_hessp = funcs + x_val = np.array([1.0, 2.0]) + p_val = np.array([1.0, 0.0]) + + loss_val, grad_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + + hessp_val = f_hessp(x_val, p_val) + assert np.allclose(hessp_val, 2 * p_val) + + +def test_scipy_optimize_funcs_from_loss_invalid_backend(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + with pytest.raises(ValueError, match="Invalid gradient backend"): + scipy_interface.scipy_optimize_funcs_from_loss( + loss, + inputs, + {"x": np.array([1.0, 2.0])}, + use_grad=True, + use_hess=False, + use_hessp=False, + gradient_backend="not_a_backend", + ) + + +def test_scipy_optimize_funcs_from_loss_hess_without_grad(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + with pytest.raises( + ValueError, match="Cannot compute hessian without also computing the gradient" + ): + scipy_interface.scipy_optimize_funcs_from_loss( + loss, + inputs, + {"x": np.array([1.0, 2.0])}, + use_grad=False, + use_hess=True, + use_hessp=False, + ) + + +@pytest.mark.parametrize("backend", ["pytensor", "jax"], ids=str) +def test_scipy_optimize_funcs_from_loss_backend(backend, simple_loss_and_inputs): + if backend == "jax": + pytest.importorskip("jax", reason="JAX is not installed") + + loss, inputs = simple_loss_and_inputs + f_fused, f_hessp = scipy_interface.scipy_optimize_funcs_from_loss( + loss, + inputs, + {"x": np.array([1.0, 2.0])}, + use_grad=True, + use_hess=False, + use_hessp=False, + gradient_backend=backend, + ) + x_val = np.array([1.0, 2.0]) + loss_val, grad_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + assert f_hessp is None From 372295823906c140602fb2d82707e7f15f595628 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 5 Jul 2025 12:50:28 +0800 Subject: [PATCH 4/7] Refactor fit_laplace --- .../inference/laplace_approx/laplace.py | 381 +++++++----------- .../inference/laplace_approx/test_laplace.py | 200 ++++----- 2 files changed, 225 insertions(+), 356 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 488d41911..ef1224e81 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -16,36 +16,35 @@ import logging from collections.abc import Callable -from importlib.util import find_spec +from functools import partial from typing import Literal +from typing import cast as type_cast import arviz as az import numpy as np import pymc as pm import pytensor import pytensor.tensor as pt +import xarray as xr from better_optimize.constants import minimize_method from numpy.typing import ArrayLike -from pymc import DictToArrayBijection -from pymc.blocking import RaveledVars -from pymc.model.transform.conditioning import remove_value_transforms +from pymc import DictToArrayBijection, join_nonshared_inputs from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.util import get_default_varnames +from pytensor.graph import vectorize_graph from pytensor.tensor import TensorVariable from pytensor.tensor.optimize import minimize -from scipy import stats +from pytensor.tensor.type import Variable from pymc_extras.inference.laplace_approx.find_map import ( - GradientBackend, - _unconstrained_vector_to_constrained_rvs, + _compute_inverse_hessian, + _make_initial_point, find_MAP, - get_nearest_psd, - scipy_optimize_funcs_from_loss, ) -from pymc_extras.inference.laplace_approx.idata import ( - add_data_to_inferencedata, - add_fit_to_inferencedata, - laplace_draws_to_inferencedata, +from pymc_extras.inference.laplace_approx.scipy_interface import ( + GradientBackend, + scipy_optimize_funcs_from_loss, ) _log = logging.getLogger(__name__) @@ -147,189 +146,110 @@ def get_conditional_gaussian_approximation( return pytensor.function(args, [x0, conditional_gaussian_approx]) -def fit_mvn_at_MAP( - optimized_point: dict[str, np.ndarray], - model: pm.Model | None = None, - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", - transform_samples: bool = False, - gradient_backend: GradientBackend = "pytensor", - zero_tol: float = 1e-8, - diag_jitter: float | None = 1e-8, - compile_kwargs: dict | None = None, -) -> tuple[RaveledVars, np.ndarray]: - """ - Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior - evaluated at the MAP estimate. This is the basis of the Laplace approximation. - - Parameters - ---------- - optimized_point : idata - Local maximum a posteriori (MAP) point returned from pymc_extras.inference.find_MAP - model : Model, optional - A PyMC model. If None, the model is taken from the current model context. - on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' - What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. - If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. - If 'error', an error will be raised. - transform_samples : bool - Whether to transform the samples back to the original parameter space. Default is True. - gradient_backend: str, default "pytensor" - The backend to use for gradient computations. Must be one of "pytensor" or "jax". - zero_tol: float - Value below which an element of the Hessian matrix is counted as 0. - This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. - diag_jitter: float | None - A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. - If None, no jitter is added. Default is 1e-8. - compile_kwargs: dict, optional - Additional keyword arguments to pass to pytensor.function when compiling loss functions - - Returns - ------- - map_estimate: RaveledVars - The MAP estimate of the model parameters, raveled into a 1D array. - - inverse_hessian: np.ndarray - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. - """ - if gradient_backend == "jax" and not find_spec("jax"): - raise ImportError("JAX must be installed to use JAX gradients") +def _unconstrained_vector_to_constrained_rvs(model): + outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=False) + names = [x.name for x in outputs] - model = pm.modelcontext(model) - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - frozen_model = freeze_dims_and_data(model) - - if not transform_samples: - untransformed_model = remove_value_transforms(frozen_model) - logp = untransformed_model.logp(jacobian=False) - variables = untransformed_model.continuous_value_vars - else: - logp = frozen_model.logp(jacobian=True) - variables = frozen_model.continuous_value_vars - - variable_names = {var.name for var in variables} - optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} - mu = DictToArrayBijection.map(optimized_free_params) - - f_fused, _ = scipy_optimize_funcs_from_loss( - loss=-logp, - inputs=variables, - initial_point_dict=optimized_free_params, - use_grad=True, - use_hess=True, - use_hessp=False, - gradient_backend=gradient_backend, - compile_kwargs=compile_kwargs, + constrained_rvs, unconstrained_vector = join_nonshared_inputs( + model.initial_point(), + inputs=model.value_vars, + outputs=outputs, ) - H = -f_fused(mu.data)[-1] - if H.ndim == 1: - H = np.expand_dims(H, axis=1) - H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) - - def stabilize(x, jitter): - return x + np.eye(x.shape[0]) * jitter + unconstrained_vector.name = "unconstrained_vector" + return names, constrained_rvs, unconstrained_vector + + +def model_to_laplace_approx( + model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500 +): + initial_point = model.initial_point() + raveled_vars = DictToArrayBijection.map(initial_point) + raveled_shape = raveled_vars.data.shape[0] + + # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov, + # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved. + names, constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) + + coords = model.coords | { + "temp_chain": np.arange(chains), + "temp_draw": np.arange(draws), + "unpacked_variable_names": unpacked_variable_names, + } + + with pm.Model(coords=coords, model=None) as laplace_model: + mu = pm.Flat("mean_vector", shape=(raveled_shape,)) + cov = pm.Flat("covariance_matrix", shape=(raveled_shape, raveled_shape)) + laplace_approximation = pm.MvNormal( + "laplace_approximation", + mu=mu, + cov=cov, + dims=["temp_chain", "temp_draw", "unpacked_variable_names"], + method="svd", + ) - H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) + cast_to_var = partial(type_cast, Variable) + batched_rvs = vectorize_graph( + type_cast(list[Variable], constrained_rvs), + replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)}, + ) - try: - np.linalg.cholesky(H_inv) - except np.linalg.LinAlgError: - if on_bad_cov == "error": - raise np.linalg.LinAlgError( - "Inverse Hessian not positive-semi definite at the provided point" - ) - H_inv = get_nearest_psd(H_inv) - if on_bad_cov == "warn": - _log.warning( - "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " - "matrix in L1-norm instead" + for name, batched_rv in zip(names, batched_rvs): + pm.Deterministic( + name, + batched_rv, + dims=("temp_chain", "temp_draw", *model.named_vars_to_dims.get(name, ())), ) - return mu, H_inv + return laplace_model -def sample_laplace_posterior( - mu: RaveledVars, - H_inv: np.ndarray, - model: pm.Model | None = None, - chains: int = 2, - draws: int = 500, - transform_samples: bool = False, - progressbar: bool = True, - random_seed: int | np.random.Generator | None = None, - compile_kwargs: dict | None = None, -) -> az.InferenceData: +def unstack_laplace_draws(idata, model): """ - Generate samples from a multivariate normal distribution with mean `mu` and inverse covariance matrix `H_inv`. - - Parameters - ---------- - mu: RaveledVars - The MAP estimate of the model parameters. - H_inv: np.ndarray - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. - model : Model - A PyMC model - chains : int - The number of sampling chains running in parallel. Default is 2. - draws : int - The number of samples to draw from the approximated posterior. Default is 500. - transform_samples : bool - Whether to transform the samples back to the original parameter space. Default is True. - progressbar : bool - Whether to display a progress bar during computations. Default is True. - random_seed: int | np.random.Generator | None - Seed for the random number generator or a numpy Generator for reproducibility + The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are + in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a + single vector, it's not easy to work with. - Returns - ------- - idata: az.InferenceData - An InferenceData object containing the approximated posterior samples. + This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and + coordinates, where possible. """ - model = pm.modelcontext(model) - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - rng = np.random.default_rng(random_seed) + initial_point = DictToArrayBijection.map(model.initial_point()) - posterior_dist = stats.multivariate_normal( - mean=mu.data, cov=H_inv, allow_singular=True, seed=rng - ) + cursor = 0 + chains = idata.coords["chain"].size + draws = idata.coords["draw"].size - posterior_draws = posterior_dist.rvs(size=(chains, draws)) - if mu.data.shape == (1,): - posterior_draws = np.expand_dims(posterior_draws, -1) + unstacked_laplace_draws = {} + laplace_data = idata.laplace_approximation.values - if transform_samples: - constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) - batched_values = pt.tensor( - "batched_values", - shape=(chains, draws, *unconstrained_vector.type.shape), - dtype=unconstrained_vector.type.dtype, - ) - batched_rvs = pytensor.graph.vectorize_graph( - constrained_rvs, replace={unconstrained_vector: batched_values} - ) + coords = model.coords | {"chain": range(chains), "draw": range(draws)} - f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs) - posterior_draws = f_constrain(posterior_draws) + # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g. + # simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just + # add an arviz-style default dim and label. + for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info): + rv_dims = [] + for i, dim in enumerate(model.named_vars_to_dims.get(rv.name, ())): + if shape[i] == len(coords[dim]): + rv_dims.append(dim) + else: + rv_dims.append(f"{name}_dim_{i}") + coords[f"{name}_dim_{i}"] = np.arange(shape[i]) - else: - info = mu.point_map_info - flat_shapes = [size for _, _, size, _ in info] - slices = [ - slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) - ] + dims = ("chain", "draw", *rv_dims) - posterior_draws = [ - posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) - for idx, (name, shape, _, dtype) in zip(slices, info) - ] + values = ( + laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype) + ) + unstacked_laplace_draws[name] = xr.DataArray( + values, dims=dims, coords={dim: list(coords[dim]) for dim in dims} + ) - idata = laplace_draws_to_inferencedata(posterior_draws, model) - idata = add_fit_to_inferencedata(idata, mu, H_inv) - idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) + cursor += size - return idata + unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws) + + return unstacked_laplace_draws def fit_laplace( @@ -341,17 +261,12 @@ def fit_laplace( use_hess: bool | None = None, initvals: dict | None = None, random_seed: int | np.random.Generator | None = None, - return_raw: bool = False, jitter_rvs: list[pt.TensorVariable] | None = None, progressbar: bool = True, include_transformed: bool = True, gradient_backend: GradientBackend = "pytensor", chains: int = 2, draws: int = 500, - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", - fit_in_unconstrained_space: bool = False, - zero_tol: float = 1e-8, - diag_jitter: float | None = 1e-8, optimizer_kwargs: dict | None = None, compile_kwargs: dict | None = None, ) -> az.InferenceData: @@ -385,23 +300,13 @@ def fit_laplace( If None, the model's default initial values are used. random_seed : None | int | np.random.Generator, optional Seed for the random number generator or a numpy Generator for reproducibility - return_raw: bool | False, optinal - Whether to also return the full output of `scipy.optimize.minimize` jitter_rvs : list of TensorVariables, optional Variables whose initial values should be jittered. If None, all variables are jittered. progressbar : bool, optional Whether to display a progress bar during optimization. Defaults to True. - fit_in_unconstrained_space: bool, default False - Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn - from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will - then be transformed back to the original parameter space. This will guarantee that the samples will respect - the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0 - and 1). - - .. warning:: - This argument should be considered highly experimental. It has not been verified if this method produces - valid draws from the posterior. **Use at your own risk**. - + include_transformed: bool, default True + Whether to include transformed variables in the output. If True, transformed variables will be included in the + output InferenceData object. If False, only the original variables will be included. gradient_backend: str, default "pytensor" The backend to use for gradient computations. Must be one of "pytensor" or "jax". chains: int, default: 2 @@ -410,16 +315,6 @@ def fit_laplace( compatible with the ArviZ library. draws: int, default: 500 The number of samples to draw from the approximated posterior. Totals samples will be chains * draws. - on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' - What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. - If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. - If 'error', an error will be raised. - zero_tol: float - Value below which an element of the Hessian matrix is counted as 0. - This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. - diag_jitter: float | None - A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. - If None, no jitter is added. Default is 1e-8. optimizer_kwargs Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``, @@ -458,8 +353,9 @@ def fit_laplace( """ compile_kwargs = {} if compile_kwargs is None else compile_kwargs optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + model = pm.modelcontext(model) if model is None else model - optimized_point = find_MAP( + idata = find_MAP( method=optimize_method, model=model, use_grad=use_grad, @@ -467,7 +363,6 @@ def fit_laplace( use_hess=use_hess, initvals=initvals, random_seed=random_seed, - return_raw=return_raw, jitter_rvs=jitter_rvs, progressbar=progressbar, include_transformed=include_transformed, @@ -476,25 +371,57 @@ def fit_laplace( **optimizer_kwargs, ) - mu, H_inv = fit_mvn_at_MAP( - optimized_point=optimized_point, - model=model, - on_bad_cov=on_bad_cov, - transform_samples=fit_in_unconstrained_space, - gradient_backend=gradient_backend, - zero_tol=zero_tol, - diag_jitter=diag_jitter, - compile_kwargs=compile_kwargs, - ) + unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist() + + if "covariance_matrix" not in idata.fit: + # The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so + # we have to go back and compute the Hessian at the MAP point now. + frozen_model = freeze_dims_and_data(model) + initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs) + + _, f_hessp = scipy_optimize_funcs_from_loss( + loss=-frozen_model.logp(jacobian=False), + inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, + initial_point_dict=DictToArrayBijection.rmap(initial_params), + use_grad=False, + use_hess=False, + use_hessp=True, + gradient_backend=gradient_backend, + compile_kwargs=compile_kwargs, + ) + H_inv = _compute_inverse_hessian( + optimizer_result=None, + optimal_point=idata.fit.mean_vector.values, + f_fused=None, + f_hessp=f_hessp, + use_hess=False, + method=optimize_method, + ) - return sample_laplace_posterior( - mu=mu, - H_inv=H_inv, - model=model, - chains=chains, - draws=draws, - transform_samples=fit_in_unconstrained_space, - progressbar=progressbar, - random_seed=random_seed, - compile_kwargs=compile_kwargs, - ) + idata.fit["covariance_matrix"] = xr.DataArray( + H_inv, + dims=("rows", "columns"), + coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, + ) + + with model_to_laplace_approx(model, unpacked_variable_names, chains, draws) as laplace_model: + laplace_idata = pm.sample_posterior_predictive( + idata.fit.expand_dims(chain=[0], draw=[0]), + extend_inferencedata=False, + random_seed=random_seed, + var_names=["laplace_approximation", *[x.name for x in laplace_model.deterministics]], + ) + new_posterior = ( + laplace_idata.posterior_predictive.squeeze(["chain", "draw"]) + .drop_vars(["chain", "draw"]) + .rename({"temp_chain": "chain", "temp_draw": "draw"}) + ) + + new_posterior.update(unstack_laplace_draws(new_posterior, model)) + new_posterior = new_posterior.drop_vars( + ["laplace_approximation", "unpacked_variable_names"] + ) + + idata.posterior = new_posterior + + return idata diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index 31c8eaf2b..68ed30cc1 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -19,12 +19,10 @@ import pymc_extras as pmx -from pymc_extras.inference.laplace_approx.find_map import GradientBackend, find_MAP +from pymc_extras.inference.laplace_approx.find_map import GradientBackend from pymc_extras.inference.laplace_approx.laplace import ( fit_laplace, - fit_mvn_at_MAP, get_conditional_gaussian_approximation, - sample_laplace_posterior, ) @@ -42,7 +40,7 @@ def rng(): "mode, gradient_backend", [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], ) -def test_laplace(mode, gradient_backend: GradientBackend): +def test_fit_laplace_basic(mode, gradient_backend: GradientBackend): # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, # Aki Vehtari, and Donald Rubin. @@ -53,8 +51,8 @@ def test_laplace(mode, gradient_backend: GradientBackend): draws = 100000 with pm.Model() as m: - mu = pm.Uniform("mu", -10000, 10000) - logsigma = pm.Uniform("logsigma", 1, 100) + mu = pm.Flat("mu") + logsigma = pm.Flat("logsigma") yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] @@ -67,6 +65,7 @@ def test_laplace(mode, gradient_backend: GradientBackend): chains=1, compile_kwargs={"mode": mode}, gradient_backend=gradient_backend, + optimizer_kwargs=dict(tol=1e-20), ) assert idata.posterior["mu"].shape == (1, draws) @@ -78,59 +77,13 @@ def test_laplace(mode, gradient_backend: GradientBackend): bda_map = [y.mean(), np.log(y.std())] bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) - np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) - np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + np.testing.assert_allclose(idata.posterior["mu"].mean(), bda_map[0], atol=1) + np.testing.assert_allclose(idata.posterior["logsigma"].mean(), bda_map[1], rtol=1e-3) + np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3) -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_laplace_only_fit(mode, gradient_backend: GradientBackend): - # Example originates from Bayesian Data Analyses, 3rd Edition - # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, - # Aki Vehtari, and Donald Rubin. - # See section. 4.1 - - y = np.array([2642, 3503, 4358], dtype=np.float64) - n = y.size - - with pm.Model() as m: - logsigma = pm.Uniform("logsigma", 1, 100) - mu = pm.Uniform("mu", -10000, 10000) - yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) - vars = [mu, logsigma] - - idata = pmx.fit( - method="laplace", - optimize_method="BFGS", - progressbar=True, - gradient_backend=gradient_backend, - compile_kwargs={"mode": mode}, - optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100), - random_seed=173300, - ) - - assert idata.fit["mean_vector"].shape == (len(vars),) - assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars)) - - bda_map = [np.log(y.std()), y.mean()] - bda_cov = np.array([[1 / (2 * n), 0], [0, y.var() / n]]) - - np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) - np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) - -@pytest.mark.parametrize( - "transform_samples", - [True, False], - ids=["transformed", "untransformed"], -) -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend): +def test_fit_laplace_coords(rng): coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as model: mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) @@ -143,49 +96,32 @@ def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: Grad dims=["obs_idx", "city"], ) - optimized_point = find_MAP( - method="trust-ncg", - use_grad=True, - use_hessp=True, - progressbar=False, - compile_kwargs=dict(mode=mode), - gradient_backend=gradient_backend, - ) - - for value in optimized_point.values(): - assert value.shape == (3,) - - mu, H_inv = fit_mvn_at_MAP( - optimized_point=optimized_point, - model=model, - transform_samples=transform_samples, - ) - - idata = sample_laplace_posterior( - mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples + idata = pmx.fit( + method="laplace", + optimize_method="trust-ncg", + chains=1, + draws=1000, + optimizer_kwargs=dict(tol=1e-20), ) - np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5) np.testing.assert_allclose( - np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3 + idata.posterior.mu.mean(dim=["chain", "draw"]).values, np.full((3,), 3), atol=0.5 + ) + np.testing.assert_allclose( + idata.posterior.sigma.mean(dim=["chain", "draw"]).values, np.full((3,), 1.5), atol=0.3 ) - suffix = "_log__" if transform_samples else "" assert idata.fit.rows.values.tolist() == [ "mu[A]", "mu[B]", "mu[C]", - f"sigma{suffix}[A]", - f"sigma{suffix}[B]", - f"sigma{suffix}[C]", + "sigma_log__[A]", + "sigma_log__[B]", + "sigma_log__[C]", ] -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng): +def test_fit_laplace_ragged_coords(rng): coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as ragged_dim_model: X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) @@ -210,10 +146,12 @@ def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng) progressbar=False, use_grad=True, use_hessp=True, - gradient_backend=gradient_backend, - compile_kwargs={"mode": mode}, ) + # These should have been dropped when the laplace idata was created + assert "laplace_approximation" not in list(idata.posterior.data_vars.keys()) + assert "unpacked_var_names" not in list(idata.posterior.coords.keys()) + assert idata["posterior"].beta.shape[-2:] == (3, 2) assert idata["posterior"].sigma.shape[-1:] == (3,) @@ -223,50 +161,48 @@ def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng) assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() -@pytest.mark.parametrize( - "fit_in_unconstrained_space", - [True, False], - ids=["transformed", "untransformed"], -) -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend): - with pm.Model() as simp_model: - mu = pm.Normal("mu", mu=3, sigma=0.5) - sigma = pm.Exponential("sigma", 1) - obs = pm.Normal( - "obs", - mu=mu, - sigma=sigma, - observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)), - ) +def test_model_with_nonstandard_dimensionality_1(rng): + y_obs = np.concatenate( + [rng.normal(-1, 2, size=150), rng.normal(3, 1, size=350), rng.normal(5, 4, size=50)] + ) - idata = fit_laplace( - optimize_method="trust-ncg", - use_grad=True, - use_hessp=True, - fit_in_unconstrained_space=fit_in_unconstrained_space, - optimizer_kwargs=dict(maxiter=100_000, tol=1e-100), - compile_kwargs={"mode": mode}, - gradient_backend=gradient_backend, - ) + with pm.Model(coords={"obs_idx": range(y_obs.size), "class": ["A", "B", "C"]}) as model: + y = pm.Data("y", y_obs, dims=["obs_idx"]) + + mu = pm.Normal("mu", mu=1, sigma=3, dims=["class"]) + sigma = pm.HalfNormal("sigma", sigma=3, dims=["class"]) - np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1) - np.testing.assert_allclose( - np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1 + w = pm.Dirichlet( + "w", + a=np.ones( + 3, + ), + dims=["class"], + ) + class_idx = pm.Categorical("class_idx", p=w, dims=["obs_idx"]) + y_hat = pm.Normal( + "obs", mu=mu[class_idx], sigma=sigma[class_idx], observed=y, dims=["obs_idx"] ) - if fit_in_unconstrained_space: - assert idata.fit.rows.values.tolist() == ["mu", "sigma_log__"] - np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 0.4]), atol=0.1) - else: - assert idata.fit.rows.values.tolist() == ["mu", "sigma"] - np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1) + with pmx.marginalize(model, [class_idx]): + idata = pmx.fit_laplace(progressbar=False) + + # The dirichlet value variable has a funky shape; check that it got a default + assert "w_simplex___dim_0" in list(idata.posterior.w_simplex__.coords.keys()) + assert "class" not in list(idata.posterior.w_simplex__.coords.keys()) + assert len(idata.posterior.coords["w_simplex___dim_0"]) == 2 + + # On the other hand, check that the actual w has the correct dims + assert "class" in list(idata.posterior.w.coords.keys()) + # The log transform is 1-to-1, so it should have the same dims as the original rv + assert "class" in list(idata.posterior.sigma_log__.coords.keys()) -def test_laplace_scalar(): + +# Test these three optimizers because they are either special cases for H_inv (BFGS, L-BFGS-B) or are +# gradient free and require re-compilation of hessp (powell). +@pytest.mark.parametrize("optimizer_method", ["BFGS", "L-BFGS-B", "powell"]) +def test_laplace_scalar_basinhopping(optimizer_method): # Example model from Statistical Rethinking data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -274,12 +210,18 @@ def test_laplace_scalar(): p = pm.Uniform("p", 0, 1) w = pm.Binomial("w", n=len(data), p=p, observed=data.sum()) - idata_laplace = pmx.fit_laplace(progressbar=False) + idata_laplace = pmx.fit_laplace( + optimize_method="basinhopping", + optimizer_kwargs={"minimizer_kwargs": {"method": optimizer_method}, "niter": 1}, + progressbar=False, + ) assert idata_laplace.fit.mean_vector.shape == (1,) assert idata_laplace.fit.covariance_matrix.shape == (1, 1) - np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) + np.testing.assert_allclose( + idata_laplace.posterior.p.mean(dim=["chain", "draw"]), data.mean(), atol=0.1 + ) def test_get_conditional_gaussian_approximation(): From 3f2aa8b3d104be7193684901bd4c7e6cff604e39 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 5 Jul 2025 12:50:43 +0800 Subject: [PATCH 5/7] Update better-optimize version pin --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0be357d96..c90ff1c4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dynamic = ["version"] # specify the version in the __init__.py file dependencies = [ "pymc>=5.21.1", "scikit-learn", - "better-optimize>=0.1.2", + "better-optimize>=0.1.4", "pydantic>=2.0.0", ] From daff0c9071d2182506e7ade2dd0e07caf09bdb49 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 8 Jul 2025 23:53:16 +0800 Subject: [PATCH 6/7] Handle labeling of non-scalar RVs without dims --- pymc_extras/inference/laplace_approx/idata.py | 19 +++++++++--- .../inference/laplace_approx/laplace.py | 31 ++++++++++++++----- .../inference/laplace_approx/test_laplace.py | 25 ++++++++++++++- 3 files changed, 62 insertions(+), 13 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/idata.py b/pymc_extras/inference/laplace_approx/idata.py index b82031f94..515adb8c6 100644 --- a/pymc_extras/inference/laplace_approx/idata.py +++ b/pymc_extras/inference/laplace_approx/idata.py @@ -14,6 +14,13 @@ from scipy.sparse.linalg import LinearOperator +def make_default_labels(name: str, shape: tuple[int, ...]) -> list: + if len(shape) == 0: + return [name] + + return [list(range(dim)) for dim in shape] + + def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]: coords = model.coords initial_point = model.initial_point() @@ -31,10 +38,14 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str] for name in names: shape = initial_point[name].shape if shape: - labels_by_dim = [ - coords[dim] if shape[i] == len(coords[dim]) else np.arange(shape[i]) - for i, dim in enumerate(dims_dict.get(name, [name])) - ] + dims = dims_dict.get(name) + if dims: + labels_by_dim = [ + coords[dim] if shape[i] == len(coords[dim]) else np.arange(shape[i]) + for i, dim in enumerate(dims) + ] + else: + labels_by_dim = make_default_labels(name, shape) labels = product(*labels_by_dim) unpacked_variable_names.extend( [f"{name}[{','.join(map(str, label))}]" for label in labels] diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index ef1224e81..f7a5df347 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -195,11 +195,21 @@ def model_to_laplace_approx( ) for name, batched_rv in zip(names, batched_rvs): - pm.Deterministic( - name, - batched_rv, - dims=("temp_chain", "temp_draw", *model.named_vars_to_dims.get(name, ())), - ) + batch_dims = ("temp_chain", "temp_draw") + if batched_rv.ndim == 2: + dims = batch_dims + elif name in model.named_vars_to_dims: + dims = (*batch_dims, *model.named_vars_to_dims[name]) + else: + dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) + laplace_model.add_coords( + { + name: np.arange(shape) + for name, shape in zip(dims[2:], batched_rv.type.shape[2:]) + } + ) + + pm.Deterministic(name, batched_rv, dims=dims) return laplace_model @@ -221,16 +231,21 @@ def unstack_laplace_draws(idata, model): unstacked_laplace_draws = {} laplace_data = idata.laplace_approximation.values - coords = model.coords | {"chain": range(chains), "draw": range(draws)} + # # There might + # idata_coords = {k: v.tolist() for k, v in zip(idata.coords.keys(), [x.values for x in idata.coords.values()]) + # if k not in ['chain', 'draw', 'unpacked_variable_names']} + # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g. # simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just # add an arviz-style default dim and label. for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info): rv_dims = [] - for i, dim in enumerate(model.named_vars_to_dims.get(rv.name, ())): - if shape[i] == len(coords[dim]): + for i, dim in enumerate( + model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))]) + ): + if coords.get(dim) and shape[i] == len(coords[dim]): rv_dims.append(dim) else: rv_dims.append(f"{name}_dim_{i}") diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index 68ed30cc1..2ec10bd05 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -161,7 +161,7 @@ def test_fit_laplace_ragged_coords(rng): assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() -def test_model_with_nonstandard_dimensionality_1(rng): +def test_model_with_nonstandard_dimensionality(rng): y_obs = np.concatenate( [rng.normal(-1, 2, size=150), rng.normal(3, 1, size=350), rng.normal(5, 4, size=50)] ) @@ -199,6 +199,29 @@ def test_model_with_nonstandard_dimensionality_1(rng): assert "class" in list(idata.posterior.sigma_log__.coords.keys()) +def test_nonscalar_rv_without_dims(): + with pm.Model(coords={"test": ["A", "B", "C"]}) as model: + x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"]) + x = pm.Normal("x", mu=x_loc, sigma=1, shape=(2, 3)) + y = pm.Normal("y", mu=x, sigma=1, observed=np.random.randn(10, 2, 3)) + + idata = pmx.fit_laplace(progressbar=False) + + assert idata.posterior["x"].shape == (2, 500, 2, 3) + assert all(f"x_dim_{i}" in idata.posterior.coords for i in range(2)) + assert idata.fit.rows.values.tolist() == [ + "x_loc[A]", + "x_loc[B]", + "x_loc[C]", + "x[0,0]", + "x[0,1]", + "x[0,2]", + "x[1,0]", + "x[1,1]", + "x[1,2]", + ] + + # Test these three optimizers because they are either special cases for H_inv (BFGS, L-BFGS-B) or are # gradient free and require re-compilation of hessp (powell). @pytest.mark.parametrize("optimizer_method", ["BFGS", "L-BFGS-B", "powell"]) From 7540ec079b5688399b93f74b868e209f2e5f65b6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 10 Jul 2025 08:20:26 +0800 Subject: [PATCH 7/7] Add unconstrained posterior draws/points to unconstrained_posterior --- .../inference/laplace_approx/find_map.py | 26 ++++--- pymc_extras/inference/laplace_approx/idata.py | 57 ++++++++-------- .../inference/laplace_approx/laplace.py | 67 +++++++++++-------- .../inference/laplace_approx/test_find_map.py | 62 ++++++++++++++--- .../inference/laplace_approx/test_laplace.py | 10 +-- 5 files changed, 137 insertions(+), 85 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index c20f7e8b0..930137e22 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -3,7 +3,6 @@ from collections.abc import Callable from typing import Literal, cast -import arviz as az import numpy as np import pymc as pm @@ -19,8 +18,8 @@ from pymc_extras.inference.laplace_approx.idata import ( add_data_to_inference_data, add_fit_to_inference_data, - add_map_posterior_to_inference_data, add_optimizer_result_to_inference_data, + map_results_to_inference_data, ) from pymc_extras.inference.laplace_approx.scipy_interface import ( GradientBackend, @@ -186,7 +185,7 @@ def _compute_inverse_hessian( def find_MAP( - method: minimize_method | Literal["basinhopping"], + method: minimize_method | Literal["basinhopping"] = "L-BFGS-B", *, model: pm.Model | None = None, use_grad: bool | None = None, @@ -317,6 +316,15 @@ def find_MAP( **optimizer_kwargs, ) + H_inv = _compute_inverse_hessian( + optimizer_result=optimizer_result, + optimal_point=None, + f_fused=f_fused, + f_hessp=f_hessp, + use_hess=use_hess, + method=method, + ) + raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")( @@ -327,17 +335,7 @@ def find_MAP( var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) } - H_inv = _compute_inverse_hessian( - optimizer_result=optimizer_result, - optimal_point=None, - f_fused=f_fused, - f_hessp=f_hessp, - use_hess=use_hess, - method=method, - ) - - idata = az.InferenceData() - idata = add_map_posterior_to_inference_data(idata, optimized_point, frozen_model) + idata = map_results_to_inference_data(optimized_point, frozen_model) idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv) idata = add_optimizer_result_to_inference_data( idata, optimizer_result, method, raveled_optimized, model diff --git a/pymc_extras/inference/laplace_approx/idata.py b/pymc_extras/inference/laplace_approx/idata.py index 515adb8c6..edf011dd4 100644 --- a/pymc_extras/inference/laplace_approx/idata.py +++ b/pymc_extras/inference/laplace_approx/idata.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Any, Literal +from typing import Literal import arviz as az import numpy as np @@ -10,6 +10,7 @@ from better_optimize.constants import minimize_method from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations from pymc.blocking import RaveledVars +from pymc.util import get_default_varnames from scipy.optimize import OptimizeResult from scipy.sparse.linalg import LinearOperator @@ -55,31 +56,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str] return unpacked_variable_names -def map_results_to_inference_data(results: dict[str, Any], model: pm.Model | None = None): - """ - Convert a dictionary of results to an InferenceData object. - - Parameters - ---------- - results: dict - A dictionary containing the results to convert. - model: Model, optional - A PyMC model. If None, the model is taken from the current model context. - - Returns - ------- - idata: az.InferenceData - An InferenceData object containing the results. - """ - model = pm.modelcontext(model) - coords, dims = coords_and_dims_for_inferencedata(model) - - idata = az.convert_to_inference_data(results, coords=coords, dims=dims) - return idata - - -def add_map_posterior_to_inference_data( - idata: az.InferenceData, +def map_results_to_inference_data( map_point: dict[str, float | int | np.ndarray], model: pm.Model | None = None, ): @@ -124,10 +101,36 @@ def add_map_posterior_to_inference_data( } ) + constrained_names = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) + ] + all_varnames = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True) + ] + + unconstrained_names = set(all_varnames) - set(constrained_names) + idata = az.from_dict( - {k: np.expand_dims(v, (0, 1)) for k, v in map_point.items()}, coords=coords, dims=dims + posterior={ + k: np.expand_dims(v, (0, 1)) for k, v in map_point.items() if k in constrained_names + }, + coords=coords, + dims=dims, ) + if unconstrained_names: + unconstrained_posterior = az.from_dict( + posterior={ + k: np.expand_dims(v, (0, 1)) + for k, v in map_point.items() + if k in unconstrained_names + }, + coords=coords, + dims=dims, + ) + + idata["unconstrained_posterior"] = unconstrained_posterior.posterior + return idata diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index f7a5df347..2b5ef6a16 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -29,8 +29,9 @@ from better_optimize.constants import minimize_method from numpy.typing import ArrayLike -from pymc import DictToArrayBijection, join_nonshared_inputs +from pymc.blocking import DictToArrayBijection from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.pytensorf import join_nonshared_inputs from pymc.util import get_default_varnames from pytensor.graph import vectorize_graph from pytensor.tensor import TensorVariable @@ -147,17 +148,29 @@ def get_conditional_gaussian_approximation( def _unconstrained_vector_to_constrained_rvs(model): - outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=False) + outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True) + constrained_names = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) + ] names = [x.name for x in outputs] - constrained_rvs, unconstrained_vector = join_nonshared_inputs( + unconstrained_names = [name for name in names if name not in constrained_names] + + new_outputs, unconstrained_vector = join_nonshared_inputs( model.initial_point(), inputs=model.value_vars, outputs=outputs, ) + constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names] + value_rvs = [x for x in new_outputs if x not in constrained_rvs] + unconstrained_vector.name = "unconstrained_vector" - return names, constrained_rvs, unconstrained_vector + + # Redo the names list to ensure it is sorted to match the return order + names = [*constrained_names, *unconstrained_names] + + return names, constrained_rvs, value_rvs, unconstrained_vector def model_to_laplace_approx( @@ -169,7 +182,9 @@ def model_to_laplace_approx( # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov, # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved. - names, constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) + names, constrained_rvs, value_rvs, unconstrained_vector = ( + _unconstrained_vector_to_constrained_rvs(model) + ) coords = model.coords | { "temp_chain": np.arange(chains), @@ -202,11 +217,10 @@ def model_to_laplace_approx( dims = (*batch_dims, *model.named_vars_to_dims[name]) else: dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) + initval = initial_point.get(name, None) + dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] laplace_model.add_coords( - { - name: np.arange(shape) - for name, shape in zip(dims[2:], batched_rv.type.shape[2:]) - } + {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} ) pm.Deterministic(name, batched_rv, dims=dims) @@ -214,7 +228,7 @@ def model_to_laplace_approx( return laplace_model -def unstack_laplace_draws(idata, model): +def unstack_laplace_draws(laplace_data, model, chains=2, draws=500): """ The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a @@ -226,17 +240,9 @@ def unstack_laplace_draws(idata, model): initial_point = DictToArrayBijection.map(model.initial_point()) cursor = 0 - chains = idata.coords["chain"].size - draws = idata.coords["draw"].size - unstacked_laplace_draws = {} - laplace_data = idata.laplace_approximation.values coords = model.coords | {"chain": range(chains), "draw": range(draws)} - # # There might - # idata_coords = {k: v.tolist() for k, v in zip(idata.coords.keys(), [x.values for x in idata.coords.values()]) - # if k not in ['chain', 'draw', 'unpacked_variable_names']} - # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g. # simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just # add an arviz-style default dim and label. @@ -420,23 +426,26 @@ def fit_laplace( ) with model_to_laplace_approx(model, unpacked_variable_names, chains, draws) as laplace_model: - laplace_idata = pm.sample_posterior_predictive( - idata.fit.expand_dims(chain=[0], draw=[0]), - extend_inferencedata=False, - random_seed=random_seed, - var_names=["laplace_approximation", *[x.name for x in laplace_model.deterministics]], - ) new_posterior = ( - laplace_idata.posterior_predictive.squeeze(["chain", "draw"]) + pm.sample_posterior_predictive( + idata.fit.expand_dims(chain=[0], draw=[0]), + extend_inferencedata=False, + random_seed=random_seed, + var_names=[ + "laplace_approximation", + *[x.name for x in laplace_model.deterministics], + ], + ) + .posterior_predictive.squeeze(["chain", "draw"]) .drop_vars(["chain", "draw"]) .rename({"temp_chain": "chain", "temp_draw": "draw"}) ) - new_posterior.update(unstack_laplace_draws(new_posterior, model)) - new_posterior = new_posterior.drop_vars( + idata.unconstrained_posterior = unstack_laplace_draws( + new_posterior.laplace_approximation.values, model, chains=chains, draws=draws + ) + idata.posterior = new_posterior.drop_vars( ["laplace_approximation", "unpacked_variable_names"] ) - idata.posterior = new_posterior - return idata diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py index 61f99e6bf..bf0cb292e 100644 --- a/tests/inference/laplace_approx/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -159,16 +159,20 @@ def test_find_MAP( ) assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") assert hasattr(idata, "fit") assert hasattr(idata, "optimizer_result") assert hasattr(idata, "observed_data") posterior = idata.posterior.squeeze(["chain", "draw"]) - assert "mu" in posterior and "sigma_log__" in posterior and "sigma" in posterior + assert "mu" in posterior and "sigma" in posterior assert posterior["mu"].shape == () - assert posterior["sigma_log__"].shape == () assert posterior["sigma"].shape == () + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + assert "sigma_log__" in unconstrained_posterior + assert unconstrained_posterior["sigma_log__"].shape == () + @pytest.mark.parametrize( "backend, gradient_backend", @@ -195,17 +199,22 @@ def test_map_shared_variables(backend, gradient_backend: GradientBackend): ) assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") assert hasattr(idata, "fit") assert hasattr(idata, "optimizer_result") assert hasattr(idata, "observed_data") assert hasattr(idata, "constant_data") posterior = idata.posterior.squeeze(["chain", "draw"]) - assert "mu" in posterior and "sigma_log__" in posterior and "sigma" in posterior + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + + assert "mu" in posterior and "sigma" in posterior assert posterior["mu"].shape == () - assert posterior["sigma_log__"].shape == () assert posterior["sigma"].shape == () + assert "sigma_log__" in unconstrained_posterior + assert unconstrained_posterior["sigma_log__"].shape == () + @pytest.mark.parametrize( "method, use_grad, use_hess, use_hessp", @@ -242,10 +251,15 @@ def test_find_MAP_basinhopping( ) assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") + posterior = idata.posterior.squeeze(["chain", "draw"]) - assert "mu" in posterior and "sigma_log__" in posterior + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + assert "mu" in posterior assert posterior["mu"].shape == () - assert posterior["sigma_log__"].shape == () + + assert "sigma_log__" in unconstrained_posterior + assert unconstrained_posterior["sigma_log__"].shape == () def test_find_MAP_with_coords(): @@ -261,20 +275,48 @@ def test_find_MAP_with_coords(): idata = find_MAP(progressbar=False, method="L-BFGS-B") assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") assert hasattr(idata, "fit") posterior = idata.posterior.squeeze(["chain", "draw"]) + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + assert ( "mu_loc" in posterior and "mu_scale" in posterior - and "mu_scale_log__" in posterior and "mu" in posterior - and "sigma_log__" in posterior and "sigma" in posterior ) + assert "mu_scale_log__" in unconstrained_posterior and "sigma_log__" in unconstrained_posterior + assert posterior["mu_loc"].shape == () assert posterior["mu_scale"].shape == () - assert posterior["mu_scale_log__"].shape == () assert posterior["mu"].shape == (5,) - assert posterior["sigma_log__"].shape == (5,) assert posterior["sigma"].shape == (5,) + + assert unconstrained_posterior["mu_scale_log__"].shape == () + assert unconstrained_posterior["sigma_log__"].shape == (5,) + + +def test_map_nonscalar_rv_without_dims(): + with pm.Model(coords={"test": ["A", "B", "C"]}) as model: + x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"]) + x = pm.Normal("x", mu=x_loc, sigma=1, shape=(2, 3)) + y = pm.Normal("y", mu=x, sigma=1, observed=np.random.randn(10, 2, 3)) + + idata = find_MAP(method="L-BFGS-B", progressbar=False) + + assert idata.posterior["x"].shape == (1, 1, 2, 3) + assert all(f"x_dim_{i}" in idata.posterior.coords for i in range(2)) + + assert idata.fit.rows.values.tolist() == [ + "x_loc[A]", + "x_loc[B]", + "x_loc[C]", + "x[0,0]", + "x[0,1]", + "x[0,2]", + "x[1,0]", + "x[1,1]", + "x[1,2]", + ] diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index 2ec10bd05..be5665d07 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -188,18 +188,18 @@ def test_model_with_nonstandard_dimensionality(rng): idata = pmx.fit_laplace(progressbar=False) # The dirichlet value variable has a funky shape; check that it got a default - assert "w_simplex___dim_0" in list(idata.posterior.w_simplex__.coords.keys()) - assert "class" not in list(idata.posterior.w_simplex__.coords.keys()) - assert len(idata.posterior.coords["w_simplex___dim_0"]) == 2 + assert "w_simplex___dim_0" in list(idata.unconstrained_posterior.w_simplex__.coords.keys()) + assert "class" not in list(idata.unconstrained_posterior.w_simplex__.coords.keys()) + assert len(idata.unconstrained_posterior.coords["w_simplex___dim_0"]) == 2 # On the other hand, check that the actual w has the correct dims assert "class" in list(idata.posterior.w.coords.keys()) # The log transform is 1-to-1, so it should have the same dims as the original rv - assert "class" in list(idata.posterior.sigma_log__.coords.keys()) + assert "class" in list(idata.unconstrained_posterior.sigma_log__.coords.keys()) -def test_nonscalar_rv_without_dims(): +def test_laplace_nonscalar_rv_without_dims(): with pm.Model(coords={"test": ["A", "B", "C"]}) as model: x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"]) x = pm.Normal("x", mu=x_loc, sigma=1, shape=(2, 3))