Skip to content

Commit 8b1f8be

Browse files
committed
better verison quickstart guide
1 parent 1d74e71 commit 8b1f8be

File tree

1 file changed

+36
-142
lines changed

1 file changed

+36
-142
lines changed

examples/plot_quickstart_guide.py

Lines changed: 36 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@
6161

6262

6363
def plot_plan(P=None, title="", axis=True):
64-
plot2D_samples_mat(x1, x2, P)
64+
if P is not None:
65+
plot2D_samples_mat(x1, x2, P)
6566
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
6667
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
6768
if not axis:
@@ -70,10 +71,8 @@ def plot_plan(P=None, title="", axis=True):
7071

7172

7273
pl.figure(1, (4, 4))
73-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
74-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
74+
plot_plan(title="Source and target distributions")
7575
pl.legend(loc=0)
76-
pl.title("Source and target distributions")
7776
pl.show()
7877

7978
pl.figure(2, (3.5, 1.7))
@@ -114,10 +113,7 @@ def plot_plan(P=None, title="", axis=True):
114113
pl.figure(1, (8, 4))
115114

116115
pl.subplot(1, 2, 1)
117-
plot2D_samples_mat(x1, x2, P)
118-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
119-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
120-
pl.title("OT plan P loss={:.3f}".format(loss))
116+
plot_plan(P, "OT plan P loss={:.3f}".format(loss))
121117

122118
pl.subplot(1, 2, 2)
123119
pl.scatter(x1[:, 0], x1[:, 1], c=alpha, cmap="viridis", edgecolors="k")
@@ -183,10 +179,7 @@ def plot_plan(P=None, title="", axis=True):
183179

184180
# sphinx_gallery_start_ignore
185181
pl.figure(1, (3, 3))
186-
plot2D_samples_mat(x1, x2, P)
187-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
188-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
189-
pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city))
182+
plot_plan(P_city, "OT plan (Citybloc) loss={:.3f}".format(loss_city))
190183

191184
pl.figure(2, (3, 1.7))
192185
pl.imshow(P_city, cmap="Greys")
@@ -232,10 +225,7 @@ def plot_plan(P=None, title="", axis=True):
232225

233226
# sphinx_gallery_start_ignore
234227
pl.figure(1, (3, 3))
235-
plot2D_samples_mat(x1, x2, P_sink)
236-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
237-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
238-
pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink))
228+
plot_plan(P_sink, "Sinkhorn OT plan loss={:.3f}".format(loss_sink))
239229
pl.show()
240230

241231
pl.figure(2, (3, 1.7))
@@ -263,22 +253,13 @@ def plot_plan(P=None, title="", axis=True):
263253
pl.figure(1, (9, 3))
264254

265255
pl.subplot(1, 3, 1)
266-
plot2D_samples_mat(x1, x2, P)
267-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
268-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
269-
pl.title("OT plan loss={:.3f}".format(loss))
256+
plot_plan(P, "OT plan loss={:.3f}".format(loss))
270257

271258
pl.subplot(1, 3, 2)
272-
plot2D_samples_mat(x1, x2, P_sink)
273-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
274-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
275-
pl.title("Sinkhorn plan loss={:.3f}".format(loss_sink))
259+
plot_plan(P_sink, "Sinkhorn plan loss={:.3f}".format(loss_sink))
276260

277261
pl.subplot(1, 3, 3)
278-
plot2D_samples_mat(x1, x2, P_quad)
279-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
280-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
281-
pl.title("Quadratic plan loss={:.3f}".format(loss_quad))
262+
plot_plan(P_quad, "Quadratic reg plan loss={:.3f}".format(loss_quad))
282263
pl.show()
283264
# sphinx_gallery_end_ignore
284265
# %%
@@ -304,10 +285,7 @@ def df(G):
304285

305286
# sphinx_gallery_start_ignore
306287
pl.figure(1, (3, 3))
307-
plot2D_samples_mat(x1, x2, P_reg)
308-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
309-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
310-
pl.title("Custom reg plan")
288+
plot_plan(P_reg, "User-defined reg plan")
311289
pl.show()
312290
# sphinx_gallery_end_ignore
313291
# %%
@@ -354,23 +332,13 @@ def df(G):
354332
pl.figure(1, (9, 3))
355333

356334
pl.subplot(1, 3, 1)
357-
plot2D_samples_mat(x1, x2, P_unb_kl)
358-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
359-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
360-
pl.title("Unbalanced KL plan")
335+
plot_plan(P_unb_kl, "Unbalanced KL plan")
361336

362337
pl.subplot(1, 3, 2)
363-
plot2D_samples_mat(x1, x2, P_unb_kl_reg)
364-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
365-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
366-
pl.title("Unbalanced KL + reg plan")
338+
plot_plan(P_unb_kl_reg, "Unbalanced KL + reg plan")
367339

368340
pl.subplot(1, 3, 3)
369-
plot2D_samples_mat(x1, x2, P_unb_l2)
370-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
371-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
372-
pl.title("Unbalanced L2 plan")
373-
341+
plot_plan(P_unb_l2, "Unbalanced L2 plan")
374342
pl.show()
375343
# sphinx_gallery_end_ignore
376344
# %%
@@ -396,16 +364,10 @@ def df(G):
396364
pl.figure(1, (6, 3))
397365

398366
pl.subplot(1, 2, 1)
399-
plot2D_samples_mat(x1, x2, P_part_pen)
400-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
401-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
402-
pl.title("Partial (Unb. TV) plan")
367+
plot_plan(P_part_pen, "Partial TV plan")
403368

404369
pl.subplot(1, 2, 2)
405-
plot2D_samples_mat(x1, x2, P_part_const)
406-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
407-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
408-
pl.title("Partial 50% mass plan")
370+
plot_plan(P_part_const, "Partial 50% mass plan")
409371
pl.show()
410372

411373
# sphinx_gallery_end_ignore
@@ -442,16 +404,10 @@ def df(G):
442404
pl.figure(1, (6, 3))
443405

444406
pl.subplot(1, 2, 1)
445-
plot2D_samples_mat(x1, x2, P_gw)
446-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
447-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
448-
pl.title("GW plan")
407+
plot_plan(P_gw, "GW plan")
449408

450409
pl.subplot(1, 2, 2)
451-
plot2D_samples_mat(x1, x2, P_egw)
452-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
453-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
454-
pl.title("Entropic GW plan")
410+
plot_plan(P_egw, "Entropic GW plan")
455411
pl.show()
456412
# sphinx_gallery_end_ignore
457413
# %%
@@ -490,16 +446,10 @@ def df(G):
490446
pl.figure(1, (6, 3))
491447

492448
pl.subplot(1, 2, 1)
493-
plot2D_samples_mat(x1, x2, P_fgw)
494-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
495-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
496-
pl.title("FGW plan")
449+
plot_plan(P_fgw, "FGW plan")
497450

498451
pl.subplot(1, 2, 2)
499-
plot2D_samples_mat(x1, x2, P_efgw)
500-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
501-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
502-
pl.title("Entropic FGW plan")
452+
plot_plan(P_efgw, "Entropic FGW plan")
503453
pl.show()
504454

505455
# sphinx_gallery_end_ignore
@@ -554,10 +504,7 @@ def df(G):
554504

555505
# sphinx_gallery_start_ignore
556506
pl.figure(1, (3, 3))
557-
plot2D_samples_mat(x1, x2, P_sink_lazy_dense)
558-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
559-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
560-
pl.title("Lazy Sinkhorn OT plan")
507+
plot_plan(P_sink_lazy_dense, "Lazy Sinkhorn OT plan")
561508
pl.show()
562509

563510
pl.figure(2, (3, 1.7))
@@ -598,16 +545,10 @@ def df(G):
598545
pl.figure(1, (6, 3))
599546

600547
pl.subplot(1, 2, 1)
601-
plot2D_samples_mat(x1, x2, P_fact)
602-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
603-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
604-
pl.title("Factored OT plan")
548+
plot_plan(P_fact, "Factored OT plan")
605549

606550
pl.subplot(1, 2, 2)
607-
plot2D_samples_mat(x1, x2, P_lowrank)
608-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
609-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
610-
pl.title("Low rank OT plan")
551+
plot_plan(P_lowrank, "Low rank OT plan")
611552
pl.show()
612553

613554
pl.figure(2, (6, 1.7))
@@ -659,87 +600,40 @@ def df(G):
659600
# sphinx_gallery_start_ignore
660601
pl.figure(1, (9, 13))
661602

662-
663603
pl.subplot(4, 3, 1)
664-
plot2D_samples_mat(x1, x2, P)
665-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
666-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
667-
pl.axis("off")
668-
pl.title("OT plan")
604+
plot_plan(P, "OT plan", axis=False)
669605

670606
pl.subplot(4, 3, 2)
671-
plot2D_samples_mat(x1, x2, P_sink)
672-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
673-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
674-
pl.axis("off")
675-
pl.title("Sinkhorn plan")
607+
plot_plan(P_sink, "Sinkhorn plan", axis=False)
676608

677609
pl.subplot(4, 3, 3)
678-
plot2D_samples_mat(x1, x2, P_quad)
679-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
680-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
681-
pl.axis("off")
682-
pl.title("Quadratic reg. plan")
610+
plot_plan(P_quad, "Quadratic reg. plan", axis=False)
683611

684612
pl.subplot(4, 3, 4)
685-
plot2D_samples_mat(x1, x2, P_unb_kl)
686-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
687-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
688-
pl.axis("off")
689-
pl.title("Unbalanced KL plan")
613+
plot_plan(P_unb_kl, "Unbalanced KL plan", axis=False)
690614

691615
pl.subplot(4, 3, 5)
692-
plot2D_samples_mat(x1, x2, P_unb_kl_reg)
693-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
694-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
695-
pl.axis("off")
696-
pl.title("Unbalanced KL + reg plan")
616+
plot_plan(P_unb_kl_reg, "Unbalanced KL + reg plan", axis=False)
697617

698618
pl.subplot(4, 3, 6)
699-
plot2D_samples_mat(x1, x2, P_unb_l2)
700-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
701-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
702-
pl.axis("off")
703-
pl.title("Unbalanced L2 plan")
619+
plot_plan(P_unb_l2, "Unbalanced L2 plan", axis=False)
704620

705621
pl.subplot(4, 3, 7)
706-
plot2D_samples_mat(x1, x2, P_part_const)
707-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
708-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
709-
pl.axis("off")
710-
pl.title("Partial 50% mass plan")
622+
plot_plan(P_part_const, "Partial 50% mass plan", axis=False)
711623

712624
pl.subplot(4, 3, 8)
713-
plot2D_samples_mat(x1, x2, P_fact)
714-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
715-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
716-
pl.axis("off")
717-
pl.title("Factored OT plan")
625+
plot_plan(P_fact, "Factored OT plan", axis=False)
718626

719627
pl.subplot(4, 3, 9)
720-
plot2D_samples_mat(x1, x2, P_lowrank)
721-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
722-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
723-
pl.axis("off")
724-
pl.title("Low rank OT plan")
628+
plot_plan(P_lowrank, "Low rank OT plan", axis=False)
725629

726630
pl.subplot(4, 3, 10)
727-
plot2D_samples_mat(x1, x2, P_gw)
728-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
729-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
730-
pl.axis("off")
731-
pl.title("GW plan")
631+
plot_plan(P_gw, "GW plan", axis=False)
732632

733633
pl.subplot(4, 3, 11)
734-
plot2D_samples_mat(x1, x2, P_egw)
735-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
736-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
737-
pl.axis("off")
738-
pl.title("Entropic GW plan")
634+
plot_plan(P_egw, "Entropic GW plan", axis=False)
739635

740636
pl.subplot(4, 3, 12)
741-
plot2D_samples_mat(x1, x2, P_fgw)
742-
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style)
743-
pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style)
744-
pl.axis("off")
745-
pl.title("Fused GW plan")
637+
plot_plan(P_fgw, "Fused GW plan", axis=False)
638+
639+
pl.show()

0 commit comments

Comments
 (0)