@@ -199,14 +199,12 @@ def free_support_barycenter(
199199 measures_weights : list of N (k_i,) array-like
200200 Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
201201 representing the weights of each discrete input measure
202-
203202 X_init : (k,d) array-like
204203 Initialization of the support locations (on `k` atoms) of the barycenter
205204 b : (k,) array-like
206205 Initialization of the weights of the barycenter (non-negatives, sum to 1)
207206 weights : (N,) array-like
208207 Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
209-
210208 numItermax : int, optional
211209 Max number of iterations
212210 stopThr : float, optional
@@ -219,13 +217,11 @@ def free_support_barycenter(
219217 If compiled with OpenMP, chooses the number of threads to parallelize.
220218 "max" selects the highest number possible.
221219
222-
223220 Returns
224221 -------
225222 X : (k,d) array-like
226223 Support locations (on k atoms) of the barycenter
227224
228-
229225 .. _references-free-support-barycenter:
230226 References
231227 ----------
@@ -428,20 +424,20 @@ def generalized_free_support_barycenter(
428424 return Y
429425
430426
431- class StoppingCriterionReached (Exception ):
432- pass
433-
434-
435427def free_support_barycenter_generic_costs (
436428 measure_locations ,
437429 measure_weights ,
438430 X_init ,
439431 cost_list ,
440- B ,
432+ ground_bary = None ,
441433 a = None ,
442434 numItermax = 100 ,
443435 stopThr = 1e-5 ,
444436 log = False ,
437+ ground_bary_lr = 1e-2 ,
438+ ground_bary_numItermax = 100 ,
439+ ground_bary_stopThr = 1e-5 ,
440+ ground_bary_solver = "SGD" ,
445441):
446442 r"""
447443 Solves the OT barycenter problem for generic costs using the fixed point
@@ -507,14 +503,15 @@ def free_support_barycenter_generic_costs(
507503 List of K arrays of measure weights, each of shape (m_k).
508504 X_init : array-like
509505 Array of shape (n, d) representing initial barycenter points.
510- cost_list : list of callable
506+ cost_list : list of callable or callable
511507 List of K cost functions :math:`c_k: \mathbb{R}^{n\times
512508 d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times
513- m_k}`.
514- B : callable
509+ m_k}`. If cost_list is a single callable, the same cost is used K times.
510+ ground_bary : callable or None, optional
515511 Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays
516512 of shape (n\times d_K), computing the ground barycenters (broadcasted
517- over n).
513+ over n). If not provided, done with Adam on PyTorch (requires PyTorch
514+ backend)
518515 a : array-like, optional
519516 Array of shape (n,) representing weights of the barycenter
520517 measure.Defaults to uniform.
@@ -524,6 +521,16 @@ def free_support_barycenter_generic_costs(
524521 If the iterations move less than this, terminate (default is 1e-5).
525522 log : bool, optional
526523 Whether to return the log dictionary (default is False).
524+ ground_bary_lr : float, optional
525+ Learning rate for the ground barycenter solver (if auto is used).
526+ ground_bary_numItermax : int, optional
527+ Maximum number of iterations for the ground barycenter solver (if auto
528+ is used).
529+ ground_bary_stopThr : float, optional
530+ Stop threshold for the ground barycenter solver (if auto is used).
531+ ground_bary_solver : str, optional
532+ Solver for auto ground bary solver (torch SGD or Adam). Default is
533+ "SGD".
527534
528535 Returns
529536 -------
@@ -549,49 +556,85 @@ def free_support_barycenter_generic_costs(
549556 See Also
550557 --------
551558 ot.lp.free_support_barycenter : Free support solver for the case where
552- :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter :
553- Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2`
554- with :math:`P_k` linear.
559+ :math:`c_k(x,y) = \lambda_k\ |x-y\|_2^2`.
560+ ot.lp.generalized_free_support_barycenter : Free support solver for the case
561+ where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
555562 """
556563 nx = get_backend (X_init , measure_locations [0 ])
557564 K = len (measure_locations )
558565 n = X_init .shape [0 ]
559566 if a is None :
560567 a = nx .ones (n , type_as = X_init ) / n
568+ if callable (cost_list ): # use the given cost for all K pairs
569+ cost_list = [cost_list ] * K
570+ auto_ground_bary = False
571+
572+ if ground_bary is None :
573+ auto_ground_bary = True
574+ assert str (nx ) == "torch" , (
575+ f"Backend { str (nx )} is not compatible with ground_bary=None, it"
576+ "must be provided if not using PyTorch backend"
577+ )
578+ try :
579+ import torch
580+ from torch .optim import Adam , SGD
581+
582+ def ground_bary (y , x_init ):
583+ x = x_init .clone ().detach ().requires_grad_ (True )
584+ solver = Adam if ground_bary_solver == "Adam" else SGD
585+ opt = solver ([x ], lr = ground_bary_lr )
586+ for _ in range (ground_bary_numItermax ):
587+ x_prev = x .data .clone ()
588+ opt .zero_grad ()
589+ # inefficient cost computation but compatible
590+ # with the choice of cost_list[k] giving the cost matrix
591+ loss = torch .sum (
592+ torch .stack (
593+ [torch .diag (cost_list [k ](x , y [k ])) for k in range (K )]
594+ )
595+ )
596+ loss .backward ()
597+ opt .step ()
598+ diff = torch .sum ((x .data - x_prev ) ** 2 )
599+ if diff < ground_bary_stopThr :
600+ break
601+ return x .detach ()
602+
603+ except ImportError :
604+ raise ImportError ("PyTorch is required to use ground_bary=None" )
605+
561606 X_list = [X_init ] if log else [] # store the iterations
562607 X = X_init
563608 dX_list = [] # store the displacement squared norms
564- exit_status = "Unknown"
565-
566- try :
567- for _ in range (numItermax ):
568- pi_list = [ # compute the pairwise transport plans
569- emd (a , measure_weights [k ], cost_list [k ](X , measure_locations [k ]))
570- for k in range (K )
571- ]
572- Y_perm = []
573- for k in range (K ): # compute barycentric projections
574- Y_perm .append (n * pi_list [k ] @ measure_locations [k ])
575- X_next = B (Y_perm )
576-
577- if log :
578- X_list .append (X_next )
609+ exit_status = "Max iterations reached"
610+
611+ for _ in range (numItermax ):
612+ pi_list = [ # compute the pairwise transport plans
613+ emd (a , measure_weights [k ], cost_list [k ](X , measure_locations [k ]))
614+ for k in range (K )
615+ ]
616+ Y_perm = []
617+ for k in range (K ): # compute barycentric projections
618+ Y_perm .append (n * pi_list [k ] @ measure_locations [k ])
619+ if auto_ground_bary : # use previous position as initialization
620+ X_next = ground_bary (Y_perm , X )
621+ else :
622+ X_next = ground_bary (Y_perm )
579623
580- # stationary criterion: move less than the threshold
581- dX = nx .sum ((X - X_next ) ** 2 )
582- X = X_next
624+ if log :
625+ X_list .append (X_next )
583626
584- if log :
585- dX_list .append (dX )
627+ # stationary criterion: move less than the threshold
628+ dX = nx .sum ((X - X_next ) ** 2 )
629+ X = X_next
586630
587- if dX < stopThr :
588- exit_status = "Stationary Point"
589- raise StoppingCriterionReached
631+ if log :
632+ dX_list .append (dX )
590633
591- exit_status = "Max iterations reached"
592- raise StoppingCriterionReached
634+ if dX < stopThr :
635+ exit_status = "Stationary Point"
636+ break
593637
594- except StoppingCriterionReached :
595- if log :
596- return X , {"X_list" : X_list , "exit_status" : exit_status , "dX_list" : dX_list }
597- return X
638+ if log :
639+ return X , {"X_list" : X_list , "exit_status" : exit_status , "dX_list" : dX_list }
640+ return X
0 commit comments