Skip to content

Commit 59223df

Browse files
authored
add sinkhorn_divergence to runtests (#146)
* add sinkhorn_divergence to runtests * fix issues with julia 1.0 compat * add more compat tag * remove usage of eachcol
1 parent 803274d commit 59223df

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
2727
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2828
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2929
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
30-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3130
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
31+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3232

3333
[targets]
3434
test = ["Distances", "ForwardDiff", "ReverseDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "StatsBase"]

src/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ end
2525
dot_matwise(x::AbstractMatrix, y::AbstractArray) = dot_matwise(y, x)
2626

2727
function dot_vecwise(x::AbstractMatrix, y::AbstractMatrix)
28-
return [dot(u, v) for (u, v) in zip(eachcol(x), eachcol(y))]
28+
return [
29+
dot(u, v) for (u, v) in
30+
zip((view(x, :, i) for i in axes(x, 2)), (view(y, :, i) for i in axes(y, 2)))
31+
]
2932
end
3033

3134
dot_vecwise(x::AbstractMatrix, y::AbstractVector) = x' * y

test/entropic/sinkhorn_divergence.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using ForwardDiff
55
using ReverseDiff
66
using LogExpFunctions
77
using PythonOT: PythonOT
8-
using StatsBase
98
using LinearAlgebra
109
using Random
1110
using Test
@@ -54,18 +53,18 @@ Random.seed!(100)
5453
for reg in (true, false)
5554
loss_batch = sinkhorn_divergence(μ, ν, C, ε; regularization=reg)
5655
@test loss_batch [
57-
sinkhorn_divergence(x, y, C, ε; regularization=reg) for
58-
(x, y) in zip(eachcol(μ), eachcol(ν))
56+
sinkhorn_divergence(μ[:, i], ν[:, i], C, ε; regularization=reg) for
57+
i in 1:M
5958
]
6059
loss_batch_μ = sinkhorn_divergence(μ, ν[:, 1], C, ε; regularization=reg)
6160
@test loss_batch_μ [
62-
sinkhorn_divergence(x, ν[:, 1], C, ε; regularization=reg) for
63-
x in eachcol(μ)
61+
sinkhorn_divergence(μ[:, i], ν[:, 1], C, ε; regularization=reg) for
62+
i in 1:M
6463
]
6564
loss_batch_ν = sinkhorn_divergence(μ[:, 1], ν, C, ε; regularization=reg)
6665
@test loss_batch_ν [
67-
sinkhorn_divergence(μ[:, 1], y, C, ε; regularization=reg) for
68-
y in eachcol(ν)
66+
sinkhorn_divergence(μ[:, 1], ν[:, i], C, ε; regularization=reg) for
67+
i in 1:M
6968
]
7069
end
7170
end
@@ -98,7 +97,7 @@ Random.seed!(100)
9897
Cμν = pairwise(SqEuclidean(), μ_spt', ν_spt'; dims=2)
9998
= pairwise(SqEuclidean(), μ_spt'; dims=2)
10099
= pairwise(SqEuclidean(), ν_spt'; dims=2)
101-
ε = 0.1 * max(mean(Cμν), mean(Cμ), mean(Cν))
100+
ε = 1.0
102101

103102
@testset "basic" begin
104103
for reg in (true, false)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ const GROUP = get(ENV, "GROUP", "All")
2828
@safetestset "Sinkhorn barycenter" begin
2929
include(joinpath("entropic", "sinkhorn_barycenter.jl"))
3030
end
31+
@safetestset "Sinkhorn divergence" begin
32+
include(joinpath("entropic", "sinkhorn_divergence.jl"))
33+
end
3134
end
3235

3336
@safetestset "Quadratically regularized OT" begin

0 commit comments

Comments
 (0)