Skip to content

[mypyc] feat: true_dict_rprimitive for optimized dict fastpath usage #19499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
RUnion,
bytes_rprimitive,
dict_rprimitive,
exact_dict_rprimitive,
int_rprimitive,
is_float_rprimitive,
is_object_rprimitive,
Expand Down Expand Up @@ -176,6 +177,7 @@ def check_op_sources_valid(fn: FuncIR) -> list[FnError]:
int_rprimitive.name,
bytes_rprimitive.name,
str_rprimitive.name,
exact_dict_rprimitive.name,
dict_rprimitive.name,
list_rprimitive.name,
set_rprimitive.name,
Expand Down
2 changes: 1 addition & 1 deletion mypyc/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def function_annotations(func_ir: FuncIR, tree: MypyFile) -> dict[int, list[Anno
ann = "Dynamic method call."
elif name in op_hints:
ann = op_hints[name]
elif name in ("CPyDict_GetItem", "CPyDict_SetItem"):
elif name in ("CPyDict_GetItemUnsafe", "PyDict_SetItem"):
if (
isinstance(op.args[0], LoadStatic)
and isinstance(op.args[1], LoadLiteral)
Expand Down
18 changes: 16 additions & 2 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,15 @@ def __hash__(self) -> int:
"builtins.list", is_unboxed=False, is_refcounted=True, may_be_immortal=False
)

# Python dict object (or an instance of a subclass of dict).
# Python dict object.
exact_dict_rprimitive: Final = RPrimitive(
"builtins.dict[exact]", is_unboxed=False, is_refcounted=True
)
"""A primitive for dicts that are confirmed to be actual instances of builtins.dict, not a subclass."""

# An instance of a subclass of dict.
dict_rprimitive: Final = RPrimitive("builtins.dict", is_unboxed=False, is_refcounted=True)
"""A primitive that represents instances of builtins.dict or subclasses of dict."""

# Python set object (or an instance of a subclass of set).
set_rprimitive: Final = RPrimitive("builtins.set", is_unboxed=False, is_refcounted=True)
Expand Down Expand Up @@ -599,7 +606,14 @@ def is_list_rprimitive(rtype: RType) -> bool:


def is_dict_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict"
return isinstance(rtype, RPrimitive) and rtype.name in (
"builtins.dict",
"builtins.dict[exact]",
)


def is_exact_dict_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict[exact]"


def is_set_rprimitive(rtype: RType) -> bool:
Expand Down
12 changes: 7 additions & 5 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
RUnion,
bitmap_rprimitive,
c_pyssize_t_rprimitive,
dict_rprimitive,
exact_dict_rprimitive,
int_rprimitive,
is_float_rprimitive,
is_list_rprimitive,
Expand Down Expand Up @@ -124,7 +124,7 @@
)
from mypyc.irbuild.util import bytes_from_str, is_constant
from mypyc.options import CompilerOptions
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.dict_ops import dict_set_item_op, exact_dict_get_item_op
from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op
from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list
from mypyc.primitives.misc_ops import check_unpack_count_op, get_module_dict_op, import_op
Expand Down Expand Up @@ -435,6 +435,8 @@ def add_to_non_ext_dict(
) -> None:
# Add an attribute entry into the class dict of a non-extension class.
key_unicode = self.load_str(key)
# must use `dict_set_item_op` instead of `exact_dict_set_item_op` because
# it breaks enums, and probably other stuff, if we take the fast path.
self.primitive_op(dict_set_item_op, [non_ext.dict, key_unicode, val], line)

# It's important that accessing class dictionary items from multiple threads
Expand Down Expand Up @@ -470,7 +472,7 @@ def get_module(self, module: str, line: int) -> Value:
# Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :(
mod_dict = self.call_c(get_module_dict_op, [], line)
# Get module object from modules dict.
return self.primitive_op(dict_get_item_op, [mod_dict, self.load_str(module)], line)
return self.primitive_op(exact_dict_get_item_op, [mod_dict, self.load_str(module)], line)

def get_module_attr(self, module: str, attr: str, line: int) -> Value:
"""Look up an attribute of a module without storing it in the local namespace.
Expand Down Expand Up @@ -1378,10 +1380,10 @@ def load_global(self, expr: NameExpr) -> Value:
def load_global_str(self, name: str, line: int) -> Value:
_globals = self.load_globals_dict()
reg = self.load_str(name)
return self.primitive_op(dict_get_item_op, [_globals, reg], line)
return self.primitive_op(exact_dict_get_item_op, [_globals, reg], line)

def load_globals_dict(self) -> Value:
return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name))
return self.add(LoadStatic(exact_dict_rprimitive, "globals", self.module_name))

def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value:
module, _, name = fullname.rpartition(".")
Expand Down
14 changes: 8 additions & 6 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from mypyc.ir.rtypes import (
RType,
bool_rprimitive,
dict_rprimitive,
exact_dict_rprimitive,
is_none_rprimitive,
is_object_rprimitive,
is_optional_type,
Expand All @@ -66,7 +66,7 @@
)
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
from mypyc.irbuild.util import dataclass_type, get_func_def, is_constant, is_dataclass_decorator
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op
from mypyc.primitives.dict_ops import dict_new_op, exact_dict_set_item_op
from mypyc.primitives.generic_ops import (
iter_op,
next_op,
Expand Down Expand Up @@ -272,7 +272,7 @@ def finalize(self, ir: ClassIR) -> None:

# Add the non-extension class to the dict
self.builder.primitive_op(
dict_set_item_op,
exact_dict_set_item_op,
[
self.builder.load_globals_dict(),
self.builder.load_str(self.cdef.name),
Expand Down Expand Up @@ -488,7 +488,9 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value:

# Add it to the dict
builder.primitive_op(
dict_set_item_op, [builder.load_globals_dict(), builder.load_str(cdef.name), tp], cdef.line
exact_dict_set_item_op,
[builder.load_globals_dict(), builder.load_str(cdef.name), tp],
cdef.line,
)

return tp
Expand Down Expand Up @@ -609,7 +611,7 @@ def setup_non_ext_dict(
py_hasattr_op, [metaclass, builder.load_str("__prepare__")], cdef.line
)

non_ext_dict = Register(dict_rprimitive)
non_ext_dict = Register(exact_dict_rprimitive)

true_block, false_block, exit_block = BasicBlock(), BasicBlock(), BasicBlock()
builder.add_bool_branch(has_prepare, true_block, false_block)
Expand Down Expand Up @@ -672,7 +674,7 @@ def add_non_ext_class_attr_ann(
typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line))

key = builder.load_str(lvalue.name)
builder.primitive_op(dict_set_item_op, [non_ext.anns, key, typ], stmt.line)
builder.primitive_op(exact_dict_set_item_op, [non_ext.anns, key, typ], stmt.line)


def add_non_ext_class_attr(
Expand Down
6 changes: 3 additions & 3 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
)
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
from mypyc.primitives.dict_ops import dict_new_op, exact_dict_get_item_op, exact_dict_set_item_op
from mypyc.primitives.generic_ops import iter_op
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
Expand Down Expand Up @@ -183,7 +183,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
# instead load the module separately on each access.
mod_dict = builder.call_c(get_module_dict_op, [], expr.line)
obj = builder.primitive_op(
dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line
exact_dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line
)
return obj
else:
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehe
def gen_inner_stmts() -> None:
k = builder.accept(o.key)
v = builder.accept(o.value)
builder.primitive_op(dict_set_item_op, [builder.read(d), k, v], o.line)
builder.primitive_op(exact_dict_set_item_op, [builder.read(d), k, v], o.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
return builder.read(d)
Expand Down
54 changes: 48 additions & 6 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
c_pyssize_t_rprimitive,
int_rprimitive,
is_dict_rprimitive,
is_exact_dict_rprimitive,
is_fixed_width_rtype,
is_list_rprimitive,
is_sequence_rprimitive,
Expand All @@ -69,6 +70,11 @@
dict_next_key_op,
dict_next_value_op,
dict_value_iter_op,
exact_dict_check_size_op,
exact_dict_iter_fast_path_op,
exact_dict_next_item_op,
exact_dict_next_key_op,
exact_dict_next_value_op,
)
from mypyc.primitives.exc_ops import no_err_occurred_op, propagate_if_error_op
from mypyc.primitives.generic_ops import aiter_op, anext_op, iter_op, next_op
Expand Down Expand Up @@ -415,8 +421,10 @@ def make_for_loop_generator(
# Special case "for k in <dict>".
expr_reg = builder.accept(expr)
target_type = builder.get_dict_key_type(expr)

for_dict = ForDictionaryKeys(builder, index, body_block, loop_exit, line, nested)
for_loop_cls = (
ForExactDictionaryKeys if is_exact_dict_rprimitive(rtyp) else ForDictionaryKeys
)
for_dict = for_loop_cls(builder, index, body_block, loop_exit, line, nested)
for_dict.init(expr_reg, target_type)
return for_dict

Expand Down Expand Up @@ -498,13 +506,22 @@ def make_for_loop_generator(
for_dict_type: type[ForGenerator] | None = None
if expr.callee.name == "keys":
target_type = builder.get_dict_key_type(expr.callee.expr)
for_dict_type = ForDictionaryKeys
if is_exact_dict_rprimitive(rtype):
for_dict_type = ForExactDictionaryKeys
else:
for_dict_type = ForDictionaryKeys
elif expr.callee.name == "values":
target_type = builder.get_dict_value_type(expr.callee.expr)
for_dict_type = ForDictionaryValues
if is_exact_dict_rprimitive(rtype):
for_dict_type = ForExactDictionaryValues
else:
for_dict_type = ForDictionaryValues
else:
target_type = builder.get_dict_item_type(expr.callee.expr)
for_dict_type = ForDictionaryItems
if is_exact_dict_rprimitive(rtype):
for_dict_type = ForExactDictionaryItems
else:
for_dict_type = ForDictionaryItems
for_dict_gen = for_dict_type(builder, index, body_block, loop_exit, line, nested)
for_dict_gen.init(expr_reg, target_type)
return for_dict_gen
Expand Down Expand Up @@ -867,6 +884,7 @@ class ForDictionaryCommon(ForGenerator):

dict_next_op: ClassVar[CFunctionDescription]
dict_iter_op: ClassVar[CFunctionDescription]
dict_size_op: ClassVar[CFunctionDescription] = dict_check_size_op

def need_cleanup(self) -> bool:
# Technically, a dict subclass can raise an unrelated exception
Expand Down Expand Up @@ -913,7 +931,7 @@ def gen_step(self) -> None:
line = self.line
# Technically, we don't need a new primitive for this, but it is simpler.
builder.call_c(
dict_check_size_op,
self.dict_size_op,
[builder.read(self.expr_target, line), builder.read(self.size, line)],
line,
)
Expand Down Expand Up @@ -991,6 +1009,30 @@ def begin_body(self) -> None:
builder.assign(target, rvalue, line)


class ForExactDictionaryKeys(ForDictionaryKeys):
"""Generate optimized IR for a for loop over dictionary items without type checks."""

dict_next_op = exact_dict_next_key_op
dict_iter_op = exact_dict_iter_fast_path_op
dict_size_op = exact_dict_check_size_op


class ForExactDictionaryValues(ForDictionaryValues):
"""Generate optimized IR for a for loop over dictionary items without type checks."""

dict_next_op = exact_dict_next_value_op
dict_iter_op = exact_dict_iter_fast_path_op
dict_size_op = exact_dict_check_size_op


class ForExactDictionaryItems(ForDictionaryItems):
"""Generate optimized IR for a for loop over dictionary items without type checks."""

dict_next_op = exact_dict_next_item_op
dict_iter_op = exact_dict_iter_fast_path_op
dict_size_op = exact_dict_check_size_op


class ForRange(ForGenerator):
"""Generate optimized IR for a for loop over an integer range."""

Expand Down
28 changes: 17 additions & 11 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from mypyc.ir.rtypes import (
RInstance,
bool_rprimitive,
dict_rprimitive,
exact_dict_rprimitive,
int_rprimitive,
object_rprimitive,
)
Expand All @@ -76,7 +76,11 @@
)
from mypyc.irbuild.generator import gen_generator_func, gen_generator_func_body
from mypyc.irbuild.targets import AssignmentTarget
from mypyc.primitives.dict_ops import dict_get_method_with_none, dict_new_op, dict_set_item_op
from mypyc.primitives.dict_ops import (
dict_new_op,
exact_dict_get_method_with_none,
exact_dict_set_item_op,
)
from mypyc.primitives.generic_ops import py_setattr_op
from mypyc.primitives.misc_ops import register_function
from mypyc.primitives.registry import builtin_names
Expand Down Expand Up @@ -124,7 +128,7 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
if decorated_func is not None:
# Set the callable object representing the decorated function as a global.
builder.primitive_op(
dict_set_item_op,
exact_dict_set_item_op,
[builder.load_globals_dict(), builder.load_str(dec.func.name), decorated_func],
decorated_func.line,
)
Expand Down Expand Up @@ -797,10 +801,12 @@ def generate_singledispatch_dispatch_function(

arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line)
dispatch_cache = builder.builder.get_attr(
dispatch_func_obj, "dispatch_cache", dict_rprimitive, line
dispatch_func_obj, "dispatch_cache", exact_dict_rprimitive, line
)
call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock()
get_result = builder.primitive_op(dict_get_method_with_none, [dispatch_cache, arg_type], line)
get_result = builder.primitive_op(
exact_dict_get_method_with_none, [dispatch_cache, arg_type], line
)
is_not_none = builder.translate_is_op(get_result, builder.none_object(), "is not", line)
impl_to_use = Register(object_rprimitive)
builder.add_bool_branch(is_not_none, use_cache, call_find_impl)
Expand All @@ -813,7 +819,7 @@ def generate_singledispatch_dispatch_function(
find_impl = builder.load_module_attr_by_fullname("functools._find_impl", line)
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
uncached_impl = builder.py_call(find_impl, [arg_type, registry], line)
builder.primitive_op(dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line)
builder.primitive_op(exact_dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line)
builder.assign(impl_to_use, uncached_impl, line)
builder.goto(call_func)

Expand Down Expand Up @@ -877,8 +883,8 @@ def gen_dispatch_func_ir(
"""
builder.enter(FuncInfo(fitem, dispatch_name))
setup_callable_class(builder)
builder.fn_info.callable_class.ir.attributes["registry"] = dict_rprimitive
builder.fn_info.callable_class.ir.attributes["dispatch_cache"] = dict_rprimitive
builder.fn_info.callable_class.ir.attributes["registry"] = exact_dict_rprimitive
builder.fn_info.callable_class.ir.attributes["dispatch_cache"] = exact_dict_rprimitive
builder.fn_info.callable_class.ir.has_dict = True
builder.fn_info.callable_class.ir.needs_getseters = True
generate_singledispatch_callable_class_ctor(builder)
Expand Down Expand Up @@ -941,7 +947,7 @@ def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo)


def load_singledispatch_registry(builder: IRBuilder, dispatch_func_obj: Value, line: int) -> Value:
return builder.builder.get_attr(dispatch_func_obj, "registry", dict_rprimitive, line)
return builder.builder.get_attr(dispatch_func_obj, "registry", exact_dict_rprimitive, line)


def singledispatch_main_func_name(orig_name: str) -> str:
Expand Down Expand Up @@ -990,9 +996,9 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None:
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
for typ in types:
loaded_type = load_type(builder, typ, None, line)
builder.primitive_op(dict_set_item_op, [registry, loaded_type, to_insert], line)
builder.primitive_op(exact_dict_set_item_op, [registry, loaded_type, to_insert], line)
dispatch_cache = builder.builder.get_attr(
dispatch_func_obj, "dispatch_cache", dict_rprimitive, line
dispatch_func_obj, "dispatch_cache", exact_dict_rprimitive, line
)
builder.gen_method_call(dispatch_cache, "clear", [], None, line)

Expand Down
Loading
Loading