Skip to content

Commit 4342fb0

Browse files
fix solvers
1 parent 8e79b24 commit 4342fb0

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

ot/solvers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,9 +1040,17 @@ def solve_gromov(
10401040
# potentials = (log['u'], log['v']) TODO
10411041

10421042
else: # partial FGW
1043-
if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
1044-
raise (ValueError("Partial FGW mass given in reg is too large"))
1043+
if unbalanced is None:
1044+
raise (
1045+
ValueError(
1046+
"Partial GW mass given in `unbalanced` must be float and not None"
1047+
)
1048+
)
10451049

1050+
elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
1051+
raise (
1052+
ValueError("Partial GW mass given in `unbalanced` is too large")
1053+
)
10461054
# default values for solver
10471055
if max_iter is None:
10481056
max_iter = 1000

0 commit comments

Comments
 (0)