1010
1111import warnings
1212
13- from ..utils import list_to_array
1413from ..backend import get_backend
14+ from ..utils import list_to_array
15+
16+ _warning_msg = (
17+ "Convolutional Sinkhorn did not converge. "
18+ "Try a larger number of iterations `numItermax` "
19+ "or a larger entropy `reg`."
20+ )
21+
22+
23+ def _get_convol_img_fn (nx , width , height , reg , type_as , log_domain = False ):
24+ """Return the convolution operator for 2D images.
25+
26+ The function constructed is equivalent to blurring on horizontal then vertical directions."""
27+ t1 = nx .linspace (0 , 1 , width , type_as = type_as )
28+ Y1 , X1 = nx .meshgrid (t1 , t1 )
29+ M1 = - ((X1 - Y1 ) ** 2 ) / reg
30+
31+ t2 = nx .linspace (0 , 1 , height , type_as = type_as )
32+ Y2 , X2 = nx .meshgrid (t2 , t2 )
33+ M2 = - ((X2 - Y2 ) ** 2 ) / reg
34+
35+ # If normal domain is selected, we can use M1 and M2 to compute the convolution
36+ if not log_domain :
37+ K1 , K2 = nx .exp (M1 ), nx .exp (M2 )
38+
39+ def convol_imgs (imgs ):
40+ kx = nx .einsum ("...ij,kjl->kil" , K1 , imgs )
41+ kxy = nx .einsum ("...ij,klj->kli" , K2 , kx )
42+ return kxy
43+
44+ # Else, we can use M1 and M2 to compute the convolution in log-domain
45+ else :
46+
47+ def convol_imgs (log_imgs ):
48+ log_imgs = nx .logsumexp (M1 [:, :, None ] + log_imgs [None ], axis = 1 )
49+ log_imgs = nx .logsumexp (M2 [:, :, None ] + log_imgs .T [None ], axis = 1 ).T
50+ return log_imgs
51+
52+ return convol_imgs
53+
54+
55+ def _print_report (ii , err ):
56+ """Print the report of the iteration."""
57+ if ii % 200 == 0 :
58+ print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
59+ print ("{:5d}|{:8e}|" .format (ii , err ))
1560
1661
1762def convolutional_barycenter2d (
@@ -133,37 +178,26 @@ def _convolutional_barycenter2d(
133178 """
134179
135180 A = list_to_array (A )
181+ n_hists , width , height = A .shape
136182
137183 nx = get_backend (A )
138184
139185 if weights is None :
140- weights = nx .ones ((A . shape [ 0 ] ,), type_as = A ) / A . shape [ 0 ]
186+ weights = nx .ones ((n_hists ,), type_as = A ) / n_hists
141187 else :
142- assert len (weights ) == A . shape [ 0 ]
188+ assert len (weights ) == n_hists
143189
144190 if log :
145191 log = {"err" : []}
146192
147- bar = nx .ones (A . shape [ 1 :] , type_as = A )
193+ bar = nx .ones (( width , height ) , type_as = A )
148194 bar /= nx .sum (bar )
149195 U = nx .ones (A .shape , type_as = A )
150196 V = nx .ones (A .shape , type_as = A )
151197 err = 1
152198
153199 # build the convolution operator
154- # this is equivalent to blurring on horizontal then vertical directions
155- t = nx .linspace (0 , 1 , A .shape [1 ], type_as = A )
156- [Y , X ] = nx .meshgrid (t , t )
157- K1 = nx .exp (- ((X - Y ) ** 2 ) / reg )
158-
159- t = nx .linspace (0 , 1 , A .shape [2 ], type_as = A )
160- [Y , X ] = nx .meshgrid (t , t )
161- K2 = nx .exp (- ((X - Y ) ** 2 ) / reg )
162-
163- def convol_imgs (imgs ):
164- kx = nx .einsum ("...ij,kjl->kil" , K1 , imgs )
165- kxy = nx .einsum ("...ij,klj->kli" , K2 , kx )
166- return kxy
200+ convol_imgs = _get_convol_img_fn (nx , width , height , reg , type_as = A )
167201
168202 KU = convol_imgs (U )
169203 for ii in range (numItermax ):
@@ -177,24 +211,18 @@ def convol_imgs(imgs):
177211 # log and verbose print
178212 if log :
179213 log ["err" ].append (err )
180-
181214 if verbose :
182- if ii % 200 == 0 :
183- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
184- print ("{:5d}|{:8e}|" .format (ii , err ))
215+ _print_report (ii , err )
185216 if err < stopThr :
186217 break
187218
188219 else :
189220 if warn :
190- warnings .warn (
191- "Convolutional Sinkhorn did not converge. "
192- "Try a larger number of iterations `numItermax` "
193- "or a larger entropy `reg`."
194- )
221+ warnings .warn (_warning_msg )
195222 if log :
196223 log ["niter" ] = ii
197224 log ["U" ] = U
225+ log ["V" ] = V
198226 return bar , log
199227 else :
200228 return bar
@@ -218,6 +246,8 @@ def _convolutional_barycenter2d_log(
218246 A = list_to_array (A )
219247
220248 nx = get_backend (A )
249+ # This error is raised because we are using mutable assignment in the line
250+ # `log_KU[k] = ...` which is not allowed in Jax and TF.
221251 if nx .__name__ in ("jax" , "tf" ):
222252 raise NotImplementedError (
223253 "Log-domain functions are not yet implemented"
@@ -236,19 +266,7 @@ def _convolutional_barycenter2d_log(
236266
237267 err = 1
238268 # build the convolution operator
239- # this is equivalent to blurring on horizontal then vertical directions
240- t = nx .linspace (0 , 1 , width , type_as = A )
241- [Y , X ] = nx .meshgrid (t , t )
242- M1 = - ((X - Y ) ** 2 ) / reg
243-
244- t = nx .linspace (0 , 1 , height , type_as = A )
245- [Y , X ] = nx .meshgrid (t , t )
246- M2 = - ((X - Y ) ** 2 ) / reg
247-
248- def convol_img (log_img ):
249- log_img = nx .logsumexp (M1 [:, :, None ] + log_img [None ], axis = 1 )
250- log_img = nx .logsumexp (M2 [:, :, None ] + log_img .T [None ], axis = 1 ).T
251- return log_img
269+ convol_img = _get_convol_img_fn (nx , width , height , reg , type_as = A , log_domain = True )
252270
253271 logA = nx .log (A + stabThr )
254272 log_KU , G , F = nx .zeros ((3 , * logA .shape ), type_as = A )
@@ -265,22 +283,15 @@ def convol_img(log_img):
265283 # log and verbose print
266284 if log :
267285 log ["err" ].append (err )
268-
269286 if verbose :
270- if ii % 200 == 0 :
271- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
272- print ("{:5d}|{:8e}|" .format (ii , err ))
287+ _print_report (ii , err )
273288 if err < stopThr :
274289 break
275290 G = log_bar [None , :, :] - log_KU
276291
277292 else :
278293 if warn :
279- warnings .warn (
280- "Convolutional Sinkhorn did not converge. "
281- "Try a larger number of iterations `numItermax` "
282- "or a larger entropy `reg`."
283- )
294+ warnings .warn (_warning_msg )
284295 if log :
285296 log ["niter" ] = ii
286297 return nx .exp (log_bar ), log
@@ -417,23 +428,11 @@ def _convolutional_barycenter2d_debiased(
417428 bar /= width * height
418429 U = nx .ones (A .shape , type_as = A )
419430 V = nx .ones (A .shape , type_as = A )
420- c = nx .ones (A . shape [ 1 :] , type_as = A )
431+ c = nx .ones (( width , height ) , type_as = A )
421432 err = 1
422433
423434 # build the convolution operator
424- # this is equivalent to blurring on horizontal then vertical directions
425- t = nx .linspace (0 , 1 , width , type_as = A )
426- [Y , X ] = nx .meshgrid (t , t )
427- K1 = nx .exp (- ((X - Y ) ** 2 ) / reg )
428-
429- t = nx .linspace (0 , 1 , height , type_as = A )
430- [Y , X ] = nx .meshgrid (t , t )
431- K2 = nx .exp (- ((X - Y ) ** 2 ) / reg )
432-
433- def convol_imgs (imgs ):
434- kx = nx .einsum ("...ij,kjl->kil" , K1 , imgs )
435- kxy = nx .einsum ("...ij,klj->kli" , K2 , kx )
436- return kxy
435+ convol_imgs = _get_convol_img_fn (nx , width , height , reg , type_as = A )
437436
438437 KU = convol_imgs (U )
439438 for ii in range (numItermax ):
@@ -451,26 +450,20 @@ def convol_imgs(imgs):
451450 # log and verbose print
452451 if log :
453452 log ["err" ].append (err )
454-
455453 if verbose :
456- if ii % 200 == 0 :
457- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
458- print ("{:5d}|{:8e}|" .format (ii , err ))
454+ _print_report (ii , err )
459455
460456 # debiased Sinkhorn does not converge monotonically
461457 # guarantee a few iterations are done before stopping
462458 if err < stopThr and ii > 20 :
463459 break
464460 else :
465461 if warn :
466- warnings .warn (
467- "Sinkhorn did not converge. You might want to "
468- "increase the number of iterations `numItermax` "
469- "or the regularization parameter `reg`."
470- )
462+ warnings .warn (_warning_msg )
471463 if log :
472464 log ["niter" ] = ii
473465 log ["U" ] = U
466+ log ["V" ] = V
474467 return bar , log
475468 else :
476469 return bar
@@ -492,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log(
492485 A = list_to_array (A )
493486 n_hists , width , height = A .shape
494487 nx = get_backend (A )
488+ # This error is raised because we are using mutable assignment in the line
489+ # `log_KU[k] = ...` which is not allowed in Jax and TF.
495490 if nx .__name__ in ("jax" , "tf" ):
496491 raise NotImplementedError (
497492 "Log-domain functions are not yet implemented"
@@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log(
507502
508503 err = 1
509504 # build the convolution operator
510- # this is equivalent to blurring on horizontal then vertical directions
511- t = nx .linspace (0 , 1 , width , type_as = A )
512- [Y , X ] = nx .meshgrid (t , t )
513- M1 = - ((X - Y ) ** 2 ) / reg
514-
515- t = nx .linspace (0 , 1 , height , type_as = A )
516- [Y , X ] = nx .meshgrid (t , t )
517- M2 = - ((X - Y ) ** 2 ) / reg
518-
519- def convol_img (log_img ):
520- log_img = nx .logsumexp (M1 [:, :, None ] + log_img [None ], axis = 1 )
521- log_img = nx .logsumexp (M2 [:, :, None ] + log_img .T [None ], axis = 1 ).T
522- return log_img
505+ convol_img = _get_convol_img_fn (nx , width , height , reg , type_as = A , log_domain = True )
523506
524507 logA = nx .log (A + stabThr )
525508 log_bar , c = nx .zeros ((2 , width , height ), type_as = A )
@@ -540,22 +523,15 @@ def convol_img(log_img):
540523 # log and verbose print
541524 if log :
542525 log ["err" ].append (err )
543-
544526 if verbose :
545- if ii % 200 == 0 :
546- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
547- print ("{:5d}|{:8e}|" .format (ii , err ))
527+ _print_report (ii , err )
548528 if err < stopThr and ii > 20 :
549529 break
550530 G = log_bar [None , :, :] - log_KU
551531
552532 else :
553533 if warn :
554- warnings .warn (
555- "Convolutional Sinkhorn did not converge. "
556- "Try a larger number of iterations `numItermax` "
557- "or a larger entropy `reg`."
558- )
534+ warnings .warn (_warning_msg )
559535 if log :
560536 log ["niter" ] = ii
561537 return nx .exp (log_bar ), log
0 commit comments