-
-
Notifications
You must be signed in to change notification settings - Fork 67
Update and Refactor find_MAP
and fit_laplace
#531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c18ee5a
to
2fc4b45
Compare
This is no longer true, find_MAP returns dummy chain, draw dims now. I thought it was too much to break the arviz promise that posterior always has chain/draw Another jank choice is the temp_chain, temp_draw thing in |
4b9ba99
to
067860f
Compare
d79d642
to
1af7049
Compare
6554ad8
to
2ea85fe
Compare
2ea85fe
to
48b74f7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the MAP-finding and Laplace approximation routines to improve performance by caching Hessian computations, reorganizes code into a laplace_approx
submodule, and standardizes return types to ArviZ InferenceData
.
- Cache and reuse Hessian subcomputations in
find_MAP
/fit_laplace
workflows. - Move all Laplace-related modules under
pymc_extras/inference/laplace_approx
. - Update
find_MAP
to returnInferenceData
and simplifyfit_laplace
interface.
Reviewed Changes
Copilot reviewed 15 out of 18 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
tests/inference/laplace_approx/test_find_map.py | Add tests for new find_MAP defaults, PSD helper, and JAX paths |
pymc_extras/inference/laplace_approx/scipy_interface.py | New module for compiling loss/grad/Hessian for SciPy optimizers |
pymc_extras/inference/laplace_approx/laplace.py | Refactor fit_laplace , cache inverse Hessian, build idata |
pymc_extras/inference/laplace_approx/find_map.py | Refactor find_MAP , wrap into InferenceData , split logic |
pymc_extras/inference/laplace_approx/idata.py | Helpers to add data/fit/optimizer results into InferenceData |
pymc_extras/inference/pathfinder/pathfinder.py | Update import to new add_data_to_inference_data helper |
pymc_extras/inference/fit.py | Route fit(method="laplace") to new Laplace submodule |
pyproject.toml | Bump better-optimize dependency to ≥0.1.4 |
Comments suppressed due to low confidence (2)
pymc_extras/inference/laplace_approx/scipy_interface.py:101
- The docstring lists
f_fused
andf_hessp
as return values but the function actually returns a list of one or twoFunction
objects. Update the doc to reflect that it returns alist[Function]
(or[Function, Function]
).
f_fused: Function
pymc_extras/inference/laplace_approx/scipy_interface.py:53
- The return statement uses a starred expression (
return *loss_and_grad, hess
), which is invalid syntax in Python. Wrap the unpacking in a tuple, e.g.:return (*loss_and_grad, hess)
.
return *loss_and_grad, hess
This should be ready to go. Last changes:
|
48b74f7
to
3f2aa8b
Compare
I just made an update to better-optimize that uses hessian matrix caching for better performance. This is something we can immediately take advantage of with
find_MAP
, re-using sub-computation from the loss or gradient in the hessian computation. This PR updates the functions generated by find_MAP to take advantage of this.While I was at it, I went ahead and did some cleanup and reorganization of the code. In particular:
laplace_approx
submodule.find_MAP
now returns an idata. This is more consistent with all the other PyMC sampling function -- it's weird to get back a dictionary in this one case.find_MAP
will now always store the inverse hessian. This is done to try to avoid an extra function compilation when it is used in conjunction withfit_laplace
.fit_laplace
was a really dumb function that was inexplicably sampling from scipy distributions. This required a ton of unnecessary work. If only we had a PPL that could help sample from complicated distributions...fit_laplace
still isn't perfect. I wanted to store both the value variables and the transformed RVs as deterministics in a pymc model and sample them directly, but that doesn't appear to work -- maybe this is a bug? I ended up doing two passes, once for the constrained RVs, then a second pass for the unconstrained. It would be good to minimize that.I also removed as many little options that were floating around as possible. These function signatures were already horrible.
Finally, I eliminated a lot of test parameterizations to speed the CI up, but also added a lot of new tests for functions that were previously not covered. Hopefully it's still net positive.