Skip to content

[wip] [mypyc] feat: quasi-constant folding for DictExpr and TupleExpr #19542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
Draft
6 changes: 5 additions & 1 deletion mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def load_type_var(self, name: str, line: int) -> Value:
)
)

def load_literal_value(self, val: int | str | bytes | float | complex | bool) -> Value:
def load_literal_value(self, val: int | str | bytes | float | complex | bool | tuple[Any, ...], dict[Any, Any]) -> Value:
"""Load value of a final name, class-level attribute, or constant folded expression."""
if isinstance(val, bool):
if val:
Expand All @@ -618,6 +618,10 @@ def load_literal_value(self, val: int | str | bytes | float | complex | bool) ->
return self.builder.load_bytes(val)
elif isinstance(val, complex):
return self.builder.load_complex(val)
elif isinstance(val, tuple):
return self.builder.load_tuple(val)
elif isinstance(val, dict):
return self.builder.load_dict(val)
else:
assert False, "Unsupported literal value"

Expand Down
29 changes: 25 additions & 4 deletions mypyc/irbuild/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,30 @@

from __future__ import annotations

from typing import Final, Union
from typing import Any, Final, Union

from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
from mypy.nodes import (
BytesExpr,
ComplexExpr,
DictExpr,
Expression,
FloatExpr,
IntExpr,
MemberExpr,
NameExpr,
OpExpr,
StrExpr,
TupleExpr,
UnaryExpr,
Var,
)
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.util import bytes_from_str

# All possible result types of constant folding
ConstantValue = Union[int, float, complex, str, bytes]
CONST_TYPES: Final = (int, float, complex, str, bytes)
ConstantValue = Union[int, float, complex, str, bytes, tuple[Any, ...], dict[Any, Any]]
CONST_TYPES: Final = (int, float, complex, str, bytes, tuple, dict)


def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
Expand Down Expand Up @@ -72,6 +74,25 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue |
value = constant_fold_expr(builder, expr.expr)
if value is not None and not isinstance(value, bytes):
return constant_fold_unary_op(expr.op, value)
elif isinstance(expr, TupleExpr):
folded = tuple(constant_fold_expr(builder, item_expr) for item_expr in expr.items)
if None not in folded:
return folded
elif isinstance(expr, DictExpr):
# NOTE: the builder can't simply use a dict constant like it can with other constants, since dicts are mutable.
# TODO: make the builder load the dict 'constant' by calling copy on a prebuilt constant template instead of building from scratch each time
folded = {
constant_fold_expr(builder, key_expr): constant_fold_expr(builder, value_expr)
for key_expr, value_expr in expr.items
}
if (
len(folded) == len(expr.items)
and None not in folded.keys()
and None not in folded.values()
):
return folded

# TODO use a placeholder instead of None so we can include None in folded tuples/dicts
return None


Expand All @@ -82,7 +103,7 @@ def constant_fold_binary_op_extended(

mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
"""
if not isinstance(left, bytes) and not isinstance(right, bytes):
if not isinstance(left, (bytes, tuple, dict)) and not isinstance(right, (bytes, tuple, dict)):
return constant_fold_binary_op(op, left, right)

if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
Expand Down
11 changes: 10 additions & 1 deletion mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import sys
from collections.abc import Sequence
from typing import Callable, Final, Optional
from typing import Any, Callable, Final, Optional

from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
Expand Down Expand Up @@ -121,6 +121,7 @@
pointer_rprimitive,
short_int_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
from mypyc.irbuild.util import concrete_arg_kind
from mypyc.options import CompilerOptions
Expand Down Expand Up @@ -1281,6 +1282,14 @@ def load_complex(self, value: complex) -> Value:
"""Load a complex literal value."""
return self.add(LoadLiteral(value, object_rprimitive))

def load_tuple(
self, value: tuple[Any, ...]
) -> Value: # should this be RTuple? conditional RTuple when length is known?
return self.add(LoadLiteral(value, tuple_rprimitive))

def load_dict(self, value: dict[Any, Any]) -> Value:
return self.add(LoadLiteral(value, dict_rprimitive))

def load_static_checked(
self,
typ: RType,
Expand Down
Loading