Skip to content

Commit 1d74e71

Browse files
committed
better verison quickstart guide
1 parent db9942f commit 1d74e71

File tree

1 file changed

+159
-58
lines changed

1 file changed

+159
-58
lines changed

examples/plot_quickstart_guide.py

Lines changed: 159 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131

3232

3333
# %%
34-
# Data generation
35-
# --------------
34+
# 2D data example
35+
# ---------------
3636
#
3737
# We first generate two sets of samples in 2D that 25 and 50
3838
# samples respectively located on circles. The weights of the samples are
@@ -53,15 +53,34 @@
5353
x2 = np.random.randn(n2, 2)
5454
x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4
5555

56+
# Compute the cost matrix
57+
C = ot.dist(x1, x2) # Squared Euclidean cost matrix by default
58+
5659
# sphinx_gallery_start_ignore
5760
style = {"markeredgecolor": "k"}
5861

62+
63+
def plot_plan(P=None, title="", axis=True):
64+
plot2D_samples_mat(x1, x2, P)
65+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
66+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
67+
if not axis:
68+
pl.axis("off")
69+
pl.title(title)
70+
71+
5972
pl.figure(1, (4, 4))
6073
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
6174
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
6275
pl.legend(loc=0)
6376
pl.title("Source and target distributions")
6477
pl.show()
78+
79+
pl.figure(2, (3.5, 1.7))
80+
pl.imshow(C)
81+
pl.colorbar()
82+
pl.title("Cost matrix C")
83+
6584
# sphinx_gallery_end_ignore
6685

6786
# %%
@@ -139,8 +158,8 @@
139158

140159

141160
# %%
142-
# Solve the Optimal Transport problem with a custom cost matrix
143-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
161+
# Optimal Transport problem with a custom cost matrix
162+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
144163
#
145164
# The cost matrix can be customized by passing it to the more general
146165
# :func:`ot.solve` function. The cost matrix should be a matrix of size
@@ -150,14 +169,17 @@
150169
# In this example, we use the Citybloc distance as the cost matrix.
151170

152171
# Compute the cost matrix
153-
C = ot.dist(x1, x2, metric="cityblock")
172+
C_city = ot.dist(x1, x2, metric="cityblock")
154173

155174
# Solve the OT problem with the custom cost matrix
156-
P_city = ot.solve(C).plan
175+
sol = ot.solve(C_city)
157176
# the parameters a and b are not provided so uniform weights are assumed
177+
P_city = sol.plan
178+
# on empirical data the same can be done with ot.solve_sample :
179+
# sol = ot.solve_sample(x1, x2, metric='cityblock')
158180

159181
# Compute the OT loss (equivalent to ot.solve(C).value)
160-
loss_city = np.sum(P_city * C)
182+
loss_city = sol.value # same as np.sum(P_city * C)
161183

162184
# sphinx_gallery_start_ignore
163185
pl.figure(1, (3, 3))
@@ -192,9 +214,7 @@
192214
# P = ot.emd(a, b, C)
193215
# loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b
194216
#
195-
196-
197-
# %%
217+
#
198218
# Sinkhorn and Regularized OT
199219
# ---------------------------
200220
#
@@ -229,8 +249,7 @@
229249
# The Sinkhorn algorithm can be faster than the exact OT solver for large
230250
# regularization strength but the solution is only an approximation of the
231251
# exact OT problem and the OT plan is not sparse.
232-
233-
# %%
252+
#
234253
# Quadratic Regularized OT
235254
# ~~~~~~~~~~~~~~~~~~~~~~~~~
236255
#
@@ -281,8 +300,7 @@ def df(G):
281300
return G
282301

283302

284-
P_reg = ot.solve_sample(x1, x2, a, b, reg=1e2, reg_type=(f, df)).plan
285-
303+
P_reg = ot.solve_sample(x1, x2, a, b, reg=3, reg_type=(f, df)).plan
286304

287305
# sphinx_gallery_start_ignore
288306
pl.figure(1, (3, 3))
@@ -312,7 +330,7 @@ def df(G):
312330
# Unbalanced and Partial OT
313331
# ----------------------------
314332
#
315-
# Solve the Unbalanced OT problem
333+
# Unbalanced Optimal Transport
316334
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
317335
#
318336
# Unbalanced OT relaxes the marginal constraints and allows for the source and
@@ -393,10 +411,10 @@ def df(G):
393411
# sphinx_gallery_end_ignore
394412
# %%
395413
#
396-
# Gromov-Wasserstein and Fused GW
414+
# Gromov-Wasserstein and Fused Gromov-Wasserstein
397415
# -------------------------------------
398416
#
399-
# Solve the Gromov-Wasserstein problem
417+
# Gromov-Wasserstein and Entropic GW
400418
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401419
#
402420
# The Gromov-Wasserstein distance is a similarity measure between metric
@@ -414,8 +432,7 @@ def df(G):
414432
# Solve the Gromov-Wasserstein problem
415433
sol_gw = ot.solve_gromov(C1, C2, a=a, b=b)
416434
P_gw = sol_gw.plan
417-
loss_gw = sol_gw.value
418-
loss_gw_linear = sol_gw.value_linear # linear part of loss
435+
loss_gw = sol_gw.value # quadratic + reg if reg>0
419436
loss_gw_quad = sol_gw.value_quad # quadratic part of loss
420437

421438
# Solve the Entropic Gromov-Wasserstein problem
@@ -460,9 +477,13 @@ def df(G):
460477
M = C / np.max(C)
461478

462479
# Solve FGW problem with alpha=0.1
463-
P_fgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1).plan # C is cost across spaces
480+
sol = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1)
481+
P_fgw = sol.plan
482+
loss_fgw = sol.value
483+
loss_fgw_linear = sol.value_linear # linear part of loss (wrt M)
484+
loss_fgw_quad = sol.value_quad # quadratic part of loss (wrt C1 and C2)
464485

465-
# SOlve entropic FGW problem with alpha=0.1
486+
# Solve entropic FGW problem with alpha=0.1
466487
P_efgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1, reg=1e-3).plan
467488

468489
# sphinx_gallery_start_ignore
@@ -497,35 +518,6 @@ def df(G):
497518
# loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1)
498519
# loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg)
499520
#
500-
501-
# # Unbalanced Gromov-Wasserstein
502-
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
503-
# #
504-
#
505-
# # Solve the Unbalanced Gromov-Wasserstein problem
506-
# P_gw_unb = ot.solve_gromov(C1, C2, a=a, b=b, unbalanced=1e-2).plan
507-
#
508-
# # Solve the Unbalanced Entropic Gromov-Wasserstein problem
509-
# P_egw_unb = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2, reg_type='KL', unbalanced=1e-2).plan
510-
#
511-
# # sphinx_gallery_start_ignore
512-
# pl.figure(1, (6, 3))
513-
#
514-
# pl.subplot(1, 2, 1)
515-
# plot2D_samples_mat(x1, x2, P_gw_unb)
516-
# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
517-
# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
518-
# pl.title("Unbalanced GW plan")
519-
#
520-
# pl.subplot(1, 2, 2)
521-
# plot2D_samples_mat(x1, x2, P_egw_unb)
522-
# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
523-
# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
524-
# pl.title("Unbalanced Entropic GW plan")
525-
# pl.show()
526-
# # sphinx_gallery_end_ignore
527-
# %%
528-
#
529521
# Large scale OT
530522
# --------------
531523
#
@@ -557,9 +549,8 @@ def df(G):
557549
# recover values for Lazy plan
558550
P12 = P_sink_lazy[1, 2]
559551
P1dots = P_sink_lazy[1, :]
560-
P_sink_lazy_dense = P_sink_lazy[
561-
:
562-
] # convert to dense matrix !!warning this can be memory consuming
552+
# convert to dense matrix !!warning this can be memory consuming
553+
P_sink_lazy_dense = P_sink_lazy[:]
563554

564555
# sphinx_gallery_start_ignore
565556
pl.figure(1, (3, 3))
@@ -575,8 +566,13 @@ def df(G):
575566
pl.show()
576567

577568
# sphinx_gallery_end_ignore
578-
#
579569
# %%
570+
# .. note::
571+
# The lazy Sinkhorn algorithm can be found in the old API with the
572+
# :func:`ot.bregman.empirical_sinkhorn` function with parameter
573+
# :code:`lazy=True`. Similarly the geoloss implementation is available
574+
# with the :func:`ot.bregman.empirical_sinkhorn2_geomloss`.
575+
#
580576
#
581577
# the first example shows how to solve the Sinkhorn problem in a lazy way with
582578
# the default POT implementation. The second example shows how to solve the
@@ -585,7 +581,7 @@ def df(G):
585581
# samples.
586582
#
587583
# Factored and Low rank OT
588-
# ------------------------
584+
# ~~~~~~~~~~~~~~~~~~~~~~~~
589585
#
590586
# The Sinkhorn algorithm can be implemented in a low rank version that
591587
# approximates the OT plan with a low rank matrix. This can be useful to
@@ -594,9 +590,9 @@ def df(G):
594590
#
595591

596592
# Solve the Factored OT problem (use lazy=True for large scale)
597-
P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=8).plan
593+
P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=15).plan
598594

599-
P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=8).plan
595+
P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=10).plan
600596

601597
# sphinx_gallery_start_ignore
602598
pl.figure(1, (6, 3))
@@ -626,8 +622,11 @@ def df(G):
626622
pl.show()
627623

628624
# sphinx_gallery_end_ignore
629-
630625
# %%
626+
# .. note::
627+
# The factored OT problem can be solved with the old API using the
628+
# :func:`ot.factored.factored_optimal_transport` function and the low rank
629+
# OT problem can be solved with the :func:`ot.lowrank.lowrank_sinkhorn` function.
631630
#
632631
# Gaussian OT with Bures-Wasserstein
633632
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -641,4 +640,106 @@ def df(G):
641640
# Compute the Bures-Wasserstein distance
642641
bw_value = ot.solve_sample(x1, x2, a, b, method="gaussian").value
643642

643+
print(f"Exact OT loss = {loss:1.3f}")
644644
print(f"Bures-Wasserstein distance = {bw_value:1.3f}")
645+
646+
# %%
647+
# .. note::
648+
# The Gaussian Wasserstein problem can be solved with the old API using the
649+
# :func:`ot.gaussian.empirical_bures_wasserstein_distance` function.
650+
#
651+
# All OT plans
652+
# ------------
653+
#
654+
# The figure below shows all the OT plans computed in this example.
655+
# The color intensity represents the amount of mass transported
656+
# between the samples.
657+
#
658+
659+
# sphinx_gallery_start_ignore
660+
pl.figure(1, (9, 13))
661+
662+
663+
pl.subplot(4, 3, 1)
664+
plot2D_samples_mat(x1, x2, P)
665+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
666+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
667+
pl.axis("off")
668+
pl.title("OT plan")
669+
670+
pl.subplot(4, 3, 2)
671+
plot2D_samples_mat(x1, x2, P_sink)
672+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
673+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
674+
pl.axis("off")
675+
pl.title("Sinkhorn plan")
676+
677+
pl.subplot(4, 3, 3)
678+
plot2D_samples_mat(x1, x2, P_quad)
679+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
680+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
681+
pl.axis("off")
682+
pl.title("Quadratic reg. plan")
683+
684+
pl.subplot(4, 3, 4)
685+
plot2D_samples_mat(x1, x2, P_unb_kl)
686+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
687+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
688+
pl.axis("off")
689+
pl.title("Unbalanced KL plan")
690+
691+
pl.subplot(4, 3, 5)
692+
plot2D_samples_mat(x1, x2, P_unb_kl_reg)
693+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
694+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
695+
pl.axis("off")
696+
pl.title("Unbalanced KL + reg plan")
697+
698+
pl.subplot(4, 3, 6)
699+
plot2D_samples_mat(x1, x2, P_unb_l2)
700+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
701+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
702+
pl.axis("off")
703+
pl.title("Unbalanced L2 plan")
704+
705+
pl.subplot(4, 3, 7)
706+
plot2D_samples_mat(x1, x2, P_part_const)
707+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
708+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
709+
pl.axis("off")
710+
pl.title("Partial 50% mass plan")
711+
712+
pl.subplot(4, 3, 8)
713+
plot2D_samples_mat(x1, x2, P_fact)
714+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
715+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
716+
pl.axis("off")
717+
pl.title("Factored OT plan")
718+
719+
pl.subplot(4, 3, 9)
720+
plot2D_samples_mat(x1, x2, P_lowrank)
721+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
722+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
723+
pl.axis("off")
724+
pl.title("Low rank OT plan")
725+
726+
pl.subplot(4, 3, 10)
727+
plot2D_samples_mat(x1, x2, P_gw)
728+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
729+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
730+
pl.axis("off")
731+
pl.title("GW plan")
732+
733+
pl.subplot(4, 3, 11)
734+
plot2D_samples_mat(x1, x2, P_egw)
735+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
736+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
737+
pl.axis("off")
738+
pl.title("Entropic GW plan")
739+
740+
pl.subplot(4, 3, 12)
741+
plot2D_samples_mat(x1, x2, P_fgw)
742+
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
743+
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
744+
pl.axis("off")
745+
pl.title("Fused GW plan")

0 commit comments

Comments
 (0)