Skip to content

Commit 671788d

Browse files
trying to fix tests
1 parent 46c4638 commit 671788d

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

ot/solvers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,7 @@ def _bary_sample_bcd(
20322032
list_res = [
20332033
inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples)
20342034
]
2035-
2035+
print("inv_b:", inv_b)
20362036
# Update the estimated barycenter weights in unbalanced cases
20372037
if update_masses:
20382038
b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)])

ot/unbalanced/_lbfgs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div
4646
Divergence used for regularization.
4747
Can take three values: 'entropy' (negative entropy), or
4848
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
49-
of two calable functions returning the reg term and its derivative.
49+
of two callable functions returning the reg term and its derivative.
5050
Note that the callable functions should be able to handle Numpy arrays
51-
and not tesors from the backend
51+
and not tensors from the backend
5252
regm_div: string, optional
5353
Divergence to quantify the difference between the marginals.
5454
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
@@ -218,9 +218,9 @@ def lbfgsb_unbalanced(
218218
Divergence used for regularization.
219219
Can take three values: 'entropy' (negative entropy), or
220220
'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
221-
of two calable functions returning the reg term and its derivative.
221+
of two callable functions returning the reg term and its derivative.
222222
Note that the callable functions should be able to handle Numpy arrays
223-
and not tesors from the backend
223+
and not tensors from the backend
224224
regm_div: string, optional
225225
Divergence to quantify the difference between the marginals.
226226
Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)

test/test_solvers.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,12 @@ def assert_allclose_bary_sol(sol1, sol2):
743743
@pytest.mark.parametrize(
744744
"reg,reg_type,unbalanced,unbalanced_type,warmstart",
745745
itertools.product(
746-
lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False]
746+
lst_reg,
747+
["tuple"],
748+
lst_unbalanced,
749+
lst_unbalanced_type,
750+
[True, False],
751+
# lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False]
747752
),
748753
)
749754
def test_bary_sample_free_support(
@@ -774,7 +779,7 @@ def df(G):
774779
return 2 * G
775780

776781
reg_type = (f, df)
777-
782+
# print('test reg_type:', reg_type[0](None), reg_type[1](None))
778783
# solve default None weights
779784
sol0 = ot.bary_sample(
780785
X_s,
@@ -790,8 +795,10 @@ def df(G):
790795
tol_bary=1e-3,
791796
verbose=True,
792797
)
798+
print("------ [done] sol0 - no backend")
793799

794800
# solve provided uniform weights
801+
795802
sol = ot.bary_sample(
796803
X_s,
797804
n,
@@ -808,6 +815,7 @@ def df(G):
808815
tol_bary=1e-3,
809816
verbose=True,
810817
)
818+
print("------ [done] sol - no backend")
811819

812820
assert_allclose_bary_sol(sol0, sol)
813821

@@ -816,14 +824,25 @@ def df(G):
816824
a_sb = nx.from_numpy(*a_s)
817825
w_sb, bb = nx.from_numpy(w_s, b)
818826

819-
if isinstance(reg_type, tuple):
827+
if reg_type == "tuple":
820828

821-
def f(G):
822-
return nx.sum(G**2)
829+
def fb(G):
830+
return nx.sum(
831+
G**2
832+
) # otherwise we keep previously defined (f, df) as required by inner solver
823833

824-
def df(G):
834+
def dfb(G):
825835
return 2 * G
826836

837+
"""
838+
if (
839+
unbalanced_type.lower() in ["kl", "l2", "tv"]) and (
840+
unbalanced is not None) and (
841+
reg is not None
842+
):
843+
reg_type = (f, df)
844+
else:
845+
"""
827846
reg_type = (f, df)
828847

829848
solb = ot.bary_sample(
@@ -842,6 +861,8 @@ def df(G):
842861
tol_bary=1e-3,
843862
verbose=True,
844863
)
864+
print("------ [done] sol - with backend")
865+
845866
assert_allclose_bary_sol(sol, solb)
846867

847868
except NotImplementedError:

0 commit comments

Comments
 (0)