From a48015a3d60f0bafe59fc7f83ecf80d02b7f99d4 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 24 Jul 2025 18:53:52 +0100 Subject: [PATCH 1/6] Try simple-minded call expression cache --- mypy/binder.py | 3 +++ mypy/checker.py | 2 ++ mypy/checkexpr.py | 48 +++++++++++++++++++++++++++++++++++++++++++---- mypy/errors.py | 6 ++++-- 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index d3482d1dad4f..9456867c1df1 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -137,6 +137,7 @@ def __init__(self, options: Options) -> None: # is added to the binder. This allows more precise narrowing and more # flexible inference of variable types (--allow-redefinition-new). self.bind_all = options.allow_redefinition_new + self.version = 0 def _get_id(self) -> int: self.next_id += 1 @@ -158,6 +159,7 @@ def push_frame(self, conditional_frame: bool = False) -> Frame: return f def _put(self, key: Key, type: Type, from_assignment: bool, index: int = -1) -> None: + self.version += 1 self.frames[index].types[key] = CurrentType(type, from_assignment) def _get(self, key: Key, index: int = -1) -> CurrentType | None: @@ -185,6 +187,7 @@ def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> N self._put(key, typ, from_assignment) def unreachable(self) -> None: + self.version += 1 self.frames[-1].unreachable = True def suppress_unreachable_warnings(self) -> None: diff --git a/mypy/checker.py b/mypy/checker.py index 7579c36a97d0..4f0b6016d515 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3018,6 +3018,7 @@ def visit_block(self, b: Block) -> None: self.msg.unreachable_statement(s) break else: + self.expr_checker.expr_cache.clear() self.accept(s) def should_report_unreachable_issues(self) -> bool: @@ -4659,6 +4660,7 @@ def replace_partial_type( ) -> None: """Replace the partial type of var with a non-partial type.""" var.type = new_type + self.binder.version += 1 del partial_types[var] if self.options.allow_redefinition_new: # When using --allow-redefinition-new, binder tracks all types of diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 86f8d9410476..22955729aa5f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -19,7 +19,7 @@ from mypy.checkmember import analyze_member_access, has_operator from mypy.checkstrformat import StringFormatterChecker from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars -from mypy.errors import ErrorWatcher, report_internal_error +from mypy.errors import ErrorInfo, ErrorWatcher, report_internal_error from mypy.expandtype import ( expand_type, expand_type_by_instance, @@ -355,9 +355,15 @@ def __init__( type_state.infer_polymorphic = not self.chk.options.old_type_inference self._arg_infer_context_cache = None + self.expr_cache: dict[ + Expression, + dict[Type | None, tuple[int, Type, list[ErrorInfo], dict[Expression, Type]]], + ] = defaultdict(dict) + self.in_lambda_expr = False def reset(self) -> None: self.resolved_type = {} + self.expr_cache.clear() def visit_name_expr(self, e: NameExpr) -> Type: """Type check a name expression. @@ -5404,6 +5410,8 @@ def find_typeddict_context( def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" + old_in_lambda = self.in_lambda_expr + self.in_lambda_expr = True self.chk.check_default_args(e, body_is_trivial=False) inferred_type, type_override = self.infer_lambda_type_using_context(e) if not inferred_type: @@ -5424,6 +5432,7 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type: ret_type = self.accept(e.expr(), allow_none_return=True) fallback = self.named_type("builtins.function") self.chk.return_types.pop() + self.in_lambda_expr = old_in_lambda return callable_type(e, fallback, ret_type) else: # Type context available. @@ -5436,6 +5445,7 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type: self.accept(e.expr(), allow_none_return=True) ret_type = self.chk.lookup_type(e.expr()) self.chk.return_types.pop() + self.in_lambda_expr = old_in_lambda return replace_callable_return_type(inferred_type, ret_type) def infer_lambda_type_using_context( @@ -5972,14 +5982,28 @@ def accept( old_is_callee = self.is_callee self.is_callee = is_callee try: - if allow_none_return and isinstance(node, CallExpr): - typ = self.visit_call_expr(node, allow_none_return=True) - elif allow_none_return and isinstance(node, YieldFromExpr): + if allow_none_return and isinstance(node, YieldFromExpr): typ = self.visit_yield_from_expr(node, allow_none_return=True) elif allow_none_return and isinstance(node, ConditionalExpr): typ = self.visit_conditional_expr(node, allow_none_return=True) elif allow_none_return and isinstance(node, AwaitExpr): typ = self.visit_await_expr(node, allow_none_return=True) + elif isinstance(node, CallExpr) and not self.in_lambda_expr: + if node in self.expr_cache and type_context in self.expr_cache[node]: + binder_version, typ, messages, type_map = self.expr_cache[node][type_context] + if binder_version == self.chk.binder.version: + self.chk.store_types(type_map) + self.msg.add_errors(messages) + else: + typ = self.accept_maybe_cache( + node, type_context=type_context, allow_none_return=allow_none_return + ) + else: + typ = self.accept_maybe_cache( + node, type_context=type_context, allow_none_return=allow_none_return + ) + elif isinstance(node, CallExpr): + typ = self.visit_call_expr(node, allow_none_return=allow_none_return) else: typ = node.accept(self) except Exception as err: @@ -6010,6 +6034,22 @@ def accept( self.in_expression = False return result + def accept_maybe_cache( + self, node: CallExpr, type_context: Type | None = None, allow_none_return: bool = False + ) -> Type: + binder_version = self.chk.binder.version + type_map: dict[Expression, Type] = {} + self.chk._type_maps.append(type_map) + with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as w: + typ = self.visit_call_expr(node, allow_none_return=allow_none_return) + messages = w.filtered_errors() + if binder_version == self.chk.binder.version and not self.chk.current_node_deferred: + self.expr_cache[node][type_context] = (binder_version, typ, messages, type_map) + self.chk._type_maps.pop() + self.chk.store_types(type_map) + self.msg.add_errors(messages) + return typ + def named_type(self, name: str) -> Instance: """Return an instance type with type given by the name and no type arguments. Alias for TypeChecker.named_type. diff --git a/mypy/errors.py b/mypy/errors.py index 5c135146bcb7..d75c1c62a1ed 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -390,7 +390,7 @@ class Errors: # in some cases to avoid reporting huge numbers of errors. seen_import_error = False - _watchers: list[ErrorWatcher] = [] + _watchers: list[ErrorWatcher] def __init__( self, @@ -421,6 +421,7 @@ def initialize(self) -> None: self.scope = None self.target_module = None self.seen_import_error = False + self._watchers = [] def reset(self) -> None: self.initialize() @@ -931,7 +932,8 @@ def prefer_simple_messages(self) -> bool: if self.file in self.ignored_files: # Errors ignored, so no point generating fancy messages return True - for _watcher in self._watchers: + if self._watchers: + _watcher = self._watchers[-1] if _watcher._filter is True and _watcher._filtered is None: # Errors are filtered return True From f2d15c47f9785c3b26e1b02839db018cdc3baa11 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 25 Jul 2025 09:52:21 +0100 Subject: [PATCH 2/6] Re-organize cache --- mypy/checkexpr.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 22955729aa5f..ae1f2205f6bc 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -356,9 +356,9 @@ def __init__( self._arg_infer_context_cache = None self.expr_cache: dict[ - Expression, - dict[Type | None, tuple[int, Type, list[ErrorInfo], dict[Expression, Type]]], - ] = defaultdict(dict) + tuple[Expression, Type | None], + tuple[int, Type, list[ErrorInfo], dict[Expression, Type]], + ] = {} self.in_lambda_expr = False def reset(self) -> None: @@ -5982,28 +5982,24 @@ def accept( old_is_callee = self.is_callee self.is_callee = is_callee try: - if allow_none_return and isinstance(node, YieldFromExpr): + if allow_none_return and isinstance(node, CallExpr): + typ = self.visit_call_expr(node, allow_none_return=True) + elif allow_none_return and isinstance(node, YieldFromExpr): typ = self.visit_yield_from_expr(node, allow_none_return=True) elif allow_none_return and isinstance(node, ConditionalExpr): typ = self.visit_conditional_expr(node, allow_none_return=True) elif allow_none_return and isinstance(node, AwaitExpr): typ = self.visit_await_expr(node, allow_none_return=True) elif isinstance(node, CallExpr) and not self.in_lambda_expr: - if node in self.expr_cache and type_context in self.expr_cache[node]: - binder_version, typ, messages, type_map = self.expr_cache[node][type_context] + if (node, type_context) in self.expr_cache: + binder_version, typ, messages, type_map = self.expr_cache[(node, type_context)] if binder_version == self.chk.binder.version: self.chk.store_types(type_map) self.msg.add_errors(messages) else: - typ = self.accept_maybe_cache( - node, type_context=type_context, allow_none_return=allow_none_return - ) + typ = self.accept_maybe_cache(node, type_context=type_context) else: - typ = self.accept_maybe_cache( - node, type_context=type_context, allow_none_return=allow_none_return - ) - elif isinstance(node, CallExpr): - typ = self.visit_call_expr(node, allow_none_return=allow_none_return) + typ = self.accept_maybe_cache(node, type_context=type_context) else: typ = node.accept(self) except Exception as err: @@ -6034,17 +6030,15 @@ def accept( self.in_expression = False return result - def accept_maybe_cache( - self, node: CallExpr, type_context: Type | None = None, allow_none_return: bool = False - ) -> Type: + def accept_maybe_cache(self, node: CallExpr, type_context: Type | None = None) -> Type: binder_version = self.chk.binder.version type_map: dict[Expression, Type] = {} self.chk._type_maps.append(type_map) - with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as w: - typ = self.visit_call_expr(node, allow_none_return=allow_none_return) - messages = w.filtered_errors() + with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as msg: + typ = self.visit_call_expr(node) + messages = msg.filtered_errors() if binder_version == self.chk.binder.version and not self.chk.current_node_deferred: - self.expr_cache[node][type_context] = (binder_version, typ, messages, type_map) + self.expr_cache[(node, type_context)] = (binder_version, typ, messages, type_map) self.chk._type_maps.pop() self.chk.store_types(type_map) self.msg.add_errors(messages) From c836f53fda5435b300fd8cad117058fafe8c1bc6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 25 Jul 2025 10:38:22 +0100 Subject: [PATCH 3/6] Try caching also lists and tuples --- mypy/checkexpr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ae1f2205f6bc..d96f7cacacd8 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5990,7 +5990,7 @@ def accept( typ = self.visit_conditional_expr(node, allow_none_return=True) elif allow_none_return and isinstance(node, AwaitExpr): typ = self.visit_await_expr(node, allow_none_return=True) - elif isinstance(node, CallExpr) and not self.in_lambda_expr: + elif isinstance(node, (CallExpr, ListExpr, TupleExpr)) and not self.in_lambda_expr: if (node, type_context) in self.expr_cache: binder_version, typ, messages, type_map = self.expr_cache[(node, type_context)] if binder_version == self.chk.binder.version: @@ -6030,12 +6030,12 @@ def accept( self.in_expression = False return result - def accept_maybe_cache(self, node: CallExpr, type_context: Type | None = None) -> Type: + def accept_maybe_cache(self, node: Expression, type_context: Type | None = None) -> Type: binder_version = self.chk.binder.version type_map: dict[Expression, Type] = {} self.chk._type_maps.append(type_map) with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as msg: - typ = self.visit_call_expr(node) + typ = node.accept(self) messages = msg.filtered_errors() if binder_version == self.chk.binder.version and not self.chk.current_node_deferred: self.expr_cache[(node, type_context)] = (binder_version, typ, messages, type_map) From 32827a88dc8ab925ef11ba06d36fe785419ca4de Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 25 Jul 2025 11:28:34 +0100 Subject: [PATCH 4/6] Skip cache immediately if deferred --- mypy/checker.py | 1 - mypy/checkexpr.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 4f0b6016d515..3760b86564f6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -448,7 +448,6 @@ def reset(self) -> None: self.binder = ConditionalTypeBinder(self.options) self._type_maps[1:] = [] self._type_maps[0].clear() - self.temp_type_map = None self.expr_checker.reset() self.deferred_nodes = [] self.partial_types = [] diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d96f7cacacd8..9648420708f6 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5990,7 +5990,9 @@ def accept( typ = self.visit_conditional_expr(node, allow_none_return=True) elif allow_none_return and isinstance(node, AwaitExpr): typ = self.visit_await_expr(node, allow_none_return=True) - elif isinstance(node, (CallExpr, ListExpr, TupleExpr)) and not self.in_lambda_expr: + elif isinstance(node, (CallExpr, ListExpr, TupleExpr)) and not ( + self.in_lambda_expr or self.chk.current_node_deferred + ): if (node, type_context) in self.expr_cache: binder_version, typ, messages, type_map = self.expr_cache[(node, type_context)] if binder_version == self.chk.binder.version: From 721faccc27f6f18c26c5da124c7b8ca1b39b7e48 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 25 Jul 2025 11:57:44 +0100 Subject: [PATCH 5/6] Add some comments --- mypy/binder.py | 3 +++ mypy/checker.py | 4 +++- mypy/checkexpr.py | 5 +++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mypy/binder.py b/mypy/binder.py index 9456867c1df1..2ae58dad1fe0 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -137,6 +137,9 @@ def __init__(self, options: Options) -> None: # is added to the binder. This allows more precise narrowing and more # flexible inference of variable types (--allow-redefinition-new). self.bind_all = options.allow_redefinition_new + + # This tracks any externally visible changes in binder to invalidate + # expression caches when needed. self.version = 0 def _get_id(self) -> int: diff --git a/mypy/checker.py b/mypy/checker.py index 3760b86564f6..3e2d5f23d876 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3017,8 +3017,9 @@ def visit_block(self, b: Block) -> None: self.msg.unreachable_statement(s) break else: - self.expr_checker.expr_cache.clear() self.accept(s) + # Clear expression cache after each statement to avoid unlimited growth. + self.expr_checker.expr_cache.clear() def should_report_unreachable_issues(self) -> bool: return ( @@ -4659,6 +4660,7 @@ def replace_partial_type( ) -> None: """Replace the partial type of var with a non-partial type.""" var.type = new_type + # Updating a partial type should invalidate expression caches. self.binder.version += 1 del partial_types[var] if self.options.allow_redefinition_new: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9648420708f6..3944c580da74 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5990,6 +5990,10 @@ def accept( typ = self.visit_conditional_expr(node, allow_none_return=True) elif allow_none_return and isinstance(node, AwaitExpr): typ = self.visit_await_expr(node, allow_none_return=True) + # Deeply nested generic calls can deteriorate performance dramatically. + # Although in most cases caching makes little difference, in worst case + # it avoids exponential complexity. + # TODO: figure out why caching within lambdas is fragile. elif isinstance(node, (CallExpr, ListExpr, TupleExpr)) and not ( self.in_lambda_expr or self.chk.current_node_deferred ): @@ -6034,6 +6038,7 @@ def accept( def accept_maybe_cache(self, node: Expression, type_context: Type | None = None) -> Type: binder_version = self.chk.binder.version + # Micro-optimization: inline local_type_map() as it is somewhat slow in mypyc. type_map: dict[Expression, Type] = {} self.chk._type_maps.append(type_map) with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as msg: From 344a84a14d1dfda69478c69ed9ff128ea34a92da Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 25 Jul 2025 15:14:07 +0100 Subject: [PATCH 6/6] Fix bug in multiassign from union --- mypy/checker.py | 2 +- mypy/checkexpr.py | 4 +++- test-data/unit/check-overloading.test | 23 ++++++++++++++++++++++ test-data/unit/fixtures/isinstancelist.pyi | 2 ++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 3e2d5f23d876..4e78755aa153 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4001,7 +4001,7 @@ def check_multi_assignment_from_union( for t, lv in zip(transposed, self.flatten_lvalues(lvalues)): # We can access _type_maps directly since temporary type maps are # only created within expressions. - t.append(self._type_maps[0].pop(lv, AnyType(TypeOfAny.special_form))) + t.append(self._type_maps[-1].pop(lv, AnyType(TypeOfAny.special_form))) union_types = tuple(make_simplified_union(col) for col in transposed) for expr, items in assignments.items(): # Bind a union of types collected in 'assignments' to every expression. diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3944c580da74..ed0f9900c2f4 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5993,7 +5993,9 @@ def accept( # Deeply nested generic calls can deteriorate performance dramatically. # Although in most cases caching makes little difference, in worst case # it avoids exponential complexity. - # TODO: figure out why caching within lambdas is fragile. + # We cannot use cache inside lambdas, because they skip immediate type + # context, and use enclosing one, see infer_lambda_type_using_context(). + # TODO: consider using cache for more expression kinds. elif isinstance(node, (CallExpr, ListExpr, TupleExpr)) and not ( self.in_lambda_expr or self.chk.current_node_deferred ): diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 0f0fc8747223..22221416f151 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6801,3 +6801,26 @@ class D(Generic[T]): a: D[str] # E: Type argument "str" of "D" must be a subtype of "C" reveal_type(a.f(1)) # N: Revealed type is "builtins.int" reveal_type(a.f("x")) # N: Revealed type is "builtins.str" + +[case testMultiAssignFromUnionInOverloadCached] +from typing import Iterable, overload, Union, Optional + +@overload +def always_bytes(str_or_bytes: None) -> None: ... +@overload +def always_bytes(str_or_bytes: Union[str, bytes]) -> bytes: ... +def always_bytes(str_or_bytes: Union[None, str, bytes]) -> Optional[bytes]: + pass + +class Headers: + def __init__(self, iter: Iterable[tuple[bytes, bytes]]) -> None: ... + +headers: Union[Headers, dict[Union[str, bytes], Union[str, bytes]], Iterable[tuple[bytes, bytes]]] + +if isinstance(headers, dict): + headers = Headers( + (always_bytes(k), always_bytes(v)) for k, v in headers.items() + ) + +reveal_type(headers) # N: Revealed type is "Union[__main__.Headers, typing.Iterable[tuple[builtins.bytes, builtins.bytes]]]" +[builtins fixtures/isinstancelist.pyi] diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi index 0ee5258ff74b..2a43606f361a 100644 --- a/test-data/unit/fixtures/isinstancelist.pyi +++ b/test-data/unit/fixtures/isinstancelist.pyi @@ -26,6 +26,7 @@ class bool(int): pass class str: def __add__(self, x: str) -> str: pass def __getitem__(self, x: int) -> str: pass +class bytes: pass T = TypeVar('T') KT = TypeVar('KT') @@ -52,6 +53,7 @@ class dict(Mapping[KT, VT]): def __setitem__(self, k: KT, v: VT) -> None: pass def __iter__(self) -> Iterator[KT]: pass def update(self, a: Mapping[KT, VT]) -> None: pass + def items(self) -> Iterable[Tuple[KT, VT]]: pass class set(Generic[T]): def __iter__(self) -> Iterator[T]: pass