Skip to content

Commit 62efd4a

Browse files
init commit
1 parent 60d1295 commit 62efd4a

File tree

1 file changed

+44
-38
lines changed

1 file changed

+44
-38
lines changed

ot/unbalanced/_lbfgs.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -206,26 +206,26 @@ def lbfgsb_unbalanced(
206206
loss matrix
207207
reg: float
208208
regularization term >=0
209-
c : array-like (dim_a, dim_b), optional (default = None)
210-
Reference measure for the regularization.
211-
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
212209
reg_m: float or indexable object of length 1 or 2
213210
Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
214211
If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
215212
then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
216213
If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
217-
reg_div: string, optional
214+
c : array-like (dim_a, dim_b), optional (default = None)
215+
Reference measure for the regularization.
216+
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
217+
reg_div: string or pair of callable functions, optional (default = 'kl')
218218
Divergence used for regularization.
219219
Can take three values: 'entropy' (negative entropy), or
220220
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
221-
of two calable functions returning the reg term and its derivative.
221+
of two callable functions returning the reg term and its derivative.
222222
Note that the callable functions should be able to handle Numpy arrays
223-
and not tesors from the backend
224-
regm_div: string, optional
223+
and not tensors from the backend
224+
regm_div: string, optional (default = 'kl')
225225
Divergence to quantify the difference between the marginals.
226226
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
227-
G0: array-like (dim_a, dim_b)
228-
Initialization of the transport matrix
227+
G0: array-like (dim_a, dim_b), optional (default = None)
228+
Initialization of the transport matrix. None corresponds to uniform product.
229229
numItermax : int, optional
230230
Max number of iterations
231231
stopThr : float, optional
@@ -267,26 +267,14 @@ def lbfgsb_unbalanced(
267267
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
268268
"""
269269

270-
# wrap the callable function to handle numpy arrays
271-
if isinstance(reg_div, tuple):
272-
f0, df0 = reg_div
273-
try:
274-
f0(G0)
275-
df0(G0)
276-
except BaseException:
277-
warnings.warn(
278-
"The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
279-
)
280-
281-
def f(x):
282-
return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0)))
283-
284-
def df(x):
285-
return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0)))
286-
287-
reg_div = (f, df)
270+
# test settings
271+
regm_div = regm_div.lower()
272+
if regm_div not in ["kl", "l2", "tv"]:
273+
raise ValueError(
274+
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
275+
)
288276

289-
else:
277+
if isinstance(reg_div, str):
290278
reg_div = reg_div.lower()
291279
if reg_div not in ["entropy", "kl", "l2"]:
292280
raise ValueError(
@@ -295,16 +283,11 @@ def df(x):
295283
)
296284
)
297285

298-
regm_div = regm_div.lower()
299-
if regm_div not in ["kl", "l2", "tv"]:
300-
raise ValueError(
301-
"Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)
302-
)
303-
286+
# convert all inputs to numpy arrays
304287
reg_m1, reg_m2 = get_parameter_pair(reg_m)
305288

306289
M, a, b = list_to_array(M, a, b)
307-
nx = get_backend(M, a, b)
290+
nx = get_backend(M, a, b, G0)
308291
M0 = M
309292

310293
dim_a, dim_b = M.shape
@@ -315,10 +298,33 @@ def df(x):
315298
b = nx.ones(dim_b, type_as=M) / dim_b
316299

317300
# convert to numpy
318-
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)
301+
if nx.__name__ == "numpy": # remaining parameters which can be arrays
302+
reg_m1, reg_m2, reg = nx.to_numpy(reg_m1, reg_m2, reg)
303+
else:
304+
a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg)
305+
319306
G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0)
320307
c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c)
321308

309+
# wrap the callable function to handle numpy arrays
310+
if isinstance(reg_div, tuple):
311+
f0, df0 = reg_div
312+
try:
313+
f0(G0)
314+
df0(G0)
315+
except BaseException:
316+
warnings.warn(
317+
"The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
318+
)
319+
320+
def f(x):
321+
return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0)))
322+
323+
def df(x):
324+
return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0)))
325+
326+
reg_div = (f, df)
327+
322328
_func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div)
323329

324330
res = minimize(
@@ -411,9 +417,9 @@ def lbfgsb_unbalanced2(
411417
Divergence used for regularization.
412418
Can take three values: 'entropy' (negative entropy), or
413419
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
414-
of two calable functions returning the reg term and its derivative.
420+
of two callable functions returning the reg term and its derivative.
415421
Note that the callable functions should be able to handle Numpy arrays
416-
and not tesors from the backend
422+
and not tensors from the backend
417423
regm_div: string, optional
418424
Divergence to quantify the difference between the marginals.
419425
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)

0 commit comments

Comments
 (0)