Skip to content

Commit e37845e

Browse files
authored
fix sinkhorn2 bug for ReverseDiff (#130)
* fix sinkhorn2 bug for ReverseDiff * remove unnecessary files and format * fix typo * change eps to be larger for autodiff tests * incorporate @devmotion's fix instead and fix deps * test to fix CI * format
1 parent ab9bc76 commit e37845e

File tree

4 files changed

+68
-63
lines changed

4 files changed

+68
-63
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ julia = "1"
3030

3131
[extras]
3232
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
33+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3334
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
3435
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3536
PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
@@ -39,4 +40,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3940
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"
4041

4142
[targets]
42-
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip", "HCubature"]
43+
test = ["ForwardDiff", "ReverseDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip", "HCubature"]

src/entropic/sinkhorn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ function sinkhorn2(μ, ν, C, ε, alg::Sinkhorn; regularization=false, plan=noth
202202
end
203203
cost = if regularization
204204
dot_matwise(γ, C) .+
205-
ε * reshape(sum(LogExpFunctions.xlogx, γ; dims=(1, 2)), size(γ)[3:end])
205+
ε .* reshape(sum(LogExpFunctions.xlogx, γ; dims=(1, 2)), size(γ)[3:end])
206206
else
207207
dot_matwise(γ, C)
208208
end

test/entropic/sinkhorn_gibbs.jl

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using OptimalTransport
22

33
using Distances
44
using ForwardDiff
5+
using ReverseDiff
56
using LogExpFunctions
67
using PythonOT: PythonOT
78

@@ -160,68 +161,71 @@ Random.seed!(100)
160161
# together. test against gradient computed using analytic formula of Proposition 2.3 of
161162
# Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343.
162163
#
164+
ε = 0.05 # use a larger ε to avoid having to do many iterations
163165
# target marginal
164-
= ForwardDiff.gradient(log.(ν)) do xs
165-
sinkhorn2(μ, softmax(xs), C, ε, SinkhornGibbs(); regularization=true)
166+
for Diff in [ReverseDiff, ForwardDiff]
167+
= Diff.gradient(log.(ν)) do xs
168+
sinkhorn2(μ, softmax(xs), C, ε, SinkhornGibbs(); regularization=true)
169+
end
170+
∇default = Diff.gradient(log.(ν)) do xs
171+
sinkhorn2(μ, softmax(xs), C, ε; regularization=true)
172+
end
173+
@test== ∇default
174+
175+
solver = OptimalTransport.build_solver(μ, ν, C, ε, SinkhornGibbs())
176+
OptimalTransport.solve!(solver)
177+
# helper function
178+
function dualvar_to_grad(x, ε)
179+
x = -ε * log.(x)
180+
x .-= sum(x) / size(x, 1)
181+
return -x
182+
end
183+
∇_ot = dualvar_to_grad(solver.cache.v, ε)
184+
# chain rule because target measure parameterised by softmax
185+
J_softmax = ForwardDiff.jacobian(log.(ν)) do xs
186+
softmax(xs)
187+
end
188+
∇analytic_target = J_softmax * ∇_ot
189+
# check that gradient obtained by AD matches the analytic formula
190+
@test ∇analytic_target rtol = 1e-6
191+
192+
# source marginal
193+
= Diff.gradient(log.(μ)) do xs
194+
sinkhorn2(softmax(xs), ν, C, ε, SinkhornGibbs(); regularization=true)
195+
end
196+
∇default = Diff.gradient(log.(μ)) do xs
197+
sinkhorn2(softmax(xs), ν, C, ε; regularization=true)
198+
end
199+
@test== ∇default
200+
201+
# check that gradient obtained by AD matches the analytic formula
202+
solver = OptimalTransport.build_solver(μ, ν, C, ε, SinkhornGibbs())
203+
OptimalTransport.solve!(solver)
204+
J_softmax = ForwardDiff.jacobian(log.(μ)) do xs
205+
softmax(xs)
206+
end
207+
∇_ot = dualvar_to_grad(solver.cache.u, ε)
208+
∇analytic_source = J_softmax * ∇_ot
209+
@test ∇analytic_source rtol = 1e-6
210+
211+
# both marginals
212+
= Diff.gradient(log.(vcat(μ, ν))) do xs
213+
sinkhorn2(
214+
softmax(xs[1:M]),
215+
softmax(xs[(M + 1):end]),
216+
C,
217+
ε,
218+
SinkhornGibbs();
219+
regularization=true,
220+
)
221+
end
222+
∇default = Diff.gradient(log.(vcat(μ, ν))) do xs
223+
sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true)
224+
end
225+
@test== ∇default
226+
∇analytic = vcat(∇analytic_source, ∇analytic_target)
227+
@test ∇analytic rtol = 1e-6
166228
end
167-
∇default = ForwardDiff.gradient(log.(ν)) do xs
168-
sinkhorn2(μ, softmax(xs), C, ε; regularization=true)
169-
end
170-
@test== ∇default
171-
172-
solver = OptimalTransport.build_solver(μ, ν, C, ε, SinkhornGibbs())
173-
OptimalTransport.solve!(solver)
174-
# helper function
175-
function dualvar_to_grad(x, ε)
176-
x = -ε * log.(x)
177-
x .-= sum(x) / size(x, 1)
178-
return -x
179-
end
180-
∇_ot = dualvar_to_grad(solver.cache.v, ε)
181-
# chain rule because target measure parameterised by softmax
182-
J_softmax = ForwardDiff.jacobian(log.(ν)) do xs
183-
softmax(xs)
184-
end
185-
∇analytic_target = J_softmax * ∇_ot
186-
# check that gradient obtained by AD matches the analytic formula
187-
@test ∇analytic_target rtol = 1e-6
188-
189-
# source marginal
190-
= ForwardDiff.gradient(log.(μ)) do xs
191-
sinkhorn2(softmax(xs), ν, C, ε, SinkhornGibbs(); regularization=true)
192-
end
193-
∇default = ForwardDiff.gradient(log.(μ)) do xs
194-
sinkhorn2(softmax(xs), ν, C, ε; regularization=true)
195-
end
196-
@test== ∇default
197-
198-
# check that gradient obtained by AD matches the analytic formula
199-
solver = OptimalTransport.build_solver(μ, ν, C, ε, SinkhornGibbs())
200-
OptimalTransport.solve!(solver)
201-
J_softmax = ForwardDiff.jacobian(log.(μ)) do xs
202-
softmax(xs)
203-
end
204-
∇_ot = dualvar_to_grad(solver.cache.u, ε)
205-
∇analytic_source = J_softmax * ∇_ot
206-
@test ∇analytic_source rtol = 1e-6
207-
208-
# both marginals
209-
= ForwardDiff.gradient(log.(vcat(μ, ν))) do xs
210-
sinkhorn2(
211-
softmax(xs[1:M]),
212-
softmax(xs[(M + 1):end]),
213-
C,
214-
ε,
215-
SinkhornGibbs();
216-
regularization=true,
217-
)
218-
end
219-
∇default = ForwardDiff.gradient(log.(vcat(μ, ν))) do xs
220-
sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true)
221-
end
222-
@test== ∇default
223-
∇analytic = vcat(∇analytic_source, ∇analytic_target)
224-
@test ∇analytic rtol = 1e-6
225229
end
226230

227231
@testset "deprecations" begin

test/exact.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Random.seed!(100)
7272

7373
# compute OT plan
7474
γ = ot_plan(sqeuclidean, μ, ν)
75-
x = randn()
75+
x = 0
7676
@test γ(x) quantile(ν, cdf(μ, x))
7777

7878
# compute OT cost

0 commit comments

Comments
 (0)