-
Notifications
You must be signed in to change notification settings - Fork 537
[MRG] Bures-Wasserstein Gradient Descent for Bures-Wasserstein Barycenters #680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
rflamary
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comments. I will let @antoinecollas do a proper review he is the expert in Riemannian optimization
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #680 +/- ##
==========================================
+ Coverage 97.10% 97.13% +0.03%
==========================================
Files 100 100
Lines 20115 20369 +254
==========================================
+ Hits 19532 19786 +254
Misses 583 583 🚀 New features to boost your workflow:
|
rflamary
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. A few tests especialy about errors are missing
rflamary
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few questions and then we can merge
| - Automatic PR labeling and release file update check (PR #704) | ||
| - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) | ||
| - Fix documentation in the module `ot.gaussian` (PR #718) | ||
| - Refactored `ot.bregman._convolutional` to improve readability (PR #709) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont' see that in the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmh, I think I did a mistake when merging with the master at some point. (It was deleted from Line 46 of the Releases.md, and it seemed to be in the wrong releases of POT)
|
|
||
| def trace(self, a): | ||
| return np.trace(a) | ||
| return np.einsum("...ii", a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that faster or slower? we need an idea
ot/gaussian.py
Outdated
| Returns | ||
| ------- | ||
| W : float | ||
| W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), array-like (n,m) if ms of shape (n,d) and mt of shape (m,d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too complicated API, do float if d, and for the rest use a parameter that return paireed or cross distances
Types of changes
This PR aims to add the Bures-Wasserstein gradient descent solver to compute Bures-Wasserstein barycenters (see e.g. Gradient descent algorithms for Bures-Wasserstein barycenters or Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent).
ot.gaussian.bures_wasserstein_barycenterto allow to use different methodsot.gaussian.bures_barycenter_fixpointot.gaussian.bures_barycenter_gradient_descenttest_bures_wasserstein_barycentertest_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenterot.gaussian.bures_wasserstein_distanceMotivation and context / Related issue
The Bures-Wasserstein gradient descent comes with convergence guarantees to solve Bures-Wasserstein barycenters. Moreover, it can also be used in a stochastic way when there are too much Gaussian. Thus, it is a good alternative to the fixed-point algorithm currently implemented.
How has this been tested (if it applies)
I added a test
test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenterto assess both methods returns the same barycenter. I also added the itertools totest_bures_wasserstein_barycenter.PR checklist