6161
6262
6363def plot_plan (P = None , title = "" , axis = True ):
64- plot2D_samples_mat (x1 , x2 , P )
64+ if P is not None :
65+ plot2D_samples_mat (x1 , x2 , P )
6566 pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
6667 pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
6768 if not axis :
@@ -70,10 +71,8 @@ def plot_plan(P=None, title="", axis=True):
7071
7172
7273pl .figure (1 , (4 , 4 ))
73- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
74- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
74+ plot_plan (title = "Source and target distributions" )
7575pl .legend (loc = 0 )
76- pl .title ("Source and target distributions" )
7776pl .show ()
7877
7978pl .figure (2 , (3.5 , 1.7 ))
@@ -114,10 +113,7 @@ def plot_plan(P=None, title="", axis=True):
114113pl .figure (1 , (8 , 4 ))
115114
116115pl .subplot (1 , 2 , 1 )
117- plot2D_samples_mat (x1 , x2 , P )
118- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
119- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
120- pl .title ("OT plan P loss={:.3f}" .format (loss ))
116+ plot_plan (P , "OT plan P loss={:.3f}" .format (loss ))
121117
122118pl .subplot (1 , 2 , 2 )
123119pl .scatter (x1 [:, 0 ], x1 [:, 1 ], c = alpha , cmap = "viridis" , edgecolors = "k" )
@@ -183,10 +179,7 @@ def plot_plan(P=None, title="", axis=True):
183179
184180# sphinx_gallery_start_ignore
185181pl .figure (1 , (3 , 3 ))
186- plot2D_samples_mat (x1 , x2 , P )
187- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
188- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
189- pl .title ("OT plan (Citybloc) loss={:.3f}" .format (loss_city ))
182+ plot_plan (P_city , "OT plan (Citybloc) loss={:.3f}" .format (loss_city ))
190183
191184pl .figure (2 , (3 , 1.7 ))
192185pl .imshow (P_city , cmap = "Greys" )
@@ -232,10 +225,7 @@ def plot_plan(P=None, title="", axis=True):
232225
233226# sphinx_gallery_start_ignore
234227pl .figure (1 , (3 , 3 ))
235- plot2D_samples_mat (x1 , x2 , P_sink )
236- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
237- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
238- pl .title ("Sinkhorn OT plan loss={:.3f}" .format (loss_sink ))
228+ plot_plan (P_sink , "Sinkhorn OT plan loss={:.3f}" .format (loss_sink ))
239229pl .show ()
240230
241231pl .figure (2 , (3 , 1.7 ))
@@ -263,22 +253,13 @@ def plot_plan(P=None, title="", axis=True):
263253pl .figure (1 , (9 , 3 ))
264254
265255pl .subplot (1 , 3 , 1 )
266- plot2D_samples_mat (x1 , x2 , P )
267- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
268- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
269- pl .title ("OT plan loss={:.3f}" .format (loss ))
256+ plot_plan (P , "OT plan loss={:.3f}" .format (loss ))
270257
271258pl .subplot (1 , 3 , 2 )
272- plot2D_samples_mat (x1 , x2 , P_sink )
273- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
274- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
275- pl .title ("Sinkhorn plan loss={:.3f}" .format (loss_sink ))
259+ plot_plan (P_sink , "Sinkhorn plan loss={:.3f}" .format (loss_sink ))
276260
277261pl .subplot (1 , 3 , 3 )
278- plot2D_samples_mat (x1 , x2 , P_quad )
279- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
280- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
281- pl .title ("Quadratic plan loss={:.3f}" .format (loss_quad ))
262+ plot_plan (P_quad , "Quadratic reg plan loss={:.3f}" .format (loss_quad ))
282263pl .show ()
283264# sphinx_gallery_end_ignore
284265# %%
@@ -304,10 +285,7 @@ def df(G):
304285
305286# sphinx_gallery_start_ignore
306287pl .figure (1 , (3 , 3 ))
307- plot2D_samples_mat (x1 , x2 , P_reg )
308- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
309- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
310- pl .title ("Custom reg plan" )
288+ plot_plan (P_reg , "User-defined reg plan" )
311289pl .show ()
312290# sphinx_gallery_end_ignore
313291# %%
@@ -354,23 +332,13 @@ def df(G):
354332pl .figure (1 , (9 , 3 ))
355333
356334pl .subplot (1 , 3 , 1 )
357- plot2D_samples_mat (x1 , x2 , P_unb_kl )
358- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
359- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
360- pl .title ("Unbalanced KL plan" )
335+ plot_plan (P_unb_kl , "Unbalanced KL plan" )
361336
362337pl .subplot (1 , 3 , 2 )
363- plot2D_samples_mat (x1 , x2 , P_unb_kl_reg )
364- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
365- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
366- pl .title ("Unbalanced KL + reg plan" )
338+ plot_plan (P_unb_kl_reg , "Unbalanced KL + reg plan" )
367339
368340pl .subplot (1 , 3 , 3 )
369- plot2D_samples_mat (x1 , x2 , P_unb_l2 )
370- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
371- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
372- pl .title ("Unbalanced L2 plan" )
373-
341+ plot_plan (P_unb_l2 , "Unbalanced L2 plan" )
374342pl .show ()
375343# sphinx_gallery_end_ignore
376344# %%
@@ -396,16 +364,10 @@ def df(G):
396364pl .figure (1 , (6 , 3 ))
397365
398366pl .subplot (1 , 2 , 1 )
399- plot2D_samples_mat (x1 , x2 , P_part_pen )
400- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
401- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
402- pl .title ("Partial (Unb. TV) plan" )
367+ plot_plan (P_part_pen , "Partial TV plan" )
403368
404369pl .subplot (1 , 2 , 2 )
405- plot2D_samples_mat (x1 , x2 , P_part_const )
406- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
407- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
408- pl .title ("Partial 50% mass plan" )
370+ plot_plan (P_part_const , "Partial 50% mass plan" )
409371pl .show ()
410372
411373# sphinx_gallery_end_ignore
@@ -442,16 +404,10 @@ def df(G):
442404pl .figure (1 , (6 , 3 ))
443405
444406pl .subplot (1 , 2 , 1 )
445- plot2D_samples_mat (x1 , x2 , P_gw )
446- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
447- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
448- pl .title ("GW plan" )
407+ plot_plan (P_gw , "GW plan" )
449408
450409pl .subplot (1 , 2 , 2 )
451- plot2D_samples_mat (x1 , x2 , P_egw )
452- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
453- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
454- pl .title ("Entropic GW plan" )
410+ plot_plan (P_egw , "Entropic GW plan" )
455411pl .show ()
456412# sphinx_gallery_end_ignore
457413# %%
@@ -490,16 +446,10 @@ def df(G):
490446pl .figure (1 , (6 , 3 ))
491447
492448pl .subplot (1 , 2 , 1 )
493- plot2D_samples_mat (x1 , x2 , P_fgw )
494- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
495- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
496- pl .title ("FGW plan" )
449+ plot_plan (P_fgw , "FGW plan" )
497450
498451pl .subplot (1 , 2 , 2 )
499- plot2D_samples_mat (x1 , x2 , P_efgw )
500- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
501- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
502- pl .title ("Entropic FGW plan" )
452+ plot_plan (P_efgw , "Entropic FGW plan" )
503453pl .show ()
504454
505455# sphinx_gallery_end_ignore
@@ -554,10 +504,7 @@ def df(G):
554504
555505# sphinx_gallery_start_ignore
556506pl .figure (1 , (3 , 3 ))
557- plot2D_samples_mat (x1 , x2 , P_sink_lazy_dense )
558- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
559- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
560- pl .title ("Lazy Sinkhorn OT plan" )
507+ plot_plan (P_sink_lazy_dense , "Lazy Sinkhorn OT plan" )
561508pl .show ()
562509
563510pl .figure (2 , (3 , 1.7 ))
@@ -598,16 +545,10 @@ def df(G):
598545pl .figure (1 , (6 , 3 ))
599546
600547pl .subplot (1 , 2 , 1 )
601- plot2D_samples_mat (x1 , x2 , P_fact )
602- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
603- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
604- pl .title ("Factored OT plan" )
548+ plot_plan (P_fact , "Factored OT plan" )
605549
606550pl .subplot (1 , 2 , 2 )
607- plot2D_samples_mat (x1 , x2 , P_lowrank )
608- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
609- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
610- pl .title ("Low rank OT plan" )
551+ plot_plan (P_lowrank , "Low rank OT plan" )
611552pl .show ()
612553
613554pl .figure (2 , (6 , 1.7 ))
@@ -659,87 +600,40 @@ def df(G):
659600# sphinx_gallery_start_ignore
660601pl .figure (1 , (9 , 13 ))
661602
662-
663603pl .subplot (4 , 3 , 1 )
664- plot2D_samples_mat (x1 , x2 , P )
665- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
666- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
667- pl .axis ("off" )
668- pl .title ("OT plan" )
604+ plot_plan (P , "OT plan" , axis = False )
669605
670606pl .subplot (4 , 3 , 2 )
671- plot2D_samples_mat (x1 , x2 , P_sink )
672- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
673- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
674- pl .axis ("off" )
675- pl .title ("Sinkhorn plan" )
607+ plot_plan (P_sink , "Sinkhorn plan" , axis = False )
676608
677609pl .subplot (4 , 3 , 3 )
678- plot2D_samples_mat (x1 , x2 , P_quad )
679- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
680- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
681- pl .axis ("off" )
682- pl .title ("Quadratic reg. plan" )
610+ plot_plan (P_quad , "Quadratic reg. plan" , axis = False )
683611
684612pl .subplot (4 , 3 , 4 )
685- plot2D_samples_mat (x1 , x2 , P_unb_kl )
686- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
687- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
688- pl .axis ("off" )
689- pl .title ("Unbalanced KL plan" )
613+ plot_plan (P_unb_kl , "Unbalanced KL plan" , axis = False )
690614
691615pl .subplot (4 , 3 , 5 )
692- plot2D_samples_mat (x1 , x2 , P_unb_kl_reg )
693- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
694- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
695- pl .axis ("off" )
696- pl .title ("Unbalanced KL + reg plan" )
616+ plot_plan (P_unb_kl_reg , "Unbalanced KL + reg plan" , axis = False )
697617
698618pl .subplot (4 , 3 , 6 )
699- plot2D_samples_mat (x1 , x2 , P_unb_l2 )
700- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
701- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
702- pl .axis ("off" )
703- pl .title ("Unbalanced L2 plan" )
619+ plot_plan (P_unb_l2 , "Unbalanced L2 plan" , axis = False )
704620
705621pl .subplot (4 , 3 , 7 )
706- plot2D_samples_mat (x1 , x2 , P_part_const )
707- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
708- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
709- pl .axis ("off" )
710- pl .title ("Partial 50% mass plan" )
622+ plot_plan (P_part_const , "Partial 50% mass plan" , axis = False )
711623
712624pl .subplot (4 , 3 , 8 )
713- plot2D_samples_mat (x1 , x2 , P_fact )
714- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
715- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
716- pl .axis ("off" )
717- pl .title ("Factored OT plan" )
625+ plot_plan (P_fact , "Factored OT plan" , axis = False )
718626
719627pl .subplot (4 , 3 , 9 )
720- plot2D_samples_mat (x1 , x2 , P_lowrank )
721- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
722- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
723- pl .axis ("off" )
724- pl .title ("Low rank OT plan" )
628+ plot_plan (P_lowrank , "Low rank OT plan" , axis = False )
725629
726630pl .subplot (4 , 3 , 10 )
727- plot2D_samples_mat (x1 , x2 , P_gw )
728- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
729- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
730- pl .axis ("off" )
731- pl .title ("GW plan" )
631+ plot_plan (P_gw , "GW plan" , axis = False )
732632
733633pl .subplot (4 , 3 , 11 )
734- plot2D_samples_mat (x1 , x2 , P_egw )
735- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
736- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
737- pl .axis ("off" )
738- pl .title ("Entropic GW plan" )
634+ plot_plan (P_egw , "Entropic GW plan" , axis = False )
739635
740636pl .subplot (4 , 3 , 12 )
741- plot2D_samples_mat (x1 , x2 , P_fgw )
742- pl .plot (x1 [:, 0 ], x1 [:, 1 ], "ob" , label = "Source samples" , ** style )
743- pl .plot (x2 [:, 0 ], x2 [:, 1 ], "or" , label = "Target samples" , ** style )
744- pl .axis ("off" )
745- pl .title ("Fused GW plan" )
637+ plot_plan (P_fgw , "Fused GW plan" , axis = False )
638+
639+ pl .show ()
0 commit comments