-
Notifications
You must be signed in to change notification settings - Fork 10
Entropy-regularised Gromov-Wasserstein #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zsteve
wants to merge
27
commits into
master
Choose a base branch
from
gromov
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
2ef3e2b
first attempt at gromov-wasserstein
zsteve 11efd8c
update
zsteve 3273976
Merge branch 'master' into gromov
zsteve 0956c3b
fixed computation of entropic gromov-wasserstein
zsteve c22d7e7
fixed computation of entropic gromov-wasserstein
zsteve ff1a92c
Merge branch 'gromov' of https://github.com/JuliaOptimalTransport/Opt…
zsteve 267dfad
exports and tests
zsteve 21609b0
formatting
zsteve 9699e04
Update test/gpu/simple_gpu.jl
zsteve 8510397
update docstrings
zsteve 2f2428f
Merge branch 'gromov' of https://github.com/JuliaOptimalTransport/Opt…
zsteve 20d5885
delete cache file
zsteve df41c28
add docs and format
zsteve a7c1a38
remove unnecessary Logging import
zsteve 19e4cab
fix missing power of 2
zsteve 56c4f9b
pull changes from master
zsteve 6e3ac4c
update version number
zsteve 5c376ae
add docs workflow
zsteve af2a493
add Gromov-Wasserstein to readme
zsteve 6bc3127
bump Julia ver for CI
zsteve a806f0f
minor edit to runtests
zsteve f704397
Update .github/workflows/CI.yml
zsteve 71351b9
Update test/runtests.jl
zsteve f2acc56
delete junk files/dirs
zsteve 0635305
revert runtests.jl
zsteve c3efe5a
avoid unnecessary allocations
zsteve 39f0b36
format
zsteve File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Gromov-Wasserstein solver | ||
|
||
abstract type EntropicGromovWasserstein end | ||
|
||
struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein | ||
alg_step::Sinkhorn | ||
end | ||
|
||
""" | ||
entropic_gromov_wasserstein( | ||
μ, ν, Cμ, Cν, ε, alg=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); | ||
atol = nothing, rtol = nothing, check_convergence = 10, maxiter = 1_000, kwargs... | ||
) | ||
|
||
Computes the transport map for the entropically regularized Gromov-Wasserstein optimal transport problem with source and target | ||
marginals `μ` and `ν` and corresponding cost matrices `Cμ` and `Cν`. That is, we seek `γ` a local minimizer of | ||
```math | ||
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, i', j'} |C^{(\\mu)}_{i,i'} - C^{(\\nu)}_{j,j'}|^2 \\gamma_{i,j} \\gamma_{i',j'} + \\varepsilon \\Omega(\\gamma), | ||
``` | ||
where ``\\Omega(\\gamma)`` is the entropic regularization term, see e.g. [`sinkhorn`](@ref). | ||
|
||
This function employs the iterative method described in (Section 10.6.4, [^PC19]), which solves a series of Sinkhorn iteration sub-problems to arrive at a solution. Note that the Gromov-Wasserstein problem is non-convex owing to the cross-terms in the | ||
objective function, and thus in general one is guaranteed to arrive at a local optimum. | ||
|
||
Every `check_convergence` steps, the current iteration of `γ` is compared with `γ_prev` (the previous iteration from `check_convergence` ago). | ||
The quantity ``\\| \\gamma - \\gamma_\\text{prev} \\|_1`` is compared against `atol` and `rtol`. | ||
|
||
[^PC19]: Peyré, G. and Cuturi, M., 2019. Computational optimal transport: With applications to data science. Foundations and Trends® in Machine Learning, 11(5-6), pp.355-607. | ||
|
||
See also: [`sinkhorn`](@ref) | ||
""" | ||
function entropic_gromov_wasserstein( | ||
μ::AbstractVector, | ||
ν::AbstractVector, | ||
Cμ::AbstractMatrix, | ||
Cν::AbstractMatrix, | ||
ε::Real, | ||
alg::EntropicGromovWasserstein=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); | ||
atol=nothing, | ||
rtol=nothing, | ||
check_convergence=10, | ||
maxiter::Int=1_000, | ||
kwargs..., | ||
) | ||
T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) | ||
C = similar(Cμ, T, size(μ, 1), size(ν, 1)) | ||
tmp = similar(C) | ||
plan = similar(C) | ||
@. plan = μ * ν' | ||
plan_prev = similar(C) | ||
plan_prev .= plan | ||
norm_plan = sum(plan) | ||
|
||
_atol = atol === nothing ? 0 : atol | ||
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol | ||
|
||
function get_new_cost!(C, plan, tmp, Cμ, Cν) | ||
A_batched_mul_B!(tmp, Cμ, plan) | ||
lmul!(-4, tmp) | ||
return A_batched_mul_B!(C, tmp, Cν) | ||
# seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? | ||
# added the factor of 4 here to ensure reproducibility for the same value of ε. | ||
# https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 | ||
end | ||
|
||
get_new_cost!(C, plan, tmp, Cμ, Cν) | ||
to_check_step = check_convergence | ||
|
||
isconverged = false | ||
for iter in 1:maxiter | ||
# perform Sinkhorn algorithm | ||
solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) | ||
solve!(solver) | ||
# compute optimal transport plan | ||
plan = sinkhorn_plan(solver) | ||
|
||
to_check_step -= 1 | ||
if to_check_step == 0 || iter == maxiter | ||
# reset counter | ||
to_check_step = check_convergence | ||
plan_prev .-= plan | ||
isconverged = sum(abs, plan_prev) < max(_atol, _rtol * norm_plan) | ||
if isconverged | ||
@debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" | ||
break | ||
end | ||
plan_prev .= plan | ||
end | ||
get_new_cost!(C, plan, tmp, Cμ, Cν) | ||
end | ||
|
||
return plan | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
using OptimalTransport | ||
|
||
using Distances | ||
using PythonOT: PythonOT | ||
|
||
using Random | ||
using Test | ||
using LinearAlgebra | ||
|
||
const POT = PythonOT | ||
|
||
Random.seed!(100) | ||
|
||
@testset "gromov.jl" begin | ||
@testset "entropic_gromov_wasserstein" begin | ||
M, N = 250, 200 | ||
|
||
μ = fill(1 / M, M) | ||
μ_spt = rand(M) | ||
ν = fill(1 / N, N) | ||
ν_spt = rand(N) | ||
|
||
Cμ = pairwise(SqEuclidean(), μ_spt) | ||
Cν = pairwise(SqEuclidean(), ν_spt) | ||
|
||
γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence=10) | ||
γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) | ||
|
||
@test γ ≈ γ_pot rtol = 1e-6 | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.