Skip to content

Commit a61aa32

Browse files
complete solve_gromov
1 parent c76b9a2 commit a61aa32

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

ot/solvers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,8 @@ def solve_gromov(
863863

864864
if reg is None or reg == 0: # exact OT
865865
if unbalanced is None and unbalanced_type.lower() not in [
866-
"semirelaxed"
866+
"semirelaxed",
867+
"partial",
867868
]: # Exact balanced OT
868869
if M is None or alpha == 1: # Gromov-Wasserstein problem
869870
# default values for solver

test/test_solvers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,10 @@ def test_solve_gromov_not_implemented(nx):
518518
ot.solve_gromov(Ca, Cb, unbalanced_type="partial", unbalanced=1.5)
519519
with pytest.raises(ValueError):
520520
ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type="partial", unbalanced=1.5)
521+
with pytest.raises(ValueError):
522+
ot.solve_gromov(Ca, Cb, M, unbalanced_type="partial", unbalanced=1.5)
523+
with pytest.raises(ValueError):
524+
ot.solve_gromov(Ca, Cb, M, reg=1, unbalanced_type="partial", unbalanced=1.5)
521525

522526

523527
def test_solve_sample(nx):

0 commit comments

Comments
 (0)