diff --git a/ChangeLog b/ChangeLog index 4225d5f5e..0c3c97c8c 100644 --- a/ChangeLog +++ b/ChangeLog @@ -7,6 +7,12 @@ What's New in astroid 4.1.0? ============================ Release date: TBA +* Add support for type constraints (`isinstance(x, y)`) in inference. + + Closes pylint-dev/pylint#1162 + Closes pylint-dev/pylint#4635 + Closes pylint-dev/pylint#10469 + * Make `type.__new__()` raise clear errors instead of returning `None` * Move object dunder methods from ``FunctionModel`` to ``ObjectModel`` to make them diff --git a/astroid/brain/brain_builtin_inference.py b/astroid/brain/brain_builtin_inference.py index e21d36141..a2ca95514 100644 --- a/astroid/brain/brain_builtin_inference.py +++ b/astroid/brain/brain_builtin_inference.py @@ -763,7 +763,7 @@ def infer_issubclass(callnode, context: InferenceContext | None = None): # The right hand argument is the class(es) that the given # object is to be checked against. try: - class_container = _class_or_tuple_to_container( + class_container = helpers.class_or_tuple_to_container( class_or_tuple_node, context=context ) except InferenceError as exc: @@ -798,7 +798,7 @@ def infer_isinstance( # The right hand argument is the class(es) that the given # obj is to be check is an instance of try: - class_container = _class_or_tuple_to_container( + class_container = helpers.class_or_tuple_to_container( class_or_tuple_node, context=context ) except InferenceError as exc: @@ -814,30 +814,6 @@ def infer_isinstance( return nodes.Const(isinstance_bool) -def _class_or_tuple_to_container( - node: InferenceResult, context: InferenceContext | None = None -) -> list[InferenceResult]: - # Move inferences results into container - # to simplify later logic - # raises InferenceError if any of the inferences fall through - try: - node_infer = next(node.infer(context=context)) - except StopIteration as e: - raise InferenceError(node=node, context=context) from e - # arg2 MUST be a type or a TUPLE of types - # for isinstance - if isinstance(node_infer, nodes.Tuple): - try: - class_container = [ - next(node.infer(context=context)) for node in node_infer.elts - ] - except StopIteration as e: - raise InferenceError(node=node, context=context) from e - else: - class_container = [node_infer] - return class_container - - def infer_len(node, context: InferenceContext | None = None) -> nodes.Const: """Infer length calls. diff --git a/astroid/constraint.py b/astroid/constraint.py index 692d22d03..693de59b9 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -10,7 +10,8 @@ from collections.abc import Iterator from typing import TYPE_CHECKING -from astroid import nodes, util +from astroid import helpers, nodes, util +from astroid.exceptions import AstroidTypeError, InferenceError, MroError from astroid.typing import InferenceResult if sys.version_info >= (3, 11): @@ -77,7 +78,7 @@ def match( def satisfied_by(self, inferred: InferenceResult) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" # Assume true if uninferable - if isinstance(inferred, util.UninferableBase): + if inferred is util.Uninferable: return True # Return the XOR of self.negate and matches(inferred, self.CONST_NONE) @@ -117,14 +118,61 @@ def satisfied_by(self, inferred: InferenceResult) -> bool: - negate=True: satisfied if boolean value is False """ inferred_booleaness = inferred.bool_value() - if isinstance(inferred, util.UninferableBase) or isinstance( - inferred_booleaness, util.UninferableBase - ): + if inferred is util.Uninferable or inferred_booleaness is util.Uninferable: return True return self.negate ^ inferred_booleaness +class TypeConstraint(Constraint): + """Represents an "isinstance(x, y)" constraint.""" + + def __init__( + self, node: nodes.NodeNG, classinfo: nodes.NodeNG, negate: bool + ) -> None: + super().__init__(node=node, negate=negate) + self.classinfo = classinfo + + @classmethod + def match( + cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Self | None: + """Return a new constraint for node if expr matches the + "isinstance(x, y)" pattern. Else, return None. + """ + is_instance_call = ( + isinstance(expr, nodes.Call) + and isinstance(expr.func, nodes.Name) + and expr.func.name == "isinstance" + and not expr.keywords + and len(expr.args) == 2 + ) + if is_instance_call and _matches(expr.args[0], node): + return cls(node=node, classinfo=expr.args[1], negate=negate) + + return None + + def satisfied_by(self, inferred: InferenceResult) -> bool: + """Return True for uninferable results, or depending on negate flag: + + - negate=False: satisfied when inferred is an instance of the checked types. + - negate=True: satisfied when inferred is not an instance of the checked types. + """ + if inferred is util.Uninferable: + return True + + try: + types = helpers.class_or_tuple_to_container(self.classinfo) + matches_checked_types = helpers.object_isinstance(inferred, types) + + if matches_checked_types is util.Uninferable: + return True + + return self.negate ^ matches_checked_types + except (InferenceError, AstroidTypeError, MroError): + return True + + def get_constraints( expr: _NameNodes, frame: nodes.LocalsDictNodeNG ) -> dict[nodes.If | nodes.IfExp, set[Constraint]]: @@ -159,6 +207,7 @@ def get_constraints( ( NoneConstraint, BooleanConstraint, + TypeConstraint, ) ) """All supported constraint types.""" diff --git a/astroid/helpers.py b/astroid/helpers.py index 9c370aa32..deef3d9fc 100644 --- a/astroid/helpers.py +++ b/astroid/helpers.py @@ -170,6 +170,30 @@ def object_issubclass( return _object_type_is_subclass(node, class_or_seq, context=context) +def class_or_tuple_to_container( + node: InferenceResult, context: InferenceContext | None = None +) -> list[InferenceResult]: + # Move inferences results into container + # to simplify later logic + # raises InferenceError if any of the inferences fall through + try: + node_infer = next(node.infer(context=context)) + except StopIteration as e: # pragma: no cover + raise InferenceError(node=node, context=context) from e + # arg2 MUST be a type or a TUPLE of types + # for isinstance + if isinstance(node_infer, nodes.Tuple): + try: + class_container = [ + next(node.infer(context=context)) for node in node_infer.elts + ] + except StopIteration as e: # pragma: no cover + raise InferenceError(node=node, context=context) from e + else: + class_container = [node_infer] + return class_container + + def has_known_bases(klass, context: InferenceContext | None = None) -> bool: """Return whether all base classes of a class could be inferred.""" try: diff --git a/tests/test_constraint.py b/tests/test_constraint.py index 4859d4241..f69e0e496 100644 --- a/tests/test_constraint.py +++ b/tests/test_constraint.py @@ -5,9 +5,12 @@ """Tests for inference involving constraints.""" from __future__ import annotations +from unittest.mock import patch + import pytest from astroid import builder, nodes +from astroid.bases import Instance from astroid.util import Uninferable @@ -19,6 +22,8 @@ def common_params(node: str) -> pytest.MarkDecorator: (f"{node} is not None", 3, None), (f"{node}", 3, None), (f"not {node}", None, 3), + (f"isinstance({node}, int)", 3, None), + (f"isinstance({node}, (int, str))", 3, None), ), ) @@ -773,3 +778,295 @@ def method(self, x = {fail_val}): assert isinstance(inferred[0], nodes.Const) assert inferred[0].value == fail_val assert inferred[1].value is Uninferable + + +def test_isinstance_equal_types() -> None: + """Test constraint for an object whose type is equal to the checked type.""" + node = builder.extract_node( + """ + class A: + pass + + x = A() + + if isinstance(x, A): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], Instance) + assert isinstance(inferred[0]._proxied, nodes.ClassDef) + assert inferred[0].name == "A" + + +def test_isinstance_subtype() -> None: + """Test constraint for an object whose type is a strict subtype of the checked type.""" + node = builder.extract_node( + """ + class A: + pass + + class B(A): + pass + + x = B() + + if isinstance(x, A): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], Instance) + assert isinstance(inferred[0]._proxied, nodes.ClassDef) + assert inferred[0].name == "B" + + +def test_isinstance_unrelated_types(): + """Test constraint for an object whose type is not related to the checked type.""" + node = builder.extract_node( + """ + class A: + pass + + class B: + pass + + x = A() + + if isinstance(x, B): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +def test_isinstance_supertype(): + """Test constraint for an object whose type is a strict supertype of the checked type.""" + node = builder.extract_node( + """ + class A: + pass + + class B(A): + pass + + x = A() + + if isinstance(x, B): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +def test_isinstance_multiple_inheritance(): + """Test constraint for an object that inherits from more than one parent class.""" + n1, n2, n3 = builder.extract_node( + """ + class A: + pass + + class B: + pass + + class C(A, B): + pass + + x = C() + + if isinstance(x, C): + x #@ + + if isinstance(x, A): + x #@ + + if isinstance(x, B): + x #@ + """ + ) + + for node in (n1, n2, n3): + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], Instance) + assert isinstance(inferred[0]._proxied, nodes.ClassDef) + assert inferred[0].name == "C" + + +def test_isinstance_diamond_inheritance(): + """Test constraint for an object that inherits from parent classes + in diamond inheritance. + """ + n1, n2, n3, n4 = builder.extract_node( + """ + class A(): + pass + + class B(A): + pass + + class C(A): + pass + + class D(B, C): + pass + + x = D() + + if isinstance(x, D): + x #@ + + if isinstance(x, B): + x #@ + + if isinstance(x, C): + x #@ + + if isinstance(x, A): + x #@ + """ + ) + + for node in (n1, n2, n3, n4): + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], Instance) + assert isinstance(inferred[0]._proxied, nodes.ClassDef) + assert inferred[0].name == "D" + + +def test_isinstance_keyword_arguments(): + """Test that constraint does not apply when `isinstance` is called + with keyword arguments. + """ + n1, n2 = builder.extract_node( + """ + x = 3 + + if isinstance(object=x, classinfo=str): + x #@ + + if isinstance(x, str, object=x, classinfo=str): + x #@ + """ + ) + + for node in (n1, n2): + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == 3 + + +def test_isinstance_extra_argument(): + """Test that constraint does not apply when `isinstance` is called + with more than two positional arguments. + """ + node = builder.extract_node( + """ + x = 3 + + if isinstance(x, str, bool): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == 3 + + +def test_isinstance_classinfo_inference_error(): + """Test that constraint is satisfied when `isinstance` is called with + classinfo that raises an inference error. + """ + node = builder.extract_node( + """ + x = 3 + + if isinstance(x, undefined_type): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == 3 + + +def test_isinstance_uninferable_classinfo(): + """Test that constraint is satisfied when `isinstance` is called with + uninferable classinfo. + """ + node = builder.extract_node( + """ + def f(classinfo): + x = 3 + + if isinstance(x, classinfo): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == 3 + + +def test_isinstance_mro_error(): + """Test that constraint is satisfied when computing the object's + method resolution order raises an MRO error. + """ + node = builder.extract_node( + """ + class A(): + pass + + class B(A, A): + pass + + x = B() + + if isinstance(x, A): + x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], Instance) + assert isinstance(inferred[0]._proxied, nodes.ClassDef) + assert inferred[0].name == "B" + + +def test_isinstance_uninferable(): + """Test that constraint is satisfied when `isinstance` inference returns Uninferable.""" + node = builder.extract_node( + """ + x = 3 + + if isinstance(x, str): + x #@ + """ + ) + + with patch( + "astroid.constraint.helpers.object_isinstance", return_value=Uninferable + ): + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == 3 diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 4df145bab..65d978038 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -10,7 +10,7 @@ from astroid import builder, helpers, manager, nodes, raw_building, util from astroid.builder import AstroidBuilder from astroid.const import IS_PYPY -from astroid.exceptions import _NonDeducibleTypeHierarchy +from astroid.exceptions import InferenceError, _NonDeducibleTypeHierarchy from astroid.nodes.node_classes import UNATTACHED_UNKNOWN @@ -275,3 +275,70 @@ def test_safe_infer_shim() -> None: "Import safe_infer from astroid.util; this shim in astroid.helpers will be removed." in records[0].message.args[0] ) + + +def test_class_to_container() -> None: + node = builder.extract_node("""isinstance(3, int)""") + + container = helpers.class_or_tuple_to_container(node.args[1]) + + assert len(container) == 1 + assert isinstance(container[0], nodes.ClassDef) + assert container[0].name == "int" + + +def test_tuple_to_container() -> None: + node = builder.extract_node("""isinstance(3, (int, str))""") + + container = helpers.class_or_tuple_to_container(node.args[1]) + + assert len(container) == 2 + + assert isinstance(container[0], nodes.ClassDef) + assert container[0].name == "int" + + assert isinstance(container[1], nodes.ClassDef) + assert container[1].name == "str" + + +def test_class_to_container_uninferable() -> None: + node = builder.extract_node( + """ + def f(x): + isinstance(3, x) #@ + """ + ) + + container = helpers.class_or_tuple_to_container(node.args[1]) + + assert len(container) == 1 + assert container[0] is util.Uninferable + + +def test_tuple_to_container_uninferable() -> None: + node = builder.extract_node( + """ + def f(x, y): + isinstance(3, (x, y)) #@ + """ + ) + + container = helpers.class_or_tuple_to_container(node.args[1]) + + assert len(container) == 2 + assert container[0] is util.Uninferable + assert container[1] is util.Uninferable + + +def test_class_to_container_inference_error() -> None: + node = builder.extract_node("""isinstance(3, undefined_type)""") + + with pytest.raises(InferenceError): + helpers.class_or_tuple_to_container(node.args[1]) + + +def test_tuple_to_container_inference_error() -> None: + node = builder.extract_node("""isinstance(3, (int, undefined_type))""") + + with pytest.raises(InferenceError): + helpers.class_or_tuple_to_container(node.args[1])