From ceb747d8ba743a3c22daddc976539d8f7208d50e Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Fri, 5 Jul 2024 13:59:06 +0200 Subject: [PATCH 01/10] Add mm_unbalanced function --- Project.toml | 2 +- src/PythonOT.jl | 3 ++- src/lib.jl | 57 ++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index a21c83a..d8894ac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PythonOT" uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef" authors = ["David Widmann"] -version = "0.1.5" +version = "0.1.6" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/src/PythonOT.jl b/src/PythonOT.jl index 1f67513..d801f3d 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -12,7 +12,8 @@ export emd, barycenter_unbalanced, sinkhorn_unbalanced, sinkhorn_unbalanced2, - empirical_sinkhorn_divergence + empirical_sinkhorn_divergence, + mm_unbalanced const pot = PyCall.PyNULL() diff --git a/src/lib.jl b/src/lib.jl index c13db7a..1d9317d 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0]; julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5]; -julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000) +julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4) 3×2 Matrix{Float64}: - 0.0 0.499964 - 0.0 0.200188 - 0.0 0.29983 + 0.0 0.5 + 0.0 0.2002 + 0.0 0.2998 ``` It is possible to provide multiple target marginals as columns of a matrix. In this case the @@ -325,10 +325,10 @@ optimal transport costs are returned: ```jldoctest sinkhorn_unbalanced julia> ν = [0.0 0.5; 1.0 0.5]; -julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6) +julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4) 2-element Vector{Float64}: - 0.949709 - 0.449411 + 0.9497 + 0.4494 ``` See also: [`sinkhorn_unbalanced2`](@ref) @@ -371,9 +371,8 @@ julia> ν = [0.0, 1.0]; julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5]; -julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6) -1-element Vector{Float64}: - 0.949709 +julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4) +0.9497 ``` It is possible to provide multiple target marginals as columns of a matrix: @@ -381,10 +380,10 @@ It is possible to provide multiple target marginals as columns of a matrix: ```jldoctest sinkhorn_unbalanced2 julia> ν = [0.0 0.5; 1.0 0.5]; -julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6) +julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4) 2-element Vector{Float64}: - 0.949709 - 0.449411 + 0.9497 + 0.4494 ``` See also: [`sinkhorn_unbalanced`](@ref) @@ -516,3 +515,35 @@ Python function. function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...) return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...) end + +""" + mm_unbalanced(a, b, M, reg_m; kwargs...) + +Solve the unbalanced optimal transport problem and return the OT plan. +The function solves the following optimization problem: + +```math + W = \\min_\\gamma \\quad \\langle \\gamma, \\mathbf{M} \\rangle_F + + \\mathrm{reg_{m1}} \\cdot \\mathrm{div}(\\gamma \\mathbf{1}, \\mathbf{a}) + + \\mathrm{reg_{m2}} \\cdot \\mathrm{div}(\\gamma^T \\mathbf{1}, \\mathbf{b}) + + \\mathrm{reg} \\cdot \\mathrm{div}(\\gamma, \\mathbf{c}) \\\\ + + s.t. + \\gamma \\geq 0 +``` + +where: + +- ``\\mathbf{M}`` is the (``dim_a``, ``dim_b``) metric cost matrix. +- ``\\mathbf{a}`` and ``\\mathbf{b}`` are source and target unbalanced distributions. +- ``\\mathbf{c}`` is a reference distribution for the regularization. +- ``\\mathrm{reg_m}`` is the marginal relaxation term + +This function is a wrapper of the function +[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the +Python Optimal Transport package. Keyword arguments are listed in the documentation of the +Python function. +""" +function mm_unbalanced(a, b, M, reg_m; kwargs...) + return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...) +end From a10a3a9227f1d9e670d009ba2352acf74cf4e481 Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Fri, 5 Jul 2024 14:02:44 +0200 Subject: [PATCH 02/10] Update api.md --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index 2730efe..00409b1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -35,4 +35,5 @@ PythonOT.Smooth.smooth_ot_dual sinkhorn_unbalanced sinkhorn_unbalanced2 barycenter_unbalanced +mm_unbalanced ``` From d983109c7694324787cfef85ee40b2d2066b2e74 Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Sat, 6 Jul 2024 14:27:37 +0200 Subject: [PATCH 03/10] Update src/lib.jl Co-authored-by: David Widmann --- src/lib.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/lib.jl b/src/lib.jl index 1d9317d..fb77645 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -523,13 +523,10 @@ Solve the unbalanced optimal transport problem and return the OT plan. The function solves the following optimization problem: ```math - W = \\min_\\gamma \\quad \\langle \\gamma, \\mathbf{M} \\rangle_F + - \\mathrm{reg_{m1}} \\cdot \\mathrm{div}(\\gamma \\mathbf{1}, \\mathbf{a}) + - \\mathrm{reg_{m2}} \\cdot \\mathrm{div}(\\gamma^T \\mathbf{1}, \\mathbf{b}) + - \\mathrm{reg} \\cdot \\mathrm{div}(\\gamma, \\mathbf{c}) \\\\ - - s.t. - \\gamma \\geq 0 +W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F + + \\mathrm{reg_{m1}} \\cdot \\operatorname{div}(\\gamma \\mathbf{1}, a) + + \\mathrm{reg_{m2}} \\cdot \\operatorname{div}(\\gamma^\\mathsf{T} \\mathbf{1}, b) + + \\mathrm{reg} \\cdot \\operatorname{div}(\\gamma, c) ``` where: From f78496999ffa6a3ae1e99660ce1c592d7116ea4e Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Sat, 6 Jul 2024 14:27:52 +0200 Subject: [PATCH 04/10] Update src/lib.jl Co-authored-by: David Widmann --- src/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index fb77645..9a902c8 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -529,7 +529,7 @@ W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F + \\mathrm{reg} \\cdot \\operatorname{div}(\\gamma, c) ``` -where: +where - ``\\mathbf{M}`` is the (``dim_a``, ``dim_b``) metric cost matrix. - ``\\mathbf{a}`` and ``\\mathbf{b}`` are source and target unbalanced distributions. From 2a75060f014389cd106d25993241f064e9e47e9a Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Sat, 6 Jul 2024 14:28:20 +0200 Subject: [PATCH 05/10] Update src/lib.jl Co-authored-by: David Widmann --- src/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 9a902c8..fc5d5e5 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -531,7 +531,7 @@ W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F + where -- ``\\mathbf{M}`` is the (``dim_a``, ``dim_b``) metric cost matrix. +- `M` is the metric cost matrix, - ``\\mathbf{a}`` and ``\\mathbf{b}`` are source and target unbalanced distributions. - ``\\mathbf{c}`` is a reference distribution for the regularization. - ``\\mathrm{reg_m}`` is the marginal relaxation term From d861202fafc6fb599ee8dd64076c741609ed4aad Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Sat, 6 Jul 2024 14:28:40 +0200 Subject: [PATCH 06/10] Update src/lib.jl Co-authored-by: David Widmann --- src/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index fc5d5e5..41f62a3 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -532,7 +532,7 @@ W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F + where - `M` is the metric cost matrix, -- ``\\mathbf{a}`` and ``\\mathbf{b}`` are source and target unbalanced distributions. +- `a` and `b` are source and target unbalanced distributions, - ``\\mathbf{c}`` is a reference distribution for the regularization. - ``\\mathrm{reg_m}`` is the marginal relaxation term From b1938c0d573b8b2df6ed07fedcacd8f8c2dacfff Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Sat, 6 Jul 2024 14:28:49 +0200 Subject: [PATCH 07/10] Update src/lib.jl Co-authored-by: David Widmann --- src/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 41f62a3..8d6d2e0 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -533,7 +533,7 @@ where - `M` is the metric cost matrix, - `a` and `b` are source and target unbalanced distributions, -- ``\\mathbf{c}`` is a reference distribution for the regularization. +- `c` is a reference distribution for the regularization, - ``\\mathrm{reg_m}`` is the marginal relaxation term This function is a wrapper of the function From f62cbc349b33967a1786115c2d719e6a03792418 Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Sat, 6 Jul 2024 14:29:19 +0200 Subject: [PATCH 08/10] Update src/lib.jl Co-authored-by: David Widmann --- src/lib.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 8d6d2e0..6e76db2 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -534,7 +534,8 @@ where - `M` is the metric cost matrix, - `a` and `b` are source and target unbalanced distributions, - `c` is a reference distribution for the regularization, -- ``\\mathrm{reg_m}`` is the marginal relaxation term +- `reg_m` is the marginal relaxation term (if it is a scalar or an indexable object of length 1, then the same term is applied to both marginal relaxations), and +- `reg` is a regularization term. This function is a wrapper of the function [`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the From 1ff7c0283f6c9bd641d09e35ca82fca7b047cfb5 Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Sat, 6 Jul 2024 14:29:41 +0200 Subject: [PATCH 09/10] Update src/lib.jl Co-authored-by: David Widmann --- src/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 6e76db2..b6c037c 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -517,7 +517,7 @@ function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; k end """ - mm_unbalanced(a, b, M, reg_m; kwargs...) + mm_unbalanced(a, b, M, reg_m; reg=0, c=a*b', kwargs...) Solve the unbalanced optimal transport problem and return the OT plan. The function solves the following optimization problem: From 9ac76014d46d0261da135542bd47a3768890f7c3 Mon Sep 17 00:00:00 2001 From: Pierre Navaro Date: Wed, 10 Jul 2024 08:44:56 +0200 Subject: [PATCH 10/10] Add doctest in mm_unbalanced function --- src/lib.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/lib.jl b/src/lib.jl index b6c037c..07bb9af 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -541,6 +541,27 @@ This function is a wrapper of the function [`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the Python Optimal Transport package. Keyword arguments are listed in the documentation of the Python function. + +# Examples + +```jldoctest +julia> a=[.5, .5]; + +julia> b=[.5, .5]; + +julia> M=[1. 36.; 9. 4.]; + +julia> round.(mm_unbalanced(a, b, M, 5, div="kl"), digits=2) +2×2 Matrix{Float64}: + 0.45 0.0 + 0.0 0.34 + +julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2) +2×2 Matrix{Float64}: + 0.4 0.0 + 0.0 0.1 +``` + """ function mm_unbalanced(a, b, M, reg_m; kwargs...) return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)