Skip to content

Commit 53806cf

Browse files
zstevedevmotion
andauthored
Unbalanced barycenter (#11)
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent b41007d commit 53806cf

File tree

4 files changed

+55
-2
lines changed

4 files changed

+55
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PythonOT"
22
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ PythonOT.Smooth.smooth_ot_dual
3131
```@docs
3232
sinkhorn_unbalanced
3333
sinkhorn_unbalanced2
34+
barycenter_unbalanced
3435
```

src/PythonOT.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@ module PythonOT
22

33
using PyCall: PyCall
44

5-
export emd, emd2, sinkhorn, sinkhorn2, barycenter, sinkhorn_unbalanced, sinkhorn_unbalanced2
5+
export emd,
6+
emd2,
7+
sinkhorn,
8+
sinkhorn2,
9+
barycenter,
10+
barycenter_unbalanced,
11+
sinkhorn_unbalanced,
12+
sinkhorn_unbalanced2
613

714
const pot = PyCall.PyNULL()
815

src/lib.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,48 @@ true
322322
```
323323
"""
324324
barycenter(A, C, ε; kwargs...) = pot.barycenter(A, C, ε; kwargs...)
325+
326+
"""
327+
barycenter_unbalanced(A, C, ε, λ; kwargs...)
328+
329+
Compute the entropically regularized unbalanced Wasserstein barycenter with histograms `A`, cost matrix
330+
`C`, entropic regularization parameter `ε` and marginal relaxation parameter `λ`.
331+
332+
The Wasserstein barycenter is a histogram and solves
333+
```math
334+
\\inf_{a} \\sum_{i} W_{\\varepsilon,C,\\lambda}(a, a_i),
335+
```
336+
where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\varepsilon,C,\\lambda}(a, a_i)}``
337+
is the optimal transport cost for the entropically regularized optimal transport problem
338+
with marginals ``a`` and ``a_i``, cost matrix ``C``, entropic regularization parameter
339+
``\\varepsilon`` and marginal relaxation parameter ``\\lambda``. Optionally, weights of the histograms ``a_i`` can be provided with the
340+
keyword argument `weights`.
341+
342+
This function is a wrapper of the function
343+
[`barycenter_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.barycenter_unbalanced) in the
344+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
345+
Python function.
346+
347+
# Examples
348+
349+
```jldoctest
350+
julia> A = rand(10, 3);
351+
352+
julia> A ./= sum(A; dims=1);
353+
354+
julia> C = rand(10, 10);
355+
356+
julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4)
357+
false
358+
359+
julia> isapprox(sum(barycenter_unbalanced(
360+
A, C, 0.01, 10_000; method="sinkhorn_stabilized", numItermax=5_000
361+
)), 1; atol=1e-4)
362+
true
363+
```
364+
365+
See also: [`barycenter`](@ref)
366+
"""
367+
function barycenter_unbalanced(A, C, ε, λ; kwargs...)
368+
return pot.barycenter_unbalanced(A, C, ε, λ; kwargs...)
369+
end

0 commit comments

Comments
 (0)