diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 62ac9b8d48e4..e6a10317a495 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -99,6 +99,9 @@ def __init__(self, label: int = -1) -> None: self.error_handler: BasicBlock | None = None self.referenced = False + def __repr__(self) -> str: + return f"{type(self).__name__}(label={self.label}, ops={self.ops})" + @property def terminated(self) -> bool: """Does the block end with a jump, branch or return? diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 5cf89f579ec4..e9b0fce05821 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -7,19 +7,24 @@ from __future__ import annotations -from typing import Callable, ClassVar +from typing import Any, Callable, ClassVar from mypy.nodes import ( ARG_POS, + BytesExpr, CallExpr, DictionaryComprehension, Expression, + FloatExpr, GeneratorExpr, + IntExpr, + ListExpr, Lvalue, MemberExpr, NameExpr, RefExpr, SetExpr, + StrExpr, TupleExpr, TypeAlias, ) @@ -90,6 +95,7 @@ def for_loop_helper( else_insts: GenFunc | None, is_async: bool, line: int, + can_unroll: bool = False, ) -> None: """Generate IR for a loop. @@ -98,6 +104,7 @@ def for_loop_helper( expr: the expression to iterate over body_insts: a function that generates the body of the loop else_insts: a function that generates the else block instructions + can_unroll: whether unrolling is allowed (for semantic safety) """ # Body of the loop body_block = BasicBlock() @@ -112,9 +119,28 @@ def for_loop_helper( normal_loop_exit = else_block if else_insts is not None else exit_block for_gen = make_for_loop_generator( - builder, index, expr, body_block, normal_loop_exit, line, is_async=is_async + builder, + index, + expr, + body_block, + normal_loop_exit, + line, + is_async=is_async, + body_insts=body_insts, + can_unroll=can_unroll, ) + is_literal_loop: bool = getattr(for_gen, "handles_body_insts", False) + + # Only call body_insts if not handled by unrolled generator + if is_literal_loop: + try: + for_gen.begin_body() + return + except AssertionError: + # For whatever reason, we can't unpack the loop in this case. + pass + builder.push_loop_stack(step_block, exit_block) condition_block = BasicBlock() builder.goto_and_activate(condition_block) @@ -377,6 +403,30 @@ def is_range_ref(expr: RefExpr) -> bool: ) +def is_literal_expr(expr: Expression) -> bool: + # Add other literal types as needed + if isinstance(expr, (IntExpr, StrExpr, FloatExpr, BytesExpr)): + return True + if isinstance(expr, NameExpr) and expr.fullname in { + "builtins.None", + "builtins.True", + "builtins.False", + }: + return True + return False + + +def is_iterable_expr_with_literal_mambers(expr: Expression) -> bool: + return ( + isinstance(expr, (ListExpr, SetExpr, TupleExpr)) + and not isinstance(expr, MemberExpr) + and all( + is_literal_expr(item) or is_iterable_expr_with_literal_mambers(item) + for item in expr.items + ) + ) + + def make_for_loop_generator( builder: IRBuilder, index: Lvalue, @@ -386,10 +436,13 @@ def make_for_loop_generator( line: int, is_async: bool = False, nested: bool = False, + body_insts: GenFunc | None = None, + can_unroll: bool = True, ) -> ForGenerator: """Return helper object for generating a for loop over an iterable. If "nested" is True, this is a nested iterator such as "e" in "enumerate(e)". + can_unroll: whether unrolling is allowed (for semantic safety) """ # Do an async loop if needed. async is always generic @@ -402,6 +455,33 @@ def make_for_loop_generator( return async_obj rtyp = builder.node_type(expr) + + if can_unroll and body_insts is not None: + # Special case: tuple/list/set literal (unroll the loop) + if is_iterable_expr_with_literal_mambers(expr): + return ForUnrolledSequenceLiteral( + builder, index, body_block, loop_exit, line, expr, body_insts # type: ignore [arg-type] + ) + + # Special case: RTuple (known-length tuple, struct field iteration) + if isinstance(rtyp, RTuple): + expr_reg = builder.accept(expr) + return ForUnrolledRTuple( + builder, index, body_block, loop_exit, line, rtyp, expr_reg, expr, body_insts + ) + + # Special case: string literal (unroll the loop) + if isinstance(expr, StrExpr): + return ForUnrolledStringLiteral( + builder, index, body_block, loop_exit, line, expr.value, expr, body_insts + ) + + # Special case: bytes literal (unroll the loop) + if isinstance(expr, BytesExpr): + return ForUnrolledBytesLiteral( + builder, index, body_block, loop_exit, line, expr.value.encode(), expr, body_insts + ) + if is_sequence_rprimitive(rtyp): # Special case "for x in ". expr_reg = builder.accept(expr) @@ -764,6 +844,181 @@ def gen_step(self) -> None: pass +class _ForUnrolled(ForGenerator): + """Generate IR for a for loop over a value known at compile time by unrolling the loop. + + This class emits the loop body for each element of the value literal directly, + avoiding any runtime iteration logic and generator handling. + """ + + def __init__(self, *args: Any, **kwargs: Any): + if type(self) is _ForUnrolled: + raise NotImplementedError( + "This is a base class and should not be initialized directly." + ) + super().__init__(*args, **kwargs) + + def gen_condition(self) -> None: + # Unrolled: nothing to do here. + pass + + def gen_step(self) -> None: + # Unrolled: nothing to do here. + pass + + +class ForUnrolledSequenceLiteral(_ForUnrolled): + """Generate IR for a for loop over a tuple literal by unrolling the loop. + + This class emits the loop body for each element of the tuple literal directly, + avoiding any runtime iteration logic. + """ + + handles_body_insts = True + + def __init__( + self, + builder: IRBuilder, + index: Lvalue, + body_block: BasicBlock, + loop_exit: BasicBlock, + line: int, + expr: ListExpr | SetExpr | TupleExpr, + body_insts: GenFunc, + ) -> None: + super().__init__(builder, index, body_block, loop_exit, line, nested=False) + self.expr = expr + self.items = expr.items + self.body_insts = body_insts + self.index_target = builder.maybe_spill_assignable(Integer(0, c_pyssize_t_rprimitive)) + + def gen_condition(self) -> None: + # For unrolled loops, immediately jump to the body if there are items, + # otherwise jump to the loop exit. + # NOTE: this method is not used when the loop is fully unrolled, but is when the loop is a component of another loop, ie a ForZip + builder = self.builder + comparison = builder.binary_op( + builder.read(self.index_target, self.line), Integer(len(self.items)), "<", self.line + ) + builder.add_bool_branch(comparison, self.body_block, self.loop_exit) + + def gen_step(self) -> None: + # NOTE: this method is not used when the loop is fully unrolled, but is when the loop is a component of another loop, ie a ForZip + builder = self.builder + add = builder.builder.int_add(builder.read(self.index_target, self.line), 1) + builder.assign(self.index_target, add, self.line) + + def begin_body(self) -> None: + builder = self.builder + for expr in self.items: + value = builder.accept(expr) + # value = builder.coerce(value, builder.node_type(expr), self.line) + builder.assign(builder.get_assignment_target(self.index), value, self.line) + self.body_insts() + + +class ForUnrolledStringLiteral(_ForUnrolled): + """Generate IR for a for loop over a string literal by unrolling the loop. + + This class emits the loop body for each character of the string literal directly, + avoiding any runtime iteration logic. + """ + + handles_body_insts = True + + def __init__( + self, + builder: IRBuilder, + index: Lvalue, + body_block: BasicBlock, + loop_exit: BasicBlock, + line: int, + value: str, + expr: Expression, + body_insts: GenFunc, + ) -> None: + super().__init__(builder, index, body_block, loop_exit, line, nested=False) + self.value = value + self.expr = expr + self.body_insts = body_insts + + def begin_body(self) -> None: + builder = self.builder + for c in self.value: + builder.assign( + builder.get_assignment_target(self.index), builder.accept(StrExpr(c)), self.line + ) + self.body_insts() + + +class ForUnrolledBytesLiteral(_ForUnrolled): + """Generate IR for a for loop over a string literal by unrolling the loop. + + This class emits the loop body for each character of the string literal directly, + avoiding any runtime iteration logic. + """ + + handles_body_insts = True + + def __init__( + self, + builder: IRBuilder, + index: Lvalue, + body_block: BasicBlock, + loop_exit: BasicBlock, + line: int, + value: bytes, + expr: Expression, + body_insts: GenFunc, + ) -> None: + super().__init__(builder, index, body_block, loop_exit, line, nested=False) + self.value = value + self.expr = expr + self.body_insts = body_insts + + def begin_body(self) -> None: + builder = self.builder + for c in self.value: + builder.assign( + builder.get_assignment_target(self.index), builder.accept(IntExpr(c)), self.line + ) + self.body_insts() + + +class ForUnrolledRTuple(_ForUnrolled): + """Generate IR for a for loop over an RTuple by directly accessing struct fields.""" + + handles_body_insts = True + + def __init__( + self, + builder: IRBuilder, + index: Lvalue, + body_block: BasicBlock, + loop_exit: BasicBlock, + line: int, + rtuple_type: RTuple, + expr_reg: Value, + expr: Expression, + body_insts: GenFunc, + ) -> None: + super().__init__(builder, index, body_block, loop_exit, line, nested=False) + self.rtuple_type = rtuple_type + self.expr_reg = expr_reg + self.expr = expr + self.body_insts = body_insts + + def begin_body(self) -> None: + builder = self.builder + line = self.line + for i, item_type in enumerate(self.rtuple_type.types): + # Directly access the struct field for each RTuple element + value = builder.add(TupleGet(self.expr_reg, i, line)) + # value = builder.coerce(value, item_type, line) + builder.assign(builder.get_assignment_target(self.index), value, line) + self.body_insts() + + def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> Value: """Emit a potentially unsafe index into a target.""" # This doesn't really fit nicely into any of our data-driven frameworks diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 79ad4cc62822..6a9a13b2570b 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -273,7 +273,7 @@ def goto(self, target: BasicBlock) -> None: def activate_block(self, block: BasicBlock) -> None: """Add a basic block and make it the active one (target of adds).""" if self.blocks: - assert self.blocks[-1].terminated + assert self.blocks[-1].terminated, self.blocks[-1] block.error_handler = self.error_handlers[-1] self.blocks.append(block) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index eeeb40ac672f..e2399ef21504 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -13,6 +13,7 @@ from typing import Callable import mypy.nodes +import mypy.traverser from mypy.nodes import ( ARG_NAMED, ARG_POS, @@ -130,6 +131,30 @@ ValueGenFunc = Callable[[], Value] +class ControlFlowDetector(mypy.traverser.TraverserVisitor): + """ + A Visitor class that detects whether a block contains any of the following control flow statements: + + - return + - continue + - break + + """ + + def __init__(self) -> None: + super().__init__() + self.has_control_flow = False + + def visit_break_stmt(self, o: BreakStmt) -> None: + self.has_control_flow = True + + def visit_continue_stmt(self, o: ContinueStmt) -> None: + self.has_control_flow = True + + def visit_return_stmt(self, o: ReturnStmt) -> None: + self.has_control_flow = True + + def transform_block(builder: IRBuilder, block: Block) -> None: if not block.is_unreachable: builder.block_reachable_stack.append(True) @@ -434,6 +459,11 @@ def transform_while_stmt(builder: IRBuilder, s: WhileStmt) -> None: def transform_for_stmt(builder: IRBuilder, s: ForStmt) -> None: + # Check for control flow statements in the loop body + detector = ControlFlowDetector() + s.body.accept(detector) + can_unroll = s.else_body is None and not detector.has_control_flow + def body() -> None: builder.accept(s.body) @@ -441,9 +471,25 @@ def else_block() -> None: assert s.else_body is not None builder.accept(s.else_body) - for_loop_helper( - builder, s.index, s.expr, body, else_block if s.else_body else None, s.is_async, s.line - ) + while True: + try: + # for whatever reason this can blow up with can_unroll=True + # but in those cases we can just fall back to the existing for loop logic. + for_loop_helper( + builder, + s.index, + s.expr, + body, + else_block if s.else_body else None, + s.is_async, + s.line, + can_unroll=can_unroll, + ) + return + except AssertionError: + if not can_unroll: + raise + can_unroll = False def transform_break_stmt(builder: IRBuilder, node: BreakStmt) -> None: diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 4a7d315ec836..83f79bd15065 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -3546,3 +3546,141 @@ L0: r2 = PyObject_Vectorcall(r1, 0, 0, 0) r3 = box(None, 1) return r3 + +[case testForOverTupleLiteral] +def f() -> int: + s = 0 + for x in (1, 2, 3): + s += x + return s +[out] +def f(): + s, x, r0, r1, r2 :: int +L0: + s = 0 + x = 2 + r0 = CPyTagged_Add(s, x) + s = r0 + x = 4 + r1 = CPyTagged_Add(s, x) + s = r1 + x = 6 + r2 = CPyTagged_Add(s, x) + s = r2 + return s + +[case testForOverStringLiteral] +def f() -> str: + out = "" + for c in "abc": + out += c + return out +[out] +def f(): + r0, out, r1, c, r2, r3, r4, r5, r6 :: str +L0: + r0 = '' + out = r0 + r1 = 'a' + c = r1 + r2 = CPyStr_Append(out, c) + out = r2 + r3 = 'b' + c = r3 + r4 = CPyStr_Append(out, c) + out = r4 + r5 = 'c' + c = r5 + r6 = CPyStr_Append(out, c) + out = r6 + return out + +[case testForOverRTuple] +from typing import Tuple +def f(t: Tuple[int, int]) -> int: + s = 0 + for x in t: + s += x + return s +[out] +def f(t): + t :: tuple[int, int] + s, r0, x, r1, r2, r3 :: int +L0: + s = 0 + r0 = t[0] + x = r0 + r1 = CPyTagged_Add(s, x) + s = r1 + r2 = t[1] + x = r2 + r3 = CPyTagged_Add(s, x) + s = r3 + return s + +[case testForOverStringVar] +def f(s: str) -> str: + out = "" + for c in s: + out += c + return out +[out] +def f(s): + s, r0, out :: str + r1, r2 :: native_int + r3, r4 :: bit + r5, c, r6 :: str + r7 :: native_int + out, c :: str +L0: + r0 = '' + out = r0 + r1 = 0 +L1: + r2 = CPyStr_Size_size_t(s) + r3 = r2 >= 0 :: signed + r4 = r1 < r2 :: signed + if r4 goto L2 else goto L2 :: bool +L2: + r5 = CPyStr_GetItemUnsafe(s, r1) + c = r5 + r6 = CPyStr_Append(out, c) + out = r6 +L3: + r7 = r1 + 1 + r1 = r7 + goto L1 +L4: + return out + +[case TestForOverCompledTupleExpr] +def f() -> None: + abc = (1, 2, str(3)) + for x in abc: + y = x +[out] +def f(): + r0 :: str + r1, abc :: tuple[int, int, str] + r2 :: int + r3 :: object + x, y :: union[int, str] + r4 :: int + r5 :: object + r6 :: str +L0: + r0 = CPyTagged_Str(6) + r1 = (2, 4, r0) + abc = r1 + r2 = abc[0] + r3 = box(int, r2) + x = r3 + y = x + r4 = abc[1] + r5 = box(int, r4) + x = r5 + y = x + r6 = abc[2] + x = r6 + y = x + return 1 diff --git a/mypyc/test-data/irbuild-set.test b/mypyc/test-data/irbuild-set.test index 5586a2bf4cfb..aea36385fb13 100644 --- a/mypyc/test-data/irbuild-set.test +++ b/mypyc/test-data/irbuild-set.test @@ -129,33 +129,41 @@ L4: def test2(): r0, tmp_tuple :: tuple[int, int, int] r1 :: set - r2, r3, r4 :: object - r5, x, r6 :: int - r7 :: object - r8 :: i32 - r9, r10 :: bit + r2, x, r3 :: int + r4 :: object + r5 :: i32 + r6 :: bit + r7, r8 :: int + r9 :: object + r10 :: i32 + r11 :: bit + r12, r13 :: int + r14 :: object + r15 :: i32 + r16 :: bit b :: set L0: r0 = (2, 6, 10) tmp_tuple = r0 r1 = PySet_New(0) - r2 = box(tuple[int, int, int], tmp_tuple) - r3 = PyObject_GetIter(r2) -L1: - r4 = PyIter_Next(r3) - if is_error(r4) goto L4 else goto L2 -L2: - r5 = unbox(int, r4) - x = r5 - r6 = f(x) - r7 = box(int, r6) - r8 = PySet_Add(r1, r7) - r9 = r8 >= 0 :: signed -L3: - goto L1 -L4: - r10 = CPy_NoErrOccurred() -L5: + r2 = tmp_tuple[0] + x = r2 + r3 = f(x) + r4 = box(int, r3) + r5 = PySet_Add(r1, r4) + r6 = r5 >= 0 :: signed + r7 = tmp_tuple[1] + x = r7 + r8 = f(x) + r9 = box(int, r8) + r10 = PySet_Add(r1, r9) + r11 = r10 >= 0 :: signed + r12 = tmp_tuple[2] + x = r12 + r13 = f(x) + r14 = box(int, r13) + r15 = PySet_Add(r1, r14) + r16 = r15 >= 0 :: signed b = r1 return 1 def test3(): @@ -727,25 +735,16 @@ def not_precomputed() -> None: [out] def precomputed(): - r0 :: set - r1, r2 :: object - r3 :: str + r0 :: str _ :: object - r4 :: bit + r1, r2 :: str L0: - r0 = frozenset({'False', 'None', 'True'}) - r1 = PyObject_GetIter(r0) -L1: - r2 = PyIter_Next(r1) - if is_error(r2) goto L4 else goto L2 -L2: - r3 = cast(str, r2) - _ = r3 -L3: - goto L1 -L4: - r4 = CPy_NoErrOccurred() -L5: + r0 = 'None' + _ = r0 + r1 = 'True' + _ = r1 + r2 = 'False' + _ = r2 return 1 def precomputed2(): r0 :: set diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 72c00a3b9b1c..74f95286c712 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1131,9 +1131,9 @@ async def main() -> None: reveal_type(a_y) reveal_type(asyncio.gather(*[asyncio.sleep(1), asyncio.sleep(1)])) [out] -_testAsyncioGatherPreciseType.py:9: note: Revealed type is "builtins.str" _testAsyncioGatherPreciseType.py:10: note: Revealed type is "builtins.str" -_testAsyncioGatherPreciseType.py:11: note: Revealed type is "asyncio.futures.Future[builtins.list[Any]]" +_testAsyncioGatherPreciseType.py:11: note: Revealed type is "builtins.str" +_testAsyncioGatherPreciseType.py:12: note: Revealed type is "asyncio.futures.Future[builtins.list[None]]" [case testMultipleInheritanceWorksWithTupleTypeGeneric] from typing import SupportsAbs, NamedTuple diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 161f14e8aea7..48ed600b864c 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -988,14 +988,14 @@ class RegularClass: from typing import NamedTuple # TODO: make sure that nested classes in `NamedTuple` are supported: -class NamedTupleWithNestedClass(NamedTuple): +class NamedTupleWithNestedClass(NamedTuple): ... class Nested: x: int y: str = 'a' [out] from typing import NamedTuple -class NamedTupleWithNestedClass(NamedTuple): +class NamedTupleWithNestedClass(NamedTuple): ... class Nested: x: int y: str