Skip to content

Add mm_unbalanced function #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PythonOT"
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
authors = ["David Widmann"]
version = "0.1.5"
version = "0.1.6"

[deps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ PythonOT.Smooth.smooth_ot_dual
sinkhorn_unbalanced
sinkhorn_unbalanced2
barycenter_unbalanced
mm_unbalanced
```
3 changes: 2 additions & 1 deletion src/PythonOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export emd,
barycenter_unbalanced,
sinkhorn_unbalanced,
sinkhorn_unbalanced2,
empirical_sinkhorn_divergence
empirical_sinkhorn_divergence,
mm_unbalanced

const pot = PyCall.PyNULL()

Expand Down
57 changes: 44 additions & 13 deletions src/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0];

julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];

julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
3×2 Matrix{Float64}:
0.0 0.499964
0.0 0.200188
0.0 0.29983
0.0 0.5
0.0 0.2002
0.0 0.2998
```

It is possible to provide multiple target marginals as columns of a matrix. In this case the
Expand All @@ -325,10 +325,10 @@ optimal transport costs are returned:
```jldoctest sinkhorn_unbalanced
julia> ν = [0.0 0.5; 1.0 0.5];

julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```

See also: [`sinkhorn_unbalanced2`](@ref)
Expand Down Expand Up @@ -371,20 +371,19 @@ julia> ν = [0.0, 1.0];

julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];

julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
1-element Vector{Float64}:
0.949709
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
0.9497
```

It is possible to provide multiple target marginals as columns of a matrix:

```jldoctest sinkhorn_unbalanced2
julia> ν = [0.0 0.5; 1.0 0.5];

julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```

See also: [`sinkhorn_unbalanced`](@ref)
Expand Down Expand Up @@ -516,3 +515,35 @@ Python function.
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
end

"""
mm_unbalanced(a, b, M, reg_m; kwargs...)

Solve the unbalanced optimal transport problem and return the OT plan.
The function solves the following optimization problem:

```math
W = \\min_\\gamma \\quad \\langle \\gamma, \\mathbf{M} \\rangle_F +
\\mathrm{reg_{m1}} \\cdot \\mathrm{div}(\\gamma \\mathbf{1}, \\mathbf{a}) +
\\mathrm{reg_{m2}} \\cdot \\mathrm{div}(\\gamma^T \\mathbf{1}, \\mathbf{b}) +
\\mathrm{reg} \\cdot \\mathrm{div}(\\gamma, \\mathbf{c}) \\\\

s.t.
\\gamma \\geq 0
```

where:

- ``\\mathbf{M}`` is the (``dim_a``, ``dim_b``) metric cost matrix.
- ``\\mathbf{a}`` and ``\\mathbf{b}`` are source and target unbalanced distributions.
- ``\\mathbf{c}`` is a reference distribution for the regularization.
- ``\\mathrm{reg_m}`` is the marginal relaxation term

This function is a wrapper of the function
[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
Python function.
"""
function mm_unbalanced(a, b, M, reg_m; kwargs...)
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)
end
Loading