Skip to content

Add mm_unbalanced function #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PythonOT"
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
authors = ["David Widmann"]
version = "0.1.5"
version = "0.1.6"

[deps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ PythonOT.Smooth.smooth_ot_dual
sinkhorn_unbalanced
sinkhorn_unbalanced2
barycenter_unbalanced
mm_unbalanced
```
3 changes: 2 additions & 1 deletion src/PythonOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export emd,
barycenter_unbalanced,
sinkhorn_unbalanced,
sinkhorn_unbalanced2,
empirical_sinkhorn_divergence
empirical_sinkhorn_divergence,
mm_unbalanced

const pot = PyCall.PyNULL()

Expand Down
55 changes: 42 additions & 13 deletions src/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,11 @@

julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];

julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
3×2 Matrix{Float64}:
0.0 0.499964
0.0 0.200188
0.0 0.29983
0.0 0.5
0.0 0.2002
0.0 0.2998
```

It is possible to provide multiple target marginals as columns of a matrix. In this case the
Expand All @@ -325,10 +325,10 @@
```jldoctest sinkhorn_unbalanced
julia> ν = [0.0 0.5; 1.0 0.5];

julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```

See also: [`sinkhorn_unbalanced2`](@ref)
Expand Down Expand Up @@ -371,20 +371,19 @@

julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];

julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
1-element Vector{Float64}:
0.949709
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
0.9497
```

It is possible to provide multiple target marginals as columns of a matrix:

```jldoctest sinkhorn_unbalanced2
julia> ν = [0.0 0.5; 1.0 0.5];

julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```

See also: [`sinkhorn_unbalanced`](@ref)
Expand Down Expand Up @@ -516,3 +515,33 @@
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
end

"""
mm_unbalanced(a, b, M, reg_m; reg=0, c=a*b', kwargs...)

Solve the unbalanced optimal transport problem and return the OT plan.
The function solves the following optimization problem:

```math
W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F +
\\mathrm{reg_{m1}} \\cdot \\operatorname{div}(\\gamma \\mathbf{1}, a) +
\\mathrm{reg_{m2}} \\cdot \\operatorname{div}(\\gamma^\\mathsf{T} \\mathbf{1}, b) +
\\mathrm{reg} \\cdot \\operatorname{div}(\\gamma, c)
```

where

- `M` is the metric cost matrix,
- `a` and `b` are source and target unbalanced distributions,
- `c` is a reference distribution for the regularization,
- `reg_m` is the marginal relaxation term (if it is a scalar or an indexable object of length 1, then the same term is applied to both marginal relaxations), and
- `reg` is a regularization term.

This function is a wrapper of the function
[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
Python function.
"""
function mm_unbalanced(a, b, M, reg_m; kwargs...)
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)

Check warning on line 546 in src/lib.jl

View check run for this annotation

Codecov / codecov/patch

src/lib.jl#L545-L546

Added lines #L545 - L546 were not covered by tests
end
Loading