Skip to content

Conversation

manuelgloeckler
Copy link
Contributor

@manuelgloeckler manuelgloeckler commented Jan 30, 2025

Completes the missing features based on score estimation #1226.

  • IID interface
  • IID util functions
  • FNPE
  • GAUSS
  • JAC
  • test

@manuelgloeckler
Copy link
Contributor Author

manuelgloeckler commented Feb 17, 2025

Okey, everything should be implemented now. This acutally became quite a big PR now. A few more points:

  • Check if batch jacobian with torch.func.vmap actually works correctly
  • Check if the above or other cause some performance degradation in jac_gauss (although this can be sensitive to how the network is preconditioned)
  • Add an API to pass hyperparameters to the IID method (and make iid_methods more customizable i.e. auto_gauss)
  • Multivariate priors
  • General Empirical prior support for automatic denoising and marginalization (then auto_gauss should become default)

Copy link

codecov bot commented Feb 17, 2025

Codecov Report

Attention: Patch coverage is 81.63265% with 81 lines in your changes missing coverage. Please review.

Project coverage is 34.65%. Comparing base (7b6cab3) to head (9fa0c41).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
sbi/inference/potentials/score_fn_iid.py 69.38% 60 Missing ⚠️
sbi/utils/score_utils.py 90.62% 18 Missing ⚠️
sbi/samplers/score/diffuser.py 66.66% 3 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (7b6cab3) and HEAD (9fa0c41). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (7b6cab3) HEAD (9fa0c41)
unittests 3 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1381       +/-   ##
===========================================
- Coverage   89.38%   34.65%   -54.74%     
===========================================
  Files         119      121        +2     
  Lines        8905     9338      +433     
===========================================
- Hits         7960     3236     -4724     
- Misses        945     6102     +5157     
Flag Coverage Δ
unittests 34.65% <81.63%> (-54.74%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/score_posterior.py 81.18% <100.00%> (-10.82%) ⬇️
sbi/inference/potentials/score_based_potential.py 84.00% <100.00%> (-12.81%) ⬇️
sbi/samplers/score/correctors.py 98.18% <100.00%> (+46.00%) ⬆️
sbi/samplers/score/diffuser.py 86.66% <66.66%> (+1.48%) ⬆️
sbi/utils/score_utils.py 90.62% <90.62%> (ø)
sbi/inference/potentials/score_fn_iid.py 69.38% <69.38%> (ø)

... and 89 files with indirect coverage changes

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@manuelgloeckler
Copy link
Contributor Author

This is now basically done. With the review, one should probably wait until the other score branch and type fixes are merged.

But the major changes are:

  • ScoreFnIID classes which manage the score composition
  • ScoreUtil, which has a bunch of helpers for "automatic" marginalization and denoising of PyTorch distributions (i.e., what the user can pass as the prior). If there is no analytic solution (or the user does not pass a prior) it will fall back to a rather good MoG approximation.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, great effort! 👏 Thanks for adding all those methods!

Looks great overall, but I was a bit confused by the class structure in the score_fn_iid.py and added a couple of comments.
Also, the tests can be refactored a bit.

Please note that you might have to rebase or merge again with main once #1404 is merged.

@manuelgloeckler
Copy link
Contributor Author

Thanks for the review. I addressed most of the points above and left open what I wasn't sure about.

Summarizing points:

  • Rebased on main.
  • Fixed the abstract classes and structure/typing issues.
  • Added more documentation to introduced methods.
  • Changed the testing - Now the training is a "module" level fixture; hence, it will only be executed once this module is tested.

@manuelgloeckler manuelgloeckler requested a review from janfb March 12, 2025 07:27
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @manuelgloeckler, almost done!

Added a couple of final questions and pushed small docs fixes myself.

@janfb
Copy link
Contributor

janfb commented Mar 13, 2025

all CD tests are passing!

@manuelgloeckler manuelgloeckler requested a review from janfb March 14, 2025 16:52
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, thanks for the update!
fixed some typos and will update the changelog to include this into release.

@janfb janfb merged commit 1ee577e into main Mar 14, 2025
2 of 7 checks passed
@janfb janfb deleted the 1226-score-based-iid branch March 14, 2025 17:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

missing features and todos for score estimation
3 participants