Skip to content

Update and Refactor find_MAP and fit_laplace #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

I just made an update to better-optimize that uses hessian matrix caching for better performance. This is something we can immediately take advantage of with find_MAP, re-using sub-computation from the loss or gradient in the hessian computation. This PR updates the functions generated by find_MAP to take advantage of this.

While I was at it, I went ahead and did some cleanup and reorganization of the code. In particular:

  1. I split up the files into smaller, more logical groupings. I moved all the files to a laplace_approx submodule.
  2. find_MAP now returns an idata. This is more consistent with all the other PyMC sampling function -- it's weird to get back a dictionary in this one case.
  3. When available, find_MAP will now always store the inverse hessian. This is done to try to avoid an extra function compilation when it is used in conjunction with fit_laplace.
  4. fit_laplace was a really dumb function that was inexplicably sampling from scipy distributions. This required a ton of unnecessary work. If only we had a PPL that could help sample from complicated distributions...

fit_laplace still isn't perfect. I wanted to store both the value variables and the transformed RVs as deterministics in a pymc model and sample them directly, but that doesn't appear to work -- maybe this is a bug? I ended up doing two passes, once for the constrained RVs, then a second pass for the unconstrained. It would be good to minimize that.

I also removed as many little options that were floating around as possible. These function signatures were already horrible.

Finally, I eliminated a lot of test parameterizations to speed the CI up, but also added a lot of new tests for functions that were previously not covered. Hopefully it's still net positive.

Copilot

This comment was marked as outdated.

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jun 29, 2025

One slightly janky thing is that I store the MAP result in the posterior group, but without chain and draw dimensions. That might be a bad choice, because it breaks a promise that idata typically makes. But they also don't have these dimension. Not sure.

This is no longer true, find_MAP returns dummy chain, draw dims now. I thought it was too much to break the arviz promise that posterior always has chain/draw

Another jank choice is the temp_chain, temp_draw thing in model_to_laplace_approx. Someone else might know a better way to accomplish what I'm doing here :)

@jessegrabowski jessegrabowski force-pushed the map-laplace-updates branch 3 times, most recently from 4b9ba99 to 067860f Compare June 29, 2025 06:20
@jessegrabowski jessegrabowski requested a review from Copilot June 29, 2025 09:34
Copilot

This comment was marked as outdated.

@jessegrabowski jessegrabowski requested a review from Copilot July 5, 2025 04:38
Copilot

This comment was marked as outdated.

@jessegrabowski jessegrabowski force-pushed the map-laplace-updates branch 3 times, most recently from 6554ad8 to 2ea85fe Compare July 5, 2025 05:50
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the MAP-finding and Laplace approximation routines to improve performance by caching Hessian computations, reorganizes code into a laplace_approx submodule, and standardizes return types to ArviZ InferenceData.

  • Cache and reuse Hessian subcomputations in find_MAP/fit_laplace workflows.
  • Move all Laplace-related modules under pymc_extras/inference/laplace_approx.
  • Update find_MAP to return InferenceData and simplify fit_laplace interface.

Reviewed Changes

Copilot reviewed 15 out of 18 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/inference/laplace_approx/test_find_map.py Add tests for new find_MAP defaults, PSD helper, and JAX paths
pymc_extras/inference/laplace_approx/scipy_interface.py New module for compiling loss/grad/Hessian for SciPy optimizers
pymc_extras/inference/laplace_approx/laplace.py Refactor fit_laplace, cache inverse Hessian, build idata
pymc_extras/inference/laplace_approx/find_map.py Refactor find_MAP, wrap into InferenceData, split logic
pymc_extras/inference/laplace_approx/idata.py Helpers to add data/fit/optimizer results into InferenceData
pymc_extras/inference/pathfinder/pathfinder.py Update import to new add_data_to_inference_data helper
pymc_extras/inference/fit.py Route fit(method="laplace") to new Laplace submodule
pyproject.toml Bump better-optimize dependency to ≥0.1.4
Comments suppressed due to low confidence (2)

pymc_extras/inference/laplace_approx/scipy_interface.py:101

  • The docstring lists f_fused and f_hessp as return values but the function actually returns a list of one or two Function objects. Update the doc to reflect that it returns a list[Function] (or [Function, Function]).
    f_fused: Function

pymc_extras/inference/laplace_approx/scipy_interface.py:53

  • The return statement uses a starred expression (return *loss_and_grad, hess), which is invalid syntax in Python. Wrap the unpacking in a tuple, e.g.: return (*loss_and_grad, hess).
            return *loss_and_grad, hess

@jessegrabowski
Copy link
Member Author

This should be ready to go. Last changes:

  • find_laplace should work in all cases, including when the value variables have a different shape as the RVs (added test for this). I brought back the old way of making the batched RVs, then combine it with the new pymc model approach to make the raveled vector. I think the result is quite nice.

  • find_MAP posterior has chain and draw, to stick fast to the arviz convention (but it can be easily squeezed away by the user)

  • I found some unused helpers and removed them. Also removed the utilities file, because I'm sick of those and I never remember where I put anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request maintenance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants