Skip to content

Commit d69419c

Browse files
[mypyc] feat: extend get_expr_length to work with RTuple [2/4] (#19929)
This PR extends `get_expr_length` to work with type information from RTuple types.
1 parent 8e57622 commit d69419c

File tree

2 files changed

+69
-74
lines changed

2 files changed

+69
-74
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,18 +1203,18 @@ def gen_cleanup(self) -> None:
12031203
gen.gen_cleanup()
12041204

12051205

1206-
def get_expr_length(expr: Expression) -> int | None:
1206+
def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
12071207
if isinstance(expr, (StrExpr, BytesExpr)):
12081208
return len(expr.value)
12091209
elif isinstance(expr, (ListExpr, TupleExpr)):
12101210
# if there are no star expressions, or we know the length of them,
12111211
# we know the length of the expression
1212-
stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)]
1212+
stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)]
12131213
if None not in stars:
12141214
other = sum(not isinstance(i, StarExpr) for i in expr.items)
12151215
return other + sum(stars) # type: ignore [arg-type]
12161216
elif isinstance(expr, StarExpr):
1217-
return get_expr_length(expr.expr)
1217+
return get_expr_length(builder, expr.expr)
12181218
elif (
12191219
isinstance(expr, RefExpr)
12201220
and isinstance(expr.node, Var)
@@ -1227,6 +1227,11 @@ def get_expr_length(expr: Expression) -> int | None:
12271227
# performance boost and can be (sometimes) figured out pretty easily. set and dict
12281228
# comps *can* be done as well but will need special logic to consider the possibility
12291229
# of key conflicts. Range, enumerate, zip are all simple logic.
1230+
1231+
# we might still be able to get the length directly from the type
1232+
rtype = builder.node_type(expr)
1233+
if isinstance(rtype, RTuple):
1234+
return len(rtype.types)
12301235
return None
12311236

12321237

@@ -1235,7 +1240,7 @@ def get_expr_length_value(
12351240
) -> Value:
12361241
rtype = builder.node_type(expr)
12371242
assert is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple), rtype
1238-
length = get_expr_length(expr)
1243+
length = get_expr_length(builder, expr)
12391244
if length is None:
12401245
# We cannot compute the length at compile time, so we will fetch it.
12411246
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)

mypyc/test-data/irbuild-tuple.test

Lines changed: 60 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -694,51 +694,46 @@ L0:
694694
return r1
695695
def test():
696696
r0, source :: tuple[int, int, int]
697-
r1 :: object
698-
r2 :: native_int
699-
r3 :: bit
700-
r4, r5, r6 :: int
701-
r7, r8, r9 :: object
702-
r10, r11 :: tuple
703-
r12 :: native_int
704-
r13 :: bit
697+
r1, r2, r3 :: int
698+
r4, r5, r6 :: object
699+
r7, r8 :: tuple
700+
r9 :: native_int
701+
r10 :: bit
702+
r11 :: object
703+
r12, x :: int
704+
r13 :: bool
705705
r14 :: object
706-
r15, x :: int
707-
r16 :: bool
708-
r17 :: object
709-
r18 :: native_int
706+
r15 :: native_int
710707
a :: tuple
711708
L0:
712709
r0 = (2, 4, 6)
713710
source = r0
714-
r1 = box(tuple[int, int, int], source)
715-
r2 = PyObject_Size(r1)
716-
r3 = r2 >= 0 :: signed
717-
r4 = source[0]
718-
r5 = source[1]
719-
r6 = source[2]
720-
r7 = box(int, r4)
721-
r8 = box(int, r5)
722-
r9 = box(int, r6)
723-
r10 = PyTuple_Pack(3, r7, r8, r9)
724-
r11 = PyTuple_New(r2)
725-
r12 = 0
711+
r1 = source[0]
712+
r2 = source[1]
713+
r3 = source[2]
714+
r4 = box(int, r1)
715+
r5 = box(int, r2)
716+
r6 = box(int, r3)
717+
r7 = PyTuple_Pack(3, r4, r5, r6)
718+
r8 = PyTuple_New(3)
719+
r9 = 0
720+
goto L2
726721
L1:
727-
r13 = r12 < r2 :: signed
728-
if r13 goto L2 else goto L4 :: bool
722+
r10 = r9 < 3 :: signed
723+
if r10 goto L2 else goto L4 :: bool
729724
L2:
730-
r14 = CPySequenceTuple_GetItemUnsafe(r10, r12)
731-
r15 = unbox(int, r14)
732-
x = r15
733-
r16 = f(x)
734-
r17 = box(bool, r16)
735-
CPySequenceTuple_SetItemUnsafe(r11, r12, r17)
725+
r11 = CPySequenceTuple_GetItemUnsafe(r7, r9)
726+
r12 = unbox(int, r11)
727+
x = r12
728+
r13 = f(x)
729+
r14 = box(bool, r13)
730+
CPySequenceTuple_SetItemUnsafe(r8, r9, r14)
736731
L3:
737-
r18 = r12 + 1
738-
r12 = r18
732+
r15 = r9 + 1
733+
r9 = r15
739734
goto L1
740735
L4:
741-
a = r11
736+
a = r8
742737
return 1
743738

744739
[case testTupleBuiltFromFinalFixedLengthTuple]
@@ -762,19 +757,16 @@ L0:
762757
def test():
763758
r0 :: tuple[int, int, int]
764759
r1 :: bool
765-
r2 :: object
766-
r3 :: native_int
767-
r4 :: bit
768-
r5, r6, r7 :: int
769-
r8, r9, r10 :: object
770-
r11, r12 :: tuple
771-
r13 :: native_int
772-
r14 :: bit
760+
r2, r3, r4 :: int
761+
r5, r6, r7 :: object
762+
r8, r9 :: tuple
763+
r10 :: native_int
764+
r11 :: bit
765+
r12 :: object
766+
r13, x :: int
767+
r14 :: bool
773768
r15 :: object
774-
r16, x :: int
775-
r17 :: bool
776-
r18 :: object
777-
r19 :: native_int
769+
r16 :: native_int
778770
a :: tuple
779771
L0:
780772
r0 = __main__.source :: static
@@ -783,34 +775,32 @@ L1:
783775
r1 = raise NameError('value for final name "source" was not set')
784776
unreachable
785777
L2:
786-
r2 = box(tuple[int, int, int], r0)
787-
r3 = PyObject_Size(r2)
788-
r4 = r3 >= 0 :: signed
789-
r5 = r0[0]
790-
r6 = r0[1]
791-
r7 = r0[2]
792-
r8 = box(int, r5)
793-
r9 = box(int, r6)
794-
r10 = box(int, r7)
795-
r11 = PyTuple_Pack(3, r8, r9, r10)
796-
r12 = PyTuple_New(r3)
797-
r13 = 0
778+
r2 = r0[0]
779+
r3 = r0[1]
780+
r4 = r0[2]
781+
r5 = box(int, r2)
782+
r6 = box(int, r3)
783+
r7 = box(int, r4)
784+
r8 = PyTuple_Pack(3, r5, r6, r7)
785+
r9 = PyTuple_New(3)
786+
r10 = 0
787+
goto L4
798788
L3:
799-
r14 = r13 < r3 :: signed
800-
if r14 goto L4 else goto L6 :: bool
789+
r11 = r10 < 3 :: signed
790+
if r11 goto L4 else goto L6 :: bool
801791
L4:
802-
r15 = CPySequenceTuple_GetItemUnsafe(r11, r13)
803-
r16 = unbox(int, r15)
804-
x = r16
805-
r17 = f(x)
806-
r18 = box(bool, r17)
807-
CPySequenceTuple_SetItemUnsafe(r12, r13, r18)
792+
r12 = CPySequenceTuple_GetItemUnsafe(r8, r10)
793+
r13 = unbox(int, r12)
794+
x = r13
795+
r14 = f(x)
796+
r15 = box(bool, r14)
797+
CPySequenceTuple_SetItemUnsafe(r9, r10, r15)
808798
L5:
809-
r19 = r13 + 1
810-
r13 = r19
799+
r16 = r10 + 1
800+
r10 = r16
811801
goto L3
812802
L6:
813-
a = r12
803+
a = r9
814804
return 1
815805

816806
[case testTupleBuiltFromVariableLengthTuple]

0 commit comments

Comments
 (0)