Skip to content

Commit 843d133

Browse files
More precise return types for TypedDict.get (#19897)
Fixes #19896, #19902 - `TypedDict.get` now ignores the type of the default when the key is required. - `reveal_type(d.get)` now gives an appropriate list of overloads - I kept the special casing for `get(subdict, {})`, but this is not visible in the overloads. Implementing this via overloads is blocked by #19895 Some additional changes: - I added some code that ensures that the default type always appears last in the union (relevant when a union of multiple keys is given) - I ensure that the original value-type is use instead of its `proper_type`. This simplifies the return in `testRecursiveTypedDictMethods`. --------- Co-authored-by: Ivan Levkivskyi <levkivskyi@gmail.com>
1 parent b8f57fd commit 843d133

File tree

8 files changed

+533
-57
lines changed

8 files changed

+533
-57
lines changed

mypy/checkexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1502,7 +1502,7 @@ def check_call_expr_with_callee_type(
15021502
def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
15031503
"""Type check calling a member expression where the base type is a union."""
15041504
res: list[Type] = []
1505-
for typ in object_type.relevant_items():
1505+
for typ in flatten_nested_unions(object_type.relevant_items()):
15061506
# Member access errors are already reported when visiting the member expression.
15071507
with self.msg.filter_errors():
15081508
item = analyze_member_access(

mypy/plugins/default.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import mypy.errorcodes as codes
77
from mypy import message_registry
8-
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
8+
from mypy.nodes import DictExpr, Expression, IntExpr, StrExpr, UnaryExpr
99
from mypy.plugin import (
1010
AttributeContext,
1111
ClassDefContext,
@@ -263,30 +263,40 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
263263
if keys is None:
264264
return ctx.default_return_type
265265

266+
default_type: Type
267+
default_arg: Expression | None
268+
if len(ctx.arg_types) <= 1 or not ctx.arg_types[1]:
269+
default_arg = None
270+
default_type = NoneType()
271+
elif len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
272+
default_arg = ctx.args[1][0]
273+
default_type = ctx.arg_types[1][0]
274+
else:
275+
return ctx.default_return_type
276+
266277
output_types: list[Type] = []
267278
for key in keys:
268-
value_type = get_proper_type(ctx.type.items.get(key))
279+
value_type: Type | None = ctx.type.items.get(key)
269280
if value_type is None:
270281
return ctx.default_return_type
271282

272-
if len(ctx.arg_types) == 1:
283+
if key in ctx.type.required_keys:
273284
output_types.append(value_type)
274-
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
275-
default_arg = ctx.args[1][0]
285+
else:
286+
# HACK to deal with get(key, {})
276287
if (
277288
isinstance(default_arg, DictExpr)
278289
and len(default_arg.items) == 0
279-
and isinstance(value_type, TypedDictType)
290+
and isinstance(vt := get_proper_type(value_type), TypedDictType)
280291
):
281-
# Special case '{}' as the default for a typed dict type.
282-
output_types.append(value_type.copy_modified(required_keys=set()))
292+
output_types.append(vt.copy_modified(required_keys=set()))
283293
else:
284294
output_types.append(value_type)
285-
output_types.append(ctx.arg_types[1][0])
286-
287-
if len(ctx.arg_types) == 1:
288-
output_types.append(NoneType())
295+
output_types.append(default_type)
289296

297+
# for nicer reveal_type, put default at the end, if it is present
298+
if default_type in output_types:
299+
output_types = [t for t in output_types if t != default_type] + [default_type]
290300
return make_simplified_union(output_types)
291301
return ctx.default_return_type
292302

test-data/unit/check-incremental.test

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7311,3 +7311,204 @@ x = 2
73117311
[out]
73127312
[rechecked bar]
73137313
[stale]
7314+
7315+
7316+
[case testIncrementalTypedDictGetMethodTotalFalse]
7317+
import impl
7318+
[file lib.py]
7319+
from typing import TypedDict
7320+
class Unrelated: pass
7321+
D = TypedDict('D', {'x': int, 'y': str}, total=False)
7322+
[file impl.py]
7323+
pass
7324+
[file impl.py.2]
7325+
from typing import Literal
7326+
from lib import D, Unrelated
7327+
d: D
7328+
u: Unrelated
7329+
x: Literal['x']
7330+
y: Literal['y']
7331+
z: Literal['z']
7332+
x_or_y: Literal['x', 'y']
7333+
x_or_z: Literal['x', 'z']
7334+
x_or_y_or_z: Literal['x', 'y', 'z']
7335+
7336+
# test with literal expression
7337+
reveal_type(d.get('x'))
7338+
reveal_type(d.get('y'))
7339+
reveal_type(d.get('z'))
7340+
reveal_type(d.get('x', u))
7341+
reveal_type(d.get('x', 1))
7342+
reveal_type(d.get('y', None))
7343+
7344+
# test with literal type / union of literal types with implicit default
7345+
reveal_type(d.get(x))
7346+
reveal_type(d.get(y))
7347+
reveal_type(d.get(z))
7348+
reveal_type(d.get(x_or_y))
7349+
reveal_type(d.get(x_or_z))
7350+
reveal_type(d.get(x_or_y_or_z))
7351+
7352+
# test with literal type / union of literal types with explicit default
7353+
reveal_type(d.get(x, u))
7354+
reveal_type(d.get(y, u))
7355+
reveal_type(d.get(z, u))
7356+
reveal_type(d.get(x_or_y, u))
7357+
reveal_type(d.get(x_or_z, u))
7358+
reveal_type(d.get(x_or_y_or_z, u))
7359+
[builtins fixtures/dict.pyi]
7360+
[typing fixtures/typing-typeddict.pyi]
7361+
[out]
7362+
[out2]
7363+
tmp/impl.py:13: note: Revealed type is "Union[builtins.int, None]"
7364+
tmp/impl.py:14: note: Revealed type is "Union[builtins.str, None]"
7365+
tmp/impl.py:15: note: Revealed type is "builtins.object"
7366+
tmp/impl.py:16: note: Revealed type is "Union[builtins.int, lib.Unrelated]"
7367+
tmp/impl.py:17: note: Revealed type is "builtins.int"
7368+
tmp/impl.py:18: note: Revealed type is "Union[builtins.str, None]"
7369+
tmp/impl.py:21: note: Revealed type is "Union[builtins.int, None]"
7370+
tmp/impl.py:22: note: Revealed type is "Union[builtins.str, None]"
7371+
tmp/impl.py:23: note: Revealed type is "builtins.object"
7372+
tmp/impl.py:24: note: Revealed type is "Union[builtins.int, builtins.str, None]"
7373+
tmp/impl.py:25: note: Revealed type is "builtins.object"
7374+
tmp/impl.py:26: note: Revealed type is "builtins.object"
7375+
tmp/impl.py:29: note: Revealed type is "Union[builtins.int, lib.Unrelated]"
7376+
tmp/impl.py:30: note: Revealed type is "Union[builtins.str, lib.Unrelated]"
7377+
tmp/impl.py:31: note: Revealed type is "builtins.object"
7378+
tmp/impl.py:32: note: Revealed type is "Union[builtins.int, builtins.str, lib.Unrelated]"
7379+
tmp/impl.py:33: note: Revealed type is "builtins.object"
7380+
tmp/impl.py:34: note: Revealed type is "builtins.object"
7381+
7382+
[case testIncrementalTypedDictGetMethodTotalTrue]
7383+
import impl
7384+
[file lib.py]
7385+
from typing import TypedDict
7386+
class Unrelated: pass
7387+
D = TypedDict('D', {'x': int, 'y': str}, total=True)
7388+
[file impl.py]
7389+
pass
7390+
[file impl.py.2]
7391+
from typing import Literal
7392+
from lib import D, Unrelated
7393+
d: D
7394+
u: Unrelated
7395+
x: Literal['x']
7396+
y: Literal['y']
7397+
z: Literal['z']
7398+
x_or_y: Literal['x', 'y']
7399+
x_or_z: Literal['x', 'z']
7400+
x_or_y_or_z: Literal['x', 'y', 'z']
7401+
7402+
# test with literal expression
7403+
reveal_type(d.get('x'))
7404+
reveal_type(d.get('y'))
7405+
reveal_type(d.get('z'))
7406+
reveal_type(d.get('x', u))
7407+
reveal_type(d.get('x', 1))
7408+
reveal_type(d.get('y', None))
7409+
7410+
# test with literal type / union of literal types with implicit default
7411+
reveal_type(d.get(x))
7412+
reveal_type(d.get(y))
7413+
reveal_type(d.get(z))
7414+
reveal_type(d.get(x_or_y))
7415+
reveal_type(d.get(x_or_z))
7416+
reveal_type(d.get(x_or_y_or_z))
7417+
7418+
# test with literal type / union of literal types with explicit default
7419+
reveal_type(d.get(x, u))
7420+
reveal_type(d.get(y, u))
7421+
reveal_type(d.get(z, u))
7422+
reveal_type(d.get(x_or_y, u))
7423+
reveal_type(d.get(x_or_z, u))
7424+
reveal_type(d.get(x_or_y_or_z, u))
7425+
[builtins fixtures/dict.pyi]
7426+
[typing fixtures/typing-typeddict.pyi]
7427+
[out]
7428+
[out2]
7429+
tmp/impl.py:13: note: Revealed type is "builtins.int"
7430+
tmp/impl.py:14: note: Revealed type is "builtins.str"
7431+
tmp/impl.py:15: note: Revealed type is "builtins.object"
7432+
tmp/impl.py:16: note: Revealed type is "builtins.int"
7433+
tmp/impl.py:17: note: Revealed type is "builtins.int"
7434+
tmp/impl.py:18: note: Revealed type is "builtins.str"
7435+
tmp/impl.py:21: note: Revealed type is "builtins.int"
7436+
tmp/impl.py:22: note: Revealed type is "builtins.str"
7437+
tmp/impl.py:23: note: Revealed type is "builtins.object"
7438+
tmp/impl.py:24: note: Revealed type is "Union[builtins.int, builtins.str]"
7439+
tmp/impl.py:25: note: Revealed type is "builtins.object"
7440+
tmp/impl.py:26: note: Revealed type is "builtins.object"
7441+
tmp/impl.py:29: note: Revealed type is "builtins.int"
7442+
tmp/impl.py:30: note: Revealed type is "builtins.str"
7443+
tmp/impl.py:31: note: Revealed type is "builtins.object"
7444+
tmp/impl.py:32: note: Revealed type is "Union[builtins.int, builtins.str]"
7445+
tmp/impl.py:33: note: Revealed type is "builtins.object"
7446+
tmp/impl.py:34: note: Revealed type is "builtins.object"
7447+
7448+
7449+
[case testIncrementalTypedDictGetMethodTotalMixed]
7450+
import impl
7451+
[file lib.py]
7452+
from typing import TypedDict
7453+
from typing_extensions import Required, NotRequired
7454+
class Unrelated: pass
7455+
D = TypedDict('D', {'x': Required[int], 'y': NotRequired[str]})
7456+
[file impl.py]
7457+
pass
7458+
[file impl.py.2]
7459+
from typing import Literal
7460+
from lib import D, Unrelated
7461+
d: D
7462+
u: Unrelated
7463+
x: Literal['x']
7464+
y: Literal['y']
7465+
z: Literal['z']
7466+
x_or_y: Literal['x', 'y']
7467+
x_or_z: Literal['x', 'z']
7468+
x_or_y_or_z: Literal['x', 'y', 'z']
7469+
7470+
# test with literal expression
7471+
reveal_type(d.get('x'))
7472+
reveal_type(d.get('y'))
7473+
reveal_type(d.get('z'))
7474+
reveal_type(d.get('x', u))
7475+
reveal_type(d.get('x', 1))
7476+
reveal_type(d.get('y', None))
7477+
7478+
# test with literal type / union of literal types with implicit default
7479+
reveal_type(d.get(x))
7480+
reveal_type(d.get(y))
7481+
reveal_type(d.get(z))
7482+
reveal_type(d.get(x_or_y))
7483+
reveal_type(d.get(x_or_z))
7484+
reveal_type(d.get(x_or_y_or_z))
7485+
7486+
# test with literal type / union of literal types with explicit default
7487+
reveal_type(d.get(x, u))
7488+
reveal_type(d.get(y, u))
7489+
reveal_type(d.get(z, u))
7490+
reveal_type(d.get(x_or_y, u))
7491+
reveal_type(d.get(x_or_z, u))
7492+
reveal_type(d.get(x_or_y_or_z, u))
7493+
[builtins fixtures/dict.pyi]
7494+
[typing fixtures/typing-typeddict.pyi]
7495+
[out]
7496+
[out2]
7497+
tmp/impl.py:13: note: Revealed type is "builtins.int"
7498+
tmp/impl.py:14: note: Revealed type is "Union[builtins.str, None]"
7499+
tmp/impl.py:15: note: Revealed type is "builtins.object"
7500+
tmp/impl.py:16: note: Revealed type is "builtins.int"
7501+
tmp/impl.py:17: note: Revealed type is "builtins.int"
7502+
tmp/impl.py:18: note: Revealed type is "Union[builtins.str, None]"
7503+
tmp/impl.py:21: note: Revealed type is "builtins.int"
7504+
tmp/impl.py:22: note: Revealed type is "Union[builtins.str, None]"
7505+
tmp/impl.py:23: note: Revealed type is "builtins.object"
7506+
tmp/impl.py:24: note: Revealed type is "Union[builtins.int, builtins.str, None]"
7507+
tmp/impl.py:25: note: Revealed type is "builtins.object"
7508+
tmp/impl.py:26: note: Revealed type is "builtins.object"
7509+
tmp/impl.py:29: note: Revealed type is "builtins.int"
7510+
tmp/impl.py:30: note: Revealed type is "Union[builtins.str, lib.Unrelated]"
7511+
tmp/impl.py:31: note: Revealed type is "builtins.object"
7512+
tmp/impl.py:32: note: Revealed type is "Union[builtins.int, builtins.str, lib.Unrelated]"
7513+
tmp/impl.py:33: note: Revealed type is "builtins.object"
7514+
tmp/impl.py:34: note: Revealed type is "builtins.object"

test-data/unit/check-literal.test

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ reveal_type(d[a_key]) # N: Revealed type is "builtins.int"
18841884
reveal_type(d[b_key]) # N: Revealed type is "builtins.str"
18851885
d[c_key] # E: TypedDict "Outer" has no key "c"
18861886

1887-
reveal_type(d.get(a_key, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1887+
reveal_type(d.get(a_key, u)) # N: Revealed type is "builtins.int"
18881888
reveal_type(d.get(b_key, u)) # N: Revealed type is "Union[builtins.str, __main__.Unrelated]"
18891889
reveal_type(d.get(c_key, u)) # N: Revealed type is "builtins.object"
18901890

@@ -1928,7 +1928,7 @@ u: Unrelated
19281928
reveal_type(a[int_key_good]) # N: Revealed type is "builtins.int"
19291929
reveal_type(b[int_key_good]) # N: Revealed type is "builtins.int"
19301930
reveal_type(c[str_key_good]) # N: Revealed type is "builtins.int"
1931-
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1931+
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "builtins.int"
19321932
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "builtins.object"
19331933

19341934
a[int_key_bad] # E: Tuple index out of range
@@ -1993,8 +1993,8 @@ optional_keys: Literal["d", "e"]
19931993
bad_keys: Literal["a", "bad"]
19941994

19951995
reveal_type(test[good_keys]) # N: Revealed type is "Union[__main__.A, __main__.B]"
1996-
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B, None]"
1997-
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?, __main__.B]"
1996+
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B]"
1997+
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B]"
19981998
reveal_type(test.pop(optional_keys)) # N: Revealed type is "Union[__main__.D, __main__.E]"
19991999
reveal_type(test.pop(optional_keys, 3)) # N: Revealed type is "Union[__main__.D, __main__.E, Literal[3]?]"
20002000
reveal_type(test.setdefault(good_keys, AAndB())) # N: Revealed type is "Union[__main__.A, __main__.B]"
@@ -2037,15 +2037,18 @@ class D2(TypedDict):
20372037
d: D
20382038

20392039
x: Union[D1, D2]
2040-
bad_keys: Literal['a', 'b', 'c', 'd']
20412040
good_keys: Literal['b', 'c']
2041+
mixed_keys: Literal['a', 'b', 'c', 'd']
2042+
bad_keys: Literal['e', 'f']
20422043

2043-
x[bad_keys] # E: TypedDict "D1" has no key "d" \
2044+
x[mixed_keys] # E: TypedDict "D1" has no key "d" \
20442045
# E: TypedDict "D2" has no key "a"
20452046

20462047
reveal_type(x[good_keys]) # N: Revealed type is "Union[__main__.B, __main__.C]"
2047-
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C, None]"
2048-
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, Literal[3]?, __main__.C]"
2048+
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2049+
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2050+
reveal_type(x.get(mixed_keys)) # N: Revealed type is "builtins.object"
2051+
reveal_type(x.get(mixed_keys, 3)) # N: Revealed type is "builtins.object"
20492052
reveal_type(x.get(bad_keys)) # N: Revealed type is "builtins.object"
20502053
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "builtins.object"
20512054

test-data/unit/check-recursive-types.test

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,10 +690,11 @@ class TD(TypedDict, total=False):
690690
y: TD
691691

692692
td: TD
693+
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...}), None]"
693694
td["y"] = {"x": 0, "y": {}}
694695
td["y"] = {"x": 0, "y": {"x": 0, "y": 42}} # E: Incompatible types (expression has type "int", TypedDict item "y" has type "TD")
695696

696-
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...})}), None]"
697+
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...}), None]"
697698
s: str = td.get("y") # E: Incompatible types in assignment (expression has type "Optional[TD]", variable has type "str")
698699

699700
td.update({"x": 0, "y": {"x": 1, "y": {}}})

0 commit comments

Comments
 (0)