diff --git a/mip/entities.py b/mip/entities.py index ef15d73c..fcd3ecd7 100644 --- a/mip/entities.py +++ b/mip/entities.py @@ -65,6 +65,9 @@ class LinExpr: a = 10*x1 + 7*x4 print(a.x) + .. warning:: + Do not pass identical objects in the ``variables`` argument when constructing + a LinExpr manually. """ __slots__ = ["__const", "__expr", "__sense"] @@ -542,6 +545,8 @@ def __add__( self, other: Union["mip.Var", LinExpr, numbers.Real] ) -> Union["mip.Var", LinExpr]: if isinstance(other, Var): + if id(self) == id(other): + return LinExpr([self], [2]) return LinExpr([self, other], [1, 1]) if isinstance(other, LinExpr): return other.__add__(self) @@ -561,6 +566,8 @@ def __sub__( self, other: Union["mip.Var", LinExpr, numbers.Real] ) -> Union["mip.Var", LinExpr]: if isinstance(other, Var): + if id(self) == id(other): + return LinExpr([self], [0]) return LinExpr([self, other], [1, -1]) if isinstance(other, LinExpr): return (-other).__add__(self) @@ -575,6 +582,8 @@ def __rsub__( self, other: Union["mip.Var", LinExpr, numbers.Real] ) -> Union["mip.Var", LinExpr]: if isinstance(other, Var): + if id(self) == id(other): + return LinExpr([self], [0]) return LinExpr([self, other], [-1, 1]) if isinstance(other, LinExpr): return other.__sub__(self) @@ -603,6 +612,8 @@ def __neg__(self) -> LinExpr: def __eq__(self, other) -> LinExpr: if isinstance(other, Var): + if id(self) == id(other): + return LinExpr([self], [0], sense="=") return LinExpr([self, other], [1, -1], sense="=") if isinstance(other, LinExpr): return LinExpr([self], [1]) == other @@ -613,6 +624,8 @@ def __eq__(self, other) -> LinExpr: def __le__(self, other: Union["mip.Var", LinExpr, numbers.Real]) -> LinExpr: if isinstance(other, Var): + if id(self) == id(other): + return LinExpr([self], [0], sense="<") return LinExpr([self, other], [1, -1], sense="<") if isinstance(other, LinExpr): return LinExpr([self], [1]) <= other @@ -623,6 +636,8 @@ def __le__(self, other: Union["mip.Var", LinExpr, numbers.Real]) -> LinExpr: def __ge__(self, other: Union["mip.Var", LinExpr, numbers.Real]) -> LinExpr: if isinstance(other, Var): + if id(self) == id(other): + return LinExpr([self], [0], sense=">") return LinExpr([self, other], [1, -1], sense=">") if isinstance(other, LinExpr): return LinExpr([self], [1]) >= other diff --git a/test/mip_test.py b/test/mip_test.py index 10ab6cb0..b205df14 100644 --- a/test/mip_test.py +++ b/test/mip_test.py @@ -4,6 +4,7 @@ from os import environ import networkx as nx +from mip.entities import LinExpr import mip.gurobi import mip.highs from mip import Model, xsum, OptimizationStatus, MAXIMIZE, BINARY, INTEGER @@ -578,6 +579,28 @@ def test_obj_const2(self, solver: str): assert model.objective_const == 1 +@skip_on(NotImplementedError) +@pytest.mark.parametrize("solver", SOLVERS) +@pytest.mark.parametrize("constraint, lb, ub", [ + (lambda x: x + x >= 3, 1, 2), + (lambda x: x - x >= 0, 1, 2), + (lambda x: x == x, 1, 2), + (lambda x: x >= x, 1, 2), + (lambda x: x <= x, -2, -1), + # (lambda x: LinExpr([x, x, x], [2, -1, -1], sense="="), 1, 2), +]) +def test_identical_vars(solver: str, constraint, lb, ub): + """Try if constraints are correctly added when variables are identical""" + m = Model(solver_name=solver) + x = m.add_var(name="x", lb=lb, ub=ub, obj=1) + + m.add_constr(constraint(x)) + + m.optimize() + assert m.status == OptimizationStatus.OPTIMAL + assert lb - TOL <= x.x <= ub + TOL + + @skip_on(NotImplementedError) @pytest.mark.parametrize("val", range(1, 4)) @pytest.mark.parametrize("solver", SOLVERS)