@@ -206,26 +206,26 @@ def lbfgsb_unbalanced(
206206 loss matrix
207207 reg: float
208208 regularization term >=0
209- c : array-like (dim_a, dim_b), optional (default = None)
210- Reference measure for the regularization.
211- If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
212209 reg_m: float or indexable object of length 1 or 2
213210 Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
214211 If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
215212 then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
216213 If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
217- reg_div: string, optional
214+ c : array-like (dim_a, dim_b), optional (default = None)
215+ Reference measure for the regularization.
216+ If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
217+ reg_div: string or pair of callable functions, optional (default = 'kl')
218218 Divergence used for regularization.
219219 Can take three values: 'entropy' (negative entropy), or
220220 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
221- of two calable functions returning the reg term and its derivative.
221+ of two callable functions returning the reg term and its derivative.
222222 Note that the callable functions should be able to handle Numpy arrays
223- and not tesors from the backend
224- regm_div: string, optional
223+ and not tensors from the backend
224+ regm_div: string, optional (default = 'kl')
225225 Divergence to quantify the difference between the marginals.
226226 Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
227- G0: array-like (dim_a, dim_b)
228- Initialization of the transport matrix
227+ G0: array-like (dim_a, dim_b), optional (default = None)
228+ Initialization of the transport matrix. None corresponds to uniform product.
229229 numItermax : int, optional
230230 Max number of iterations
231231 stopThr : float, optional
@@ -267,26 +267,14 @@ def lbfgsb_unbalanced(
267267 ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
268268 """
269269
270- # wrap the callable function to handle numpy arrays
271- if isinstance (reg_div , tuple ):
272- f0 , df0 = reg_div
273- try :
274- f0 (G0 )
275- df0 (G0 )
276- except BaseException :
277- warnings .warn (
278- "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
279- )
280-
281- def f (x ):
282- return nx .to_numpy (f0 (nx .from_numpy (x , type_as = M0 )))
283-
284- def df (x ):
285- return nx .to_numpy (df0 (nx .from_numpy (x , type_as = M0 )))
286-
287- reg_div = (f , df )
270+ # test settings
271+ regm_div = regm_div .lower ()
272+ if regm_div not in ["kl" , "l2" , "tv" ]:
273+ raise ValueError (
274+ "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'" .format (regm_div )
275+ )
288276
289- else :
277+ if isinstance ( reg_div , str ) :
290278 reg_div = reg_div .lower ()
291279 if reg_div not in ["entropy" , "kl" , "l2" ]:
292280 raise ValueError (
@@ -295,16 +283,11 @@ def df(x):
295283 )
296284 )
297285
298- regm_div = regm_div .lower ()
299- if regm_div not in ["kl" , "l2" , "tv" ]:
300- raise ValueError (
301- "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'" .format (regm_div )
302- )
303-
286+ # convert all inputs to numpy arrays
304287 reg_m1 , reg_m2 = get_parameter_pair (reg_m )
305288
306289 M , a , b = list_to_array (M , a , b )
307- nx = get_backend (M , a , b )
290+ nx = get_backend (M , a , b , G0 )
308291 M0 = M
309292
310293 dim_a , dim_b = M .shape
@@ -315,10 +298,33 @@ def df(x):
315298 b = nx .ones (dim_b , type_as = M ) / dim_b
316299
317300 # convert to numpy
318- a , b , M , reg_m1 , reg_m2 , reg = nx .to_numpy (a , b , M , reg_m1 , reg_m2 , reg )
301+ if nx .__name__ == "numpy" : # remaining parameters which can be arrays
302+ reg_m1 , reg_m2 , reg = nx .to_numpy (reg_m1 , reg_m2 , reg )
303+ else :
304+ a , b , M , reg_m1 , reg_m2 , reg = nx .to_numpy (a , b , M , reg_m1 , reg_m2 , reg )
305+
319306 G0 = a [:, None ] * b [None , :] if G0 is None else nx .to_numpy (G0 )
320307 c = a [:, None ] * b [None , :] if c is None else nx .to_numpy (c )
321308
309+ # wrap the callable function to handle numpy arrays
310+ if isinstance (reg_div , tuple ):
311+ f0 , df0 = reg_div
312+ try :
313+ f0 (G0 )
314+ df0 (G0 )
315+ except BaseException :
316+ warnings .warn (
317+ "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
318+ )
319+
320+ def f (x ):
321+ return nx .to_numpy (f0 (nx .from_numpy (x , type_as = M0 )))
322+
323+ def df (x ):
324+ return nx .to_numpy (df0 (nx .from_numpy (x , type_as = M0 )))
325+
326+ reg_div = (f , df )
327+
322328 _func = _get_loss_unbalanced (a , b , c , M , reg , reg_m1 , reg_m2 , reg_div , regm_div )
323329
324330 res = minimize (
@@ -411,9 +417,9 @@ def lbfgsb_unbalanced2(
411417 Divergence used for regularization.
412418 Can take three values: 'entropy' (negative entropy), or
413419 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
414- of two calable functions returning the reg term and its derivative.
420+ of two callable functions returning the reg term and its derivative.
415421 Note that the callable functions should be able to handle Numpy arrays
416- and not tesors from the backend
422+ and not tensors from the backend
417423 regm_div: string, optional
418424 Divergence to quantify the difference between the marginals.
419425 Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
0 commit comments