@@ -541,6 +541,7 @@ def bures_barycenter_gradient_descent(
541541 log = False ,
542542 step_size = 1 ,
543543 batch_size = None ,
544+ averaged = False ,
544545 nx = None ,
545546):
546547 r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
@@ -570,6 +571,8 @@ def bures_barycenter_gradient_descent(
570571 step size for the gradient descent, 1 by default
571572 batch_size : int, optional
572573 batch size if use a stochastic gradient descent
574+ averaged : bool, optional
575+ if True, use the averaged procedure of :ref:`[74] <references-OT-bures-barycenter-gradient_descent>`
573576 nx : module, optional
574577 The numerical backend module to use. If not provided, the backend will
575578 be fetched from the input matrices `C`.
@@ -607,7 +610,9 @@ def bures_barycenter_gradient_descent(
607610 Cb = nx .mean (C * weights [:, None , None ], axis = 0 )
608611 Id = nx .eye (C .shape [- 1 ], type_as = Cb )
609612
610- L_grads = []
613+ L_diff = []
614+
615+ Cb_averaged = nx .copy (Cb )
611616
612617 for it in range (num_iter ):
613618 Cb12 = nx .sqrtm (Cb )
@@ -627,40 +632,38 @@ def bures_barycenter_gradient_descent(
627632
628633 # step size from [74] (page 15)
629634 step_size = 2 / (0.7 * (it + 2 / 0.7 + 1 ))
630-
631- # TODO: Add one where we take samples in order, + averaging? cf [74]
632635 else : # gradient descent
633636 M = nx .sqrtm (nx .einsum ("ij,njk,kl -> nil" , Cb12 , C , Cb12 ))
634637 ot_maps = nx .einsum ("ij,njk,kl -> nil" , Cb12_ , M , Cb12_ )
635638 grad_bw = Id - nx .sum (ot_maps * weights [:, None , None ], axis = 0 )
636639
637640 Cnew = exp_bures (Cb , - step_size * grad_bw , nx = nx )
638641
642+ if averaged :
643+ # ot map between Cb_averaged and Cnew
644+ Cb_averaged12 = nx .sqrtm (Cb_averaged )
645+ Cb_averaged12inv = nx .inv (Cb_averaged12 )
646+ M = nx .sqrtm (nx .einsum ("ij,jk,kl->il" , Cb_averaged12 , Cnew , Cb_averaged12 ))
647+ ot_map = nx .einsum ("ij,jk,kl->il" , Cb_averaged12inv , M , Cb_averaged12inv )
648+ map = Id * step_size / (step_size + 1 ) + ot_map / (step_size + 1 )
649+ Cb_averaged = nx .einsum ("ij,jk,kl->il" , map , Cb_averaged , map )
650+
639651 # check convergence
640- if batch_size is not None and batch_size < n :
641- # TODO: criteria for SGD: on gradients? + test SGD
642- # TOO slow, test with value? (but don't want to compute the full barycenter)
643- # + need to make bures_wasserstein_distance batchable (TODO)
644- L_grads .append (nx .sum (grad_bw ** 2 ))
645- diff = np .mean (L_grads )
646-
647- # L_values.append(nx.norm(Cb - Cnew))
648- # print(diff, np.mean(L_values))
649- else :
650- diff = nx .norm (Cb - Cnew )
652+ L_diff .append (nx .norm (Cb - Cnew ))
651653
652- if diff <= eps :
654+ # Criteria to stop
655+ if np .mean (L_diff [- 100 :]) <= eps :
653656 break
654657
655658 Cb = Cnew
656659
657- if diff > eps :
658- print ( "Dit not converge." )
660+ if averaged :
661+ Cb = Cb_averaged
659662
660663 if log :
661664 dict_log = {}
662665 dict_log ["num_iter" ] = it
663- dict_log ["final_diff" ] = diff
666+ dict_log ["final_diff" ] = L_diff [ - 1 ]
664667 return Cb , dict_log
665668 else :
666669 return Cb
@@ -708,7 +711,8 @@ def bures_wasserstein_barycenter(
708711 weights : array-like (k), optional
709712 weights for each distribution
710713 method : str
711- method used for the solver, either 'fixed_point' or 'gradient_descent'
714+ method used for the solver, either 'fixed_point', 'gradient_descent', 'stochastic_gradient_descent' or
715+ 'averaged_stochastic_gradient_descent'
712716 num_iter : int, optional
713717 number of iteration for the fixed point algorithm
714718 eps : float, optional
@@ -756,15 +760,35 @@ def bures_wasserstein_barycenter(
756760 # Compute the mean barycenter
757761 mb = nx .sum (m * weights [:, None ], axis = 0 )
758762
759- if method == "gradient_descent" or batch_size is not None :
763+ if method == "gradient_descent" :
760764 out = bures_barycenter_gradient_descent (
761765 C ,
762766 weights = weights ,
763767 num_iter = num_iter ,
764768 eps = eps ,
765769 log = log ,
766770 step_size = step_size ,
767- batch_size = batch_size ,
771+ nx = nx ,
772+ )
773+ elif method == "stochastic_gradient_descent" :
774+ out = bures_barycenter_gradient_descent (
775+ C ,
776+ weights = weights ,
777+ num_iter = num_iter ,
778+ eps = eps ,
779+ log = log ,
780+ batch_size = 1 if batch_size is None else batch_size ,
781+ nx = nx ,
782+ )
783+ elif method == "averaged_stochastic_gradient_descent" :
784+ out = bures_barycenter_gradient_descent (
785+ C ,
786+ weights = weights ,
787+ num_iter = num_iter ,
788+ eps = eps ,
789+ log = log ,
790+ batch_size = 1 if batch_size is None else batch_size ,
791+ averaged = True ,
768792 nx = nx ,
769793 )
770794 elif method == "fixed_point" :
0 commit comments