-
Notifications
You must be signed in to change notification settings - Fork 10
Entropy-regularised Gromov-Wasserstein #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 7 commits
2ef3e2b
11efd8c
3273976
0956c3b
c22d7e7
ff1a92c
267dfad
21609b0
9699e04
8510397
2f2428f
20d5885
df41c28
a7c1a38
19e4cab
56c4f9b
6e3ac4c
5c376ae
af2a493
6bc3127
a806f0f
f704397
71351b9
f2acc56
0635305
c3efe5a
39f0b36
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Gromov-Wasserstein solver | ||
|
||
abstract type EntropicGromovWasserstein end | ||
|
||
struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein | ||
alg_step::Sinkhorn | ||
end | ||
|
||
function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, | ||
alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) | ||
T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) | ||
C = similar(Cμ, T, size(μ, 1), size(ν, 1)) | ||
tmp = similar(C) | ||
plan = similar(C) | ||
@. plan = μ * ν' | ||
plan_prev = similar(C) | ||
plan_prev .= plan | ||
norm_plan = sum(plan) | ||
|
||
_atol = atol === nothing ? 0 : atol | ||
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol | ||
|
||
function get_new_cost!(C, plan, tmp, Cμ, Cν) | ||
A_batched_mul_B!(tmp, Cμ, plan) | ||
A_batched_mul_B!(C, tmp, -4Cν) | ||
# seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? | ||
# added the factor of 4 here to ensure reproducibility for the same value of ε. | ||
# https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 | ||
end | ||
|
||
get_new_cost!(C, plan, tmp, Cμ, Cν) | ||
to_check_step = check_convergence | ||
|
||
isconverged = false | ||
for iter in 1:maxiter | ||
# perform Sinkhorn algorithm | ||
solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) | ||
solve!(solver) | ||
# compute optimal transport plan | ||
plan = sinkhorn_plan(solver) | ||
|
||
to_check_step -= 1 | ||
if to_check_step == 0 || iter == maxiter | ||
# reset counter | ||
to_check_step = check_convergence | ||
isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) | ||
if isconverged | ||
@debug "$Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" | ||
break | ||
end | ||
plan_prev .= plan | ||
end | ||
get_new_cost!(C, plan, tmp, Cμ, Cν) | ||
end | ||
|
||
return plan | ||
end |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,57 @@ | ||||||||
# Gromov-Wasserstein solver | ||||||||
|
||||||||
abstract type EntropicGromovWasserstein end | ||||||||
|
||||||||
struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein | ||||||||
alg_step::Sinkhorn | ||||||||
end | ||||||||
|
||||||||
function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, | ||||||||
alg::EntropicGromovWasserstein = EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) | ||||||||
zsteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) | ||||||||
C = similar(Cμ, T, size(μ, 1), size(ν, 1)) | ||||||||
tmp = similar(C) | ||||||||
plan = similar(C) | ||||||||
@. plan = μ * ν' | ||||||||
plan_prev = similar(C) | ||||||||
plan_prev .= plan | ||||||||
norm_plan = sum(plan) | ||||||||
|
||||||||
_atol = atol === nothing ? 0 : atol | ||||||||
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol | ||||||||
|
||||||||
function get_new_cost!(C, plan, tmp, Cμ, Cν) | ||||||||
A_batched_mul_B!(tmp, Cμ, plan) | ||||||||
A_batched_mul_B!(C, tmp, -4Cν) | ||||||||
zsteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
# seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? | ||||||||
# added the factor of 4 here to ensure reproducibility for the same value of ε. | ||||||||
# https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 | ||||||||
end | ||||||||
|
||||||||
get_new_cost!(C, plan, tmp, Cμ, Cν) | ||||||||
to_check_step = check_convergence | ||||||||
|
||||||||
isconverged = false | ||||||||
for iter in 1:maxiter | ||||||||
# perform Sinkhorn algorithm | ||||||||
solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) | ||||||||
solve!(solver) | ||||||||
# compute optimal transport plan | ||||||||
plan = sinkhorn_plan(solver) | ||||||||
|
||||||||
to_check_step -= 1 | ||||||||
if to_check_step == 0 || iter == maxiter | ||||||||
# reset counter | ||||||||
to_check_step = check_convergence | ||||||||
isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe also avoid allocations here by writing:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The initial plan is taken to be the independent coupling and here we only consider the balanced problem, so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good catch, done |
||||||||
if isconverged | ||||||||
@debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" | ||||||||
break | ||||||||
end | ||||||||
plan_prev .= plan | ||||||||
end | ||||||||
get_new_cost!(C, plan, tmp, Cμ, Cν) | ||||||||
end | ||||||||
|
||||||||
return plan | ||||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[deps] | ||
zsteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
using OptimalTransport | ||
|
||
using Distances | ||
using PythonOT: PythonOT | ||
|
||
using Random | ||
using Test | ||
using LinearAlgebra | ||
|
||
const POT = PythonOT | ||
|
||
Random.seed!(100) | ||
|
||
@testset "gromov.jl" begin | ||
@testset "entropic_gromov_wasserstein" begin | ||
M, N = 250, 200 | ||
|
||
μ = fill(1/M, M) | ||
zsteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
μ_spt = rand(M) | ||
ν = fill(1/N, N) | ||
zsteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ν_spt = rand(N) | ||
|
||
Cμ = pairwise(SqEuclidean(), μ_spt) | ||
Cν = pairwise(SqEuclidean(), ν_spt) | ||
|
||
γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence = 10) | ||
zsteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) | ||
|
||
@test γ ≈ γ_pot rtol = 1e-6 | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.