Skip to content

Commit d62877b

Browse files
committed
Add entropic_partial_wasserstein function
I also fix tests in `sinkhorn_unbalanced` and `sinkhorn_unbalanced2`
1 parent 7c72cf8 commit d62877b

File tree

4 files changed

+70
-36
lines changed

4 files changed

+70
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PythonOT"
22
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ sinkhorn_unbalanced2
3737
barycenter_unbalanced
3838
mm_unbalanced
3939
```
40+
41+
## Partial optimal transport
42+
43+
```@docs
44+
entropic_partial_wasserstein
45+
```

src/PythonOT.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ export emd,
1313
sinkhorn_unbalanced,
1414
sinkhorn_unbalanced2,
1515
empirical_sinkhorn_divergence,
16-
mm_unbalanced
16+
mm_unbalanced,
17+
entropic_partial_wasserstein
1718

1819
const pot = PyCall.PyNULL()
1920

src/lib.jl

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -306,29 +306,18 @@ Python function.
306306
# Examples
307307
308308
```jldoctest sinkhorn_unbalanced
309-
julia> μ = [0.5, 0.2, 0.3];
309+
julia> μ = [0.5, 0.5];
310310
311-
julia> ν = [0.0, 1.0];
311+
julia> ν = [0.5, 0.5];
312312
313-
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
313+
julia> C = [0.0 1.0; 1.0 0.0];
314314
315-
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
316-
3×2 Matrix{Float64}:
317-
0.0 0.5
318-
0.0 0.2002
319-
0.0 0.2998
315+
julia> round.(sinkhorn_unbalanced(μ, ν, C, 1, 1); sigdigits=7)
316+
2×2 Matrix{Float64}:
317+
0.322054 0.118477
318+
0.118477 0.322054
320319
```
321320
322-
It is possible to provide multiple target marginals as columns of a matrix. In this case the
323-
optimal transport costs are returned:
324-
325-
```jldoctest sinkhorn_unbalanced
326-
julia> ν = [0.0 0.5; 1.0 0.5];
327-
328-
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
329-
2-element Vector{Float64}:
330-
0.9497
331-
0.4494
332321
```
333322
334323
See also: [`sinkhorn_unbalanced2`](@ref)
@@ -365,25 +354,14 @@ Python function.
365354
# Examples
366355
367356
```jldoctest sinkhorn_unbalanced2
368-
julia> μ = [0.5, 0.2, 0.3];
369-
370-
julia> ν = [0.0, 1.0];
371-
372-
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
357+
julia> μ = [0.5, 0.1];
373358
374-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
375-
0.9497
376-
```
377-
378-
It is possible to provide multiple target marginals as columns of a matrix:
359+
julia> ν = [0.5, 0.5];
379360
380-
```jldoctest sinkhorn_unbalanced2
381-
julia> ν = [0.0 0.5; 1.0 0.5];
361+
julia> C = [0.0 1.0; 1.0 0.0];
382362
383-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
384-
2-element Vector{Float64}:
385-
0.9497
386-
0.4494
363+
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 1., 1.); sigdigits=8)
364+
0.19600125
387365
```
388366
389367
See also: [`sinkhorn_unbalanced`](@ref)
@@ -566,3 +544,52 @@ julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2)
566544
function mm_unbalanced(a, b, M, reg_m; kwargs...)
567545
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)
568546
end
547+
548+
549+
"""
550+
entropic_partial_wasserstein(a, b, M, reg; kwargs...)
551+
552+
Solves the partial optimal transport problem and returns the OT plan
553+
The function considers the following problem:
554+
555+
```math
556+
\\gamma = \\mathop{\\arg \\min}_\\gamma \\quad \\langle \\gamma,
557+
\\mathbf{M} \\rangle_F + \\mathrm{reg} \\cdot\\Omega(\\gamma)
558+
559+
s.t. \\gamma \\mathbf{1} &\\leq \\mathbf{a} \\\\
560+
\\gamma^T \\mathbf{1} &\\leq \\mathbf{b} \\\\
561+
\\gamma &\\geq 0 \\\\
562+
\\mathbf{1}^T \\gamma^T \\mathbf{1} = m
563+
&\\leq \\min\\{\\|\\mathbf{a}\\|_1, \\|\\mathbf{b}\\|_1\\} \\\\
564+
```
565+
566+
where :
567+
568+
- `M` is the metric cost matrix
569+
- ``\\Omega`` is the entropic regularization term, ``\\Omega=\\sum_{i,j} \\gamma_{i,j}\\log(\\gamma_{i,j})``
570+
- `a` and `b` are the sample weights
571+
- `m` is the amount of mass to be transported
572+
573+
This function is a wrapper of the function
574+
[`entropic_partial_wasserstein`](https://pythonot.github.io/gen_modules/ot.partial.html#ot.partial.entropic_partial_wasserstein) in the
575+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
576+
Python function.
577+
578+
579+
# Examples
580+
581+
```jldoctest
582+
julia> a = [.1, .2];
583+
584+
julia> b = [.1, .1];
585+
586+
julia> M = [0. 1.; 2. 3.];
587+
588+
julia> round.(entropic_partial_wasserstein(a, b, M, 1, m=0.1), digits=2)
589+
2×2 Matrix{Float64}:
590+
0.06 0.02
591+
0.01 0.0
592+
"""
593+
function entropic_partial_wasserstein(a, b, M, reg; kwargs...)
594+
return pot.partial.entropic_partial_wasserstein(a, b, M, reg; kwargs...)
595+
end

0 commit comments

Comments
 (0)