3131
3232
3333# %%
34- # Data generation
35- # --------------
34+ # 2D data example
35+ # ---------------
3636#
3737# We first generate two sets of samples in 2D that 25 and 50
3838# samples respectively located on circles. The weights of the samples are
5353x2 = np .random .randn (n2 , 2 )
5454x2 /= np .sqrt (np .sum (x2 ** 2 , 1 , keepdims = True )) / 4
5555
56+ # Compute the cost matrix
57+ C = ot .dist (x1 , x2 ) # Squared Euclidean cost matrix by default
58+
5659# sphinx_gallery_start_ignore
5760style = {"markeredgecolor" : "k" }
5861
62+
63+ def plot_plan (P = None , title = "" , axis = True ):
64+ plot2D_samples_mat (x1 , x2 , P )
65+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
66+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
67+ if not axis :
68+ pl .axis ("off" )
69+ pl .title (title )
70+
71+
5972pl .figure (1 , (4 , 4 ))
6073pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
6174pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
6275pl .legend (loc = 0 )
6376pl .title ("Source and target distributions" )
6477pl .show ()
78+
79+ pl .figure (2 , (3.5 , 1.7 ))
80+ pl .imshow (C )
81+ pl .colorbar ()
82+ pl .title ("Cost matrix C" )
83+
6584# sphinx_gallery_end_ignore
6685
6786# %%
139158
140159
141160# %%
142- # Solve the Optimal Transport problem with a custom cost matrix
143- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
161+ # Optimal Transport problem with a custom cost matrix
162+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
144163#
145164# The cost matrix can be customized by passing it to the more general
146165# :func:`ot.solve` function. The cost matrix should be a matrix of size
150169# In this example, we use the Citybloc distance as the cost matrix.
151170
152171# Compute the cost matrix
153- C = ot .dist (x1 , x2 , metric = "cityblock" )
172+ C_city = ot .dist (x1 , x2 , metric = "cityblock" )
154173
155174# Solve the OT problem with the custom cost matrix
156- P_city = ot .solve (C ). plan
175+ sol = ot .solve (C_city )
157176# the parameters a and b are not provided so uniform weights are assumed
177+ P_city = sol .plan
178+ # on empirical data the same can be done with ot.solve_sample :
179+ # sol = ot.solve_sample(x1, x2, metric='cityblock')
158180
159181# Compute the OT loss (equivalent to ot.solve(C).value)
160- loss_city = np .sum (P_city * C )
182+ loss_city = sol . value # same as np.sum(P_city * C)
161183
162184# sphinx_gallery_start_ignore
163185pl .figure (1 , (3 , 3 ))
192214# P = ot.emd(a, b, C)
193215# loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b
194216#
195-
196-
197- # %%
217+ #
198218# Sinkhorn and Regularized OT
199219# ---------------------------
200220#
229249# The Sinkhorn algorithm can be faster than the exact OT solver for large
230250# regularization strength but the solution is only an approximation of the
231251# exact OT problem and the OT plan is not sparse.
232-
233- # %%
252+ #
234253# Quadratic Regularized OT
235254# ~~~~~~~~~~~~~~~~~~~~~~~~~
236255#
@@ -281,8 +300,7 @@ def df(G):
281300 return G
282301
283302
284- P_reg = ot .solve_sample (x1 , x2 , a , b , reg = 1e2 , reg_type = (f , df )).plan
285-
303+ P_reg = ot .solve_sample (x1 , x2 , a , b , reg = 3 , reg_type = (f , df )).plan
286304
287305# sphinx_gallery_start_ignore
288306pl .figure (1 , (3 , 3 ))
@@ -312,7 +330,7 @@ def df(G):
312330# Unbalanced and Partial OT
313331# ----------------------------
314332#
315- # Solve the Unbalanced OT problem
333+ # Unbalanced Optimal Transport
316334# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
317335#
318336# Unbalanced OT relaxes the marginal constraints and allows for the source and
@@ -393,10 +411,10 @@ def df(G):
393411# sphinx_gallery_end_ignore
394412# %%
395413#
396- # Gromov-Wasserstein and Fused GW
414+ # Gromov-Wasserstein and Fused Gromov-Wasserstein
397415# -------------------------------------
398416#
399- # Solve the Gromov-Wasserstein problem
417+ # Gromov-Wasserstein and Entropic GW
400418# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401419#
402420# The Gromov-Wasserstein distance is a similarity measure between metric
@@ -414,8 +432,7 @@ def df(G):
414432# Solve the Gromov-Wasserstein problem
415433sol_gw = ot .solve_gromov (C1 , C2 , a = a , b = b )
416434P_gw = sol_gw .plan
417- loss_gw = sol_gw .value
418- loss_gw_linear = sol_gw .value_linear # linear part of loss
435+ loss_gw = sol_gw .value # quadratic + reg if reg>0
419436loss_gw_quad = sol_gw .value_quad # quadratic part of loss
420437
421438# Solve the Entropic Gromov-Wasserstein problem
@@ -460,9 +477,13 @@ def df(G):
460477M = C / np .max (C )
461478
462479# Solve FGW problem with alpha=0.1
463- P_fgw = ot .solve_gromov (C1 , C2 , M , a = a , b = b , alpha = 0.1 ).plan # C is cost across spaces
480+ sol = ot .solve_gromov (C1 , C2 , M , a = a , b = b , alpha = 0.1 )
481+ P_fgw = sol .plan
482+ loss_fgw = sol .value
483+ loss_fgw_linear = sol .value_linear # linear part of loss (wrt M)
484+ loss_fgw_quad = sol .value_quad # quadratic part of loss (wrt C1 and C2)
464485
465- # SOlve entropic FGW problem with alpha=0.1
486+ # Solve entropic FGW problem with alpha=0.1
466487P_efgw = ot .solve_gromov (C1 , C2 , M , a = a , b = b , alpha = 0.1 , reg = 1e-3 ).plan
467488
468489# sphinx_gallery_start_ignore
@@ -497,35 +518,6 @@ def df(G):
497518# loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1)
498519# loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg)
499520#
500-
501- # # Unbalanced Gromov-Wasserstein
502- # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
503- # #
504- #
505- # # Solve the Unbalanced Gromov-Wasserstein problem
506- # P_gw_unb = ot.solve_gromov(C1, C2, a=a, b=b, unbalanced=1e-2).plan
507- #
508- # # Solve the Unbalanced Entropic Gromov-Wasserstein problem
509- # P_egw_unb = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2, reg_type='KL', unbalanced=1e-2).plan
510- #
511- # # sphinx_gallery_start_ignore
512- # pl.figure(1, (6, 3))
513- #
514- # pl.subplot(1, 2, 1)
515- # plot2D_samples_mat(x1, x2, P_gw_unb)
516- # pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
517- # pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
518- # pl.title("Unbalanced GW plan")
519- #
520- # pl.subplot(1, 2, 2)
521- # plot2D_samples_mat(x1, x2, P_egw_unb)
522- # pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
523- # pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
524- # pl.title("Unbalanced Entropic GW plan")
525- # pl.show()
526- # # sphinx_gallery_end_ignore
527- # %%
528- #
529521# Large scale OT
530522# --------------
531523#
@@ -557,9 +549,8 @@ def df(G):
557549# recover values for Lazy plan
558550P12 = P_sink_lazy [1 , 2 ]
559551P1dots = P_sink_lazy [1 , :]
560- P_sink_lazy_dense = P_sink_lazy [
561- :
562- ] # convert to dense matrix !!warning this can be memory consuming
552+ # convert to dense matrix !!warning this can be memory consuming
553+ P_sink_lazy_dense = P_sink_lazy [:]
563554
564555# sphinx_gallery_start_ignore
565556pl .figure (1 , (3 , 3 ))
@@ -575,8 +566,13 @@ def df(G):
575566pl .show ()
576567
577568# sphinx_gallery_end_ignore
578- #
579569# %%
570+ # .. note::
571+ # The lazy Sinkhorn algorithm can be found in the old API with the
572+ # :func:`ot.bregman.empirical_sinkhorn` function with parameter
573+ # :code:`lazy=True`. Similarly the geoloss implementation is available
574+ # with the :func:`ot.bregman.empirical_sinkhorn2_geomloss`.
575+ #
580576#
581577# the first example shows how to solve the Sinkhorn problem in a lazy way with
582578# the default POT implementation. The second example shows how to solve the
@@ -585,7 +581,7 @@ def df(G):
585581# samples.
586582#
587583# Factored and Low rank OT
588- # ------------------------
584+ # ~~~~~~~~~~~~~~~~~~~~~~~~
589585#
590586# The Sinkhorn algorithm can be implemented in a low rank version that
591587# approximates the OT plan with a low rank matrix. This can be useful to
@@ -594,9 +590,9 @@ def df(G):
594590#
595591
596592# Solve the Factored OT problem (use lazy=True for large scale)
597- P_fact = ot .solve_sample (x1 , x2 , a , b , method = "factored" , rank = 8 ).plan
593+ P_fact = ot .solve_sample (x1 , x2 , a , b , method = "factored" , rank = 15 ).plan
598594
599- P_lowrank = ot .solve_sample (x1 , x2 , a , b , reg = 0.1 , method = "lowrank" , rank = 8 ).plan
595+ P_lowrank = ot .solve_sample (x1 , x2 , a , b , reg = 0.1 , method = "lowrank" , rank = 10 ).plan
600596
601597# sphinx_gallery_start_ignore
602598pl .figure (1 , (6 , 3 ))
@@ -626,8 +622,11 @@ def df(G):
626622pl .show ()
627623
628624# sphinx_gallery_end_ignore
629-
630625# %%
626+ # .. note::
627+ # The factored OT problem can be solved with the old API using the
628+ # :func:`ot.factored.factored_optimal_transport` function and the low rank
629+ # OT problem can be solved with the :func:`ot.lowrank.lowrank_sinkhorn` function.
631630#
632631# Gaussian OT with Bures-Wasserstein
633632# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -641,4 +640,106 @@ def df(G):
641640# Compute the Bures-Wasserstein distance
642641bw_value = ot .solve_sample (x1 , x2 , a , b , method = "gaussian" ).value
643642
643+ print (f"Exact OT loss = { loss :1.3f} " )
644644print (f"Bures-Wasserstein distance = { bw_value :1.3f} " )
645+
646+ # %%
647+ # .. note::
648+ # The Gaussian Wasserstein problem can be solved with the old API using the
649+ # :func:`ot.gaussian.empirical_bures_wasserstein_distance` function.
650+ #
651+ # All OT plans
652+ # ------------
653+ #
654+ # The figure below shows all the OT plans computed in this example.
655+ # The color intensity represents the amount of mass transported
656+ # between the samples.
657+ #
658+
659+ # sphinx_gallery_start_ignore
660+ pl .figure (1 , (9 , 13 ))
661+
662+
663+ pl .subplot (4 , 3 , 1 )
664+ plot2D_samples_mat (x1 , x2 , P )
665+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
666+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
667+ pl .axis ("off" )
668+ pl .title ("OT plan" )
669+
670+ pl .subplot (4 , 3 , 2 )
671+ plot2D_samples_mat (x1 , x2 , P_sink )
672+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
673+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
674+ pl .axis ("off" )
675+ pl .title ("Sinkhorn plan" )
676+
677+ pl .subplot (4 , 3 , 3 )
678+ plot2D_samples_mat (x1 , x2 , P_quad )
679+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
680+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
681+ pl .axis ("off" )
682+ pl .title ("Quadratic reg. plan" )
683+
684+ pl .subplot (4 , 3 , 4 )
685+ plot2D_samples_mat (x1 , x2 , P_unb_kl )
686+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
687+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
688+ pl .axis ("off" )
689+ pl .title ("Unbalanced KL plan" )
690+
691+ pl .subplot (4 , 3 , 5 )
692+ plot2D_samples_mat (x1 , x2 , P_unb_kl_reg )
693+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
694+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
695+ pl .axis ("off" )
696+ pl .title ("Unbalanced KL + reg plan" )
697+
698+ pl .subplot (4 , 3 , 6 )
699+ plot2D_samples_mat (x1 , x2 , P_unb_l2 )
700+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
701+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
702+ pl .axis ("off" )
703+ pl .title ("Unbalanced L2 plan" )
704+
705+ pl .subplot (4 , 3 , 7 )
706+ plot2D_samples_mat (x1 , x2 , P_part_const )
707+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
708+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
709+ pl .axis ("off" )
710+ pl .title ("Partial 50% mass plan" )
711+
712+ pl .subplot (4 , 3 , 8 )
713+ plot2D_samples_mat (x1 , x2 , P_fact )
714+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
715+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
716+ pl .axis ("off" )
717+ pl .title ("Factored OT plan" )
718+
719+ pl .subplot (4 , 3 , 9 )
720+ plot2D_samples_mat (x1 , x2 , P_lowrank )
721+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
722+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
723+ pl .axis ("off" )
724+ pl .title ("Low rank OT plan" )
725+
726+ pl .subplot (4 , 3 , 10 )
727+ plot2D_samples_mat (x1 , x2 , P_gw )
728+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
729+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
730+ pl .axis ("off" )
731+ pl .title ("GW plan" )
732+
733+ pl .subplot (4 , 3 , 11 )
734+ plot2D_samples_mat (x1 , x2 , P_egw )
735+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
736+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
737+ pl .axis ("off" )
738+ pl .title ("Entropic GW plan" )
739+
740+ pl .subplot (4 , 3 , 12 )
741+ plot2D_samples_mat (x1 , x2 , P_fgw )
742+ pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
743+ pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
744+ pl .axis ("off" )
745+ pl .title ("Fused GW plan" )
0 commit comments