@@ -306,29 +306,18 @@ Python function.
306
306
# Examples
307
307
308
308
```jldoctest sinkhorn_unbalanced
309
- julia> μ = [0.5, 0.2, 0.3 ];
309
+ julia> μ = [0.5, 0.5 ];
310
310
311
- julia> ν = [0.0, 1.0 ];
311
+ julia> ν = [0.5, 0.5 ];
312
312
313
- julia> C = [0.0 1.0; 2 .0 0.0; 0.5 1.5 ];
313
+ julia> C = [0.0 1.0; 1 .0 0.0];
314
314
315
- julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
316
- 3×2 Matrix{Float64}:
317
- 0.0 0.5
318
- 0.0 0.2002
319
- 0.0 0.2998
315
+ julia> round.(sinkhorn_unbalanced(μ, ν, C, 1, 1); sigdigits=7)
316
+ 2×2 Matrix{Float64}:
317
+ 0.322054 0.118477
318
+ 0.118477 0.322054
320
319
```
321
320
322
- It is possible to provide multiple target marginals as columns of a matrix. In this case the
323
- optimal transport costs are returned:
324
-
325
- ```jldoctest sinkhorn_unbalanced
326
- julia> ν = [0.0 0.5; 1.0 0.5];
327
-
328
- julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
329
- 2-element Vector{Float64}:
330
- 0.9497
331
- 0.4494
332
321
```
333
322
334
323
See also: [`sinkhorn_unbalanced2`](@ref)
@@ -365,25 +354,14 @@ Python function.
365
354
# Examples
366
355
367
356
```jldoctest sinkhorn_unbalanced2
368
- julia> μ = [0.5, 0.2, 0.3];
369
-
370
- julia> ν = [0.0, 1.0];
371
-
372
- julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
357
+ julia> μ = [0.5, 0.1];
373
358
374
- julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
375
- 0.9497
376
- ```
377
-
378
- It is possible to provide multiple target marginals as columns of a matrix:
359
+ julia> ν = [0.5, 0.5];
379
360
380
- ```jldoctest sinkhorn_unbalanced2
381
- julia> ν = [0.0 0.5; 1.0 0.5];
361
+ julia> C = [0.0 1.0; 1.0 0.0];
382
362
383
- julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
384
- 2-element Vector{Float64}:
385
- 0.9497
386
- 0.4494
363
+ julia> round.(sinkhorn_unbalanced2(μ, ν, C, 1., 1.); sigdigits=8)
364
+ 0.19600125
387
365
```
388
366
389
367
See also: [`sinkhorn_unbalanced`](@ref)
@@ -566,3 +544,52 @@ julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2)
566
544
function mm_unbalanced (a, b, M, reg_m; kwargs... )
567
545
return pot. unbalanced. mm_unbalanced (a, b, M, reg_m; kwargs... )
568
546
end
547
+
548
+
549
+ """
550
+ entropic_partial_wasserstein(a, b, M, reg; kwargs...)
551
+
552
+ Solves the partial optimal transport problem and returns the OT plan
553
+ The function considers the following problem:
554
+
555
+ ```math
556
+ \\ gamma = \\ mathop{\\ arg \\ min}_\\ gamma \\ quad \\ langle \\ gamma,
557
+ \\ mathbf{M} \\ rangle_F + \\ mathrm{reg} \\ cdot\\ Omega(\\ gamma)
558
+
559
+ s.t. \\ gamma \\ mathbf{1} &\\ leq \\ mathbf{a} \\\\
560
+ \\ gamma^T \\ mathbf{1} &\\ leq \\ mathbf{b} \\\\
561
+ \\ gamma &\\ geq 0 \\\\
562
+ \\ mathbf{1}^T \\ gamma^T \\ mathbf{1} = m
563
+ &\\ leq \\ min\\ {\\ |\\ mathbf{a}\\ |_1, \\ |\\ mathbf{b}\\ |_1\\ } \\\\
564
+ ```
565
+
566
+ where :
567
+
568
+ - `M` is the metric cost matrix
569
+ - ``\\ Omega`` is the entropic regularization term, ``\\ Omega=\\ sum_{i,j} \\ gamma_{i,j}\\ log(\\ gamma_{i,j})``
570
+ - `a` and `b` are the sample weights
571
+ - `m` is the amount of mass to be transported
572
+
573
+ This function is a wrapper of the function
574
+ [`entropic_partial_wasserstein`](https://pythonot.github.io/gen_modules/ot.partial.html#ot.partial.entropic_partial_wasserstein) in the
575
+ Python Optimal Transport package. Keyword arguments are listed in the documentation of the
576
+ Python function.
577
+
578
+
579
+ # Examples
580
+
581
+ ```jldoctest
582
+ julia> a = [.1, .2];
583
+
584
+ julia> b = [.1, .1];
585
+
586
+ julia> M = [0. 1.; 2. 3.];
587
+
588
+ julia> round.(entropic_partial_wasserstein(a, b, M, 1, m=0.1), digits=2)
589
+ 2×2 Matrix{Float64}:
590
+ 0.06 0.02
591
+ 0.01 0.0
592
+ """
593
+ function entropic_partial_wasserstein (a, b, M, reg; kwargs... )
594
+ return pot. partial. entropic_partial_wasserstein (a, b, M, reg; kwargs... )
595
+ end
0 commit comments