Skip to content

Commit 4a44da2

Browse files
committed
fix logical bug
1 parent 3168387 commit 4a44da2

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

ot/low_rank/_factor_relaxation.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def solve_balanced_FRLC(
178178
a=a,
179179
b=g_Q,
180180
M=grad_Q,
181+
c=Q,
181182
reg=1 / gamma_k,
182183
reg_m=[float("inf"), tau],
183184
method="sinkhorn_stabilized",
@@ -187,6 +188,7 @@ def solve_balanced_FRLC(
187188
a=b,
188189
b=g_R,
189190
M=grad_R,
191+
c=R,
190192
reg=1 / gamma_k,
191193
reg_m=[float("inf"), tau],
192194
method="sinkhorn_stabilized",
@@ -199,8 +201,14 @@ def solve_balanced_FRLC(
199201

200202
gamma_T = gamma / nx.max(nx.abs(grad_T))
201203

202-
T_new = sinkhorn(
203-
g_R, g_Q, grad_T, reg=1 / gamma_T, method="sinkhorn_log"
204+
T_new = sinkhorn_unbalanced(
205+
M=grad_T,
206+
a=g_Q,
207+
b=g_R,
208+
reg=1 / gamma_T,
209+
c=T,
210+
reg_m=[float("inf"), float("inf")],
211+
method="sinkhorn_stabilized",
204212
) # Shape (r, r)
205213

206214
X_new = nx.dot(nx.dot(nx.diag(1 / g_Q), T_new), nx.diag(1 / g_R)) # Shape (r,r)

0 commit comments

Comments
 (0)