Skip to content

Commit 4d38602

Browse files
committed
feat: true dict rprimitive
1 parent 2e5d7ee commit 4d38602

27 files changed

+994
-467
lines changed

mypyc/annotate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def function_annotations(func_ir: FuncIR, tree: MypyFile) -> dict[int, list[Anno
215215
ann = "Dynamic method call."
216216
elif name in op_hints:
217217
ann = op_hints[name]
218-
elif name in ("CPyDict_GetItem", "CPyDict_SetItem"):
218+
elif name in ("CPyDict_GetItemUnsafe", "PyDict_SetItem"):
219219
if (
220220
isinstance(op.args[0], LoadStatic)
221221
and isinstance(op.args[1], LoadLiteral)

mypyc/ir/rtypes.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,15 @@ def __hash__(self) -> int:
487487
"builtins.list", is_unboxed=False, is_refcounted=True, may_be_immortal=False
488488
)
489489

490-
# Python dict object (or an instance of a subclass of dict).
490+
# Python dict object.
491+
true_dict_rprimitive: Final = RPrimitive(
492+
"builtins.dict[confirmed]", is_unboxed=False, is_refcounted=True
493+
)
494+
"""A primitive for dicts that are confirmed to be actual instances of builtins.dict, not a subclass."""
495+
496+
# An instance of a subclass of dict.
491497
dict_rprimitive: Final = RPrimitive("builtins.dict", is_unboxed=False, is_refcounted=True)
498+
"""A primitive that represents instances of builtins.dict or subclasses of dict."""
492499

493500
# Python set object (or an instance of a subclass of set).
494501
set_rprimitive: Final = RPrimitive("builtins.set", is_unboxed=False, is_refcounted=True)
@@ -599,7 +606,14 @@ def is_list_rprimitive(rtype: RType) -> bool:
599606

600607

601608
def is_dict_rprimitive(rtype: RType) -> bool:
602-
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict"
609+
return isinstance(rtype, RPrimitive) and rtype.name in (
610+
"builtins.dict",
611+
"builtins.dict[confirmed]",
612+
)
613+
614+
615+
def is_true_dict_rprimitive(rtype: RType) -> bool:
616+
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict[confirmed]"
603617

604618

605619
def is_set_rprimitive(rtype: RType) -> bool:

mypyc/irbuild/builder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@
9292
RUnion,
9393
bitmap_rprimitive,
9494
c_pyssize_t_rprimitive,
95-
dict_rprimitive,
9695
int_rprimitive,
9796
is_float_rprimitive,
9897
is_list_rprimitive,
@@ -103,6 +102,7 @@
103102
none_rprimitive,
104103
object_rprimitive,
105104
str_rprimitive,
105+
true_dict_rprimitive,
106106
)
107107
from mypyc.irbuild.context import FuncInfo, ImplicitClass
108108
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
@@ -124,7 +124,7 @@
124124
)
125125
from mypyc.irbuild.util import bytes_from_str, is_constant
126126
from mypyc.options import CompilerOptions
127-
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
127+
from mypyc.primitives.dict_ops import dict_set_item_op, true_dict_get_item_op
128128
from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op
129129
from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list
130130
from mypyc.primitives.misc_ops import check_unpack_count_op, get_module_dict_op, import_op
@@ -431,6 +431,8 @@ def add_to_non_ext_dict(
431431
) -> None:
432432
# Add an attribute entry into the class dict of a non-extension class.
433433
key_unicode = self.load_str(key)
434+
# must use `dict_set_item_op` instead of `true_dict_set_item_op` because
435+
# it breaks enums, and probably other stuff, if we take the fast path.
434436
self.primitive_op(dict_set_item_op, [non_ext.dict, key_unicode, val], line)
435437

436438
def gen_import(self, id: str, line: int) -> None:
@@ -462,7 +464,7 @@ def get_module(self, module: str, line: int) -> Value:
462464
# Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :(
463465
mod_dict = self.call_c(get_module_dict_op, [], line)
464466
# Get module object from modules dict.
465-
return self.primitive_op(dict_get_item_op, [mod_dict, self.load_str(module)], line)
467+
return self.primitive_op(true_dict_get_item_op, [mod_dict, self.load_str(module)], line)
466468

467469
def get_module_attr(self, module: str, attr: str, line: int) -> Value:
468470
"""Look up an attribute of a module without storing it in the local namespace.
@@ -1370,10 +1372,10 @@ def load_global(self, expr: NameExpr) -> Value:
13701372
def load_global_str(self, name: str, line: int) -> Value:
13711373
_globals = self.load_globals_dict()
13721374
reg = self.load_str(name)
1373-
return self.primitive_op(dict_get_item_op, [_globals, reg], line)
1375+
return self.primitive_op(true_dict_get_item_op, [_globals, reg], line)
13741376

13751377
def load_globals_dict(self) -> Value:
1376-
return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name))
1378+
return self.add(LoadStatic(true_dict_rprimitive, "globals", self.module_name))
13771379

13781380
def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value:
13791381
module, _, name = fullname.rpartition(".")

mypyc/irbuild/classdef.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@
5050
from mypyc.ir.rtypes import (
5151
RType,
5252
bool_rprimitive,
53-
dict_rprimitive,
5453
is_none_rprimitive,
5554
is_object_rprimitive,
5655
is_optional_type,
5756
object_rprimitive,
57+
true_dict_rprimitive,
5858
)
5959
from mypyc.irbuild.builder import IRBuilder, create_type_params
6060
from mypyc.irbuild.function import (
@@ -66,7 +66,7 @@
6666
)
6767
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
6868
from mypyc.irbuild.util import dataclass_type, get_func_def, is_constant, is_dataclass_decorator
69-
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op
69+
from mypyc.primitives.dict_ops import dict_new_op, true_dict_set_item_op
7070
from mypyc.primitives.generic_ops import (
7171
iter_op,
7272
next_op,
@@ -269,7 +269,7 @@ def finalize(self, ir: ClassIR) -> None:
269269

270270
# Add the non-extension class to the dict
271271
self.builder.primitive_op(
272-
dict_set_item_op,
272+
true_dict_set_item_op,
273273
[
274274
self.builder.load_globals_dict(),
275275
self.builder.load_str(self.cdef.name),
@@ -480,7 +480,9 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value:
480480

481481
# Add it to the dict
482482
builder.primitive_op(
483-
dict_set_item_op, [builder.load_globals_dict(), builder.load_str(cdef.name), tp], cdef.line
483+
true_dict_set_item_op,
484+
[builder.load_globals_dict(), builder.load_str(cdef.name), tp],
485+
cdef.line,
484486
)
485487

486488
return tp
@@ -601,7 +603,7 @@ def setup_non_ext_dict(
601603
py_hasattr_op, [metaclass, builder.load_str("__prepare__")], cdef.line
602604
)
603605

604-
non_ext_dict = Register(dict_rprimitive)
606+
non_ext_dict = Register(true_dict_rprimitive)
605607

606608
true_block, false_block, exit_block = BasicBlock(), BasicBlock(), BasicBlock()
607609
builder.add_bool_branch(has_prepare, true_block, false_block)
@@ -664,7 +666,7 @@ def add_non_ext_class_attr_ann(
664666
typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line))
665667

666668
key = builder.load_str(lvalue.name)
667-
builder.primitive_op(dict_set_item_op, [non_ext.anns, key, typ], stmt.line)
669+
builder.primitive_op(true_dict_set_item_op, [non_ext.anns, key, typ], stmt.line)
668670

669671

670672
def add_non_ext_class_attr(

mypyc/irbuild/expression.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
)
9898
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
9999
from mypyc.primitives.bytes_ops import bytes_slice_op
100-
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
100+
from mypyc.primitives.dict_ops import dict_new_op, true_dict_get_item_op, true_dict_set_item_op
101101
from mypyc.primitives.generic_ops import iter_op
102102
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
103103
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
@@ -183,7 +183,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
183183
# instead load the module separately on each access.
184184
mod_dict = builder.call_c(get_module_dict_op, [], expr.line)
185185
obj = builder.primitive_op(
186-
dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line
186+
true_dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line
187187
)
188188
return obj
189189
else:
@@ -1030,7 +1030,7 @@ def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehe
10301030
def gen_inner_stmts() -> None:
10311031
k = builder.accept(o.key)
10321032
v = builder.accept(o.value)
1033-
builder.primitive_op(dict_set_item_op, [builder.read(d), k, v], o.line)
1033+
builder.primitive_op(true_dict_set_item_op, [builder.read(d), k, v], o.line)
10341034

10351035
comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
10361036
return builder.read(d)

mypyc/irbuild/for_helpers.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
is_sequence_rprimitive,
5353
is_short_int_rprimitive,
5454
is_str_rprimitive,
55+
is_true_dict_rprimitive,
5556
is_tuple_rprimitive,
5657
object_pointer_rprimitive,
5758
object_rprimitive,
@@ -69,6 +70,11 @@
6970
dict_next_key_op,
7071
dict_next_value_op,
7172
dict_value_iter_op,
73+
true_dict_check_size_op,
74+
true_dict_iter_fast_path_op,
75+
true_dict_next_item_op,
76+
true_dict_next_key_op,
77+
true_dict_next_value_op,
7278
)
7379
from mypyc.primitives.exc_ops import no_err_occurred_op, propagate_if_error_op
7480
from mypyc.primitives.generic_ops import aiter_op, anext_op, iter_op, next_op
@@ -415,8 +421,10 @@ def make_for_loop_generator(
415421
# Special case "for k in <dict>".
416422
expr_reg = builder.accept(expr)
417423
target_type = builder.get_dict_key_type(expr)
418-
419-
for_dict = ForDictionaryKeys(builder, index, body_block, loop_exit, line, nested)
424+
for_loop_cls = (
425+
ForTrueDictionaryKeys if is_true_dict_rprimitive(rtyp) else ForDictionaryKeys
426+
)
427+
for_dict = for_loop_cls(builder, index, body_block, loop_exit, line, nested)
420428
for_dict.init(expr_reg, target_type)
421429
return for_dict
422430

@@ -498,13 +506,22 @@ def make_for_loop_generator(
498506
for_dict_type: type[ForGenerator] | None = None
499507
if expr.callee.name == "keys":
500508
target_type = builder.get_dict_key_type(expr.callee.expr)
501-
for_dict_type = ForDictionaryKeys
509+
if is_true_dict_rprimitive(rtype):
510+
for_dict_type = ForTrueDictionaryKeys
511+
else:
512+
for_dict_type = ForDictionaryKeys
502513
elif expr.callee.name == "values":
503514
target_type = builder.get_dict_value_type(expr.callee.expr)
504-
for_dict_type = ForDictionaryValues
515+
if is_true_dict_rprimitive(rtype):
516+
for_dict_type = ForTrueDictionaryValues
517+
else:
518+
for_dict_type = ForDictionaryValues
505519
else:
506520
target_type = builder.get_dict_item_type(expr.callee.expr)
507-
for_dict_type = ForDictionaryItems
521+
if is_true_dict_rprimitive(rtype):
522+
for_dict_type = ForTrueDictionaryItems
523+
else:
524+
for_dict_type = ForDictionaryItems
508525
for_dict_gen = for_dict_type(builder, index, body_block, loop_exit, line, nested)
509526
for_dict_gen.init(expr_reg, target_type)
510527
return for_dict_gen
@@ -867,6 +884,7 @@ class ForDictionaryCommon(ForGenerator):
867884

868885
dict_next_op: ClassVar[CFunctionDescription]
869886
dict_iter_op: ClassVar[CFunctionDescription]
887+
dict_size_op: ClassVar[CFunctionDescription] = dict_check_size_op
870888

871889
def need_cleanup(self) -> bool:
872890
# Technically, a dict subclass can raise an unrelated exception
@@ -913,7 +931,7 @@ def gen_step(self) -> None:
913931
line = self.line
914932
# Technically, we don't need a new primitive for this, but it is simpler.
915933
builder.call_c(
916-
dict_check_size_op,
934+
self.dict_size_op,
917935
[builder.read(self.expr_target, line), builder.read(self.size, line)],
918936
line,
919937
)
@@ -991,6 +1009,30 @@ def begin_body(self) -> None:
9911009
builder.assign(target, rvalue, line)
9921010

9931011

1012+
class ForTrueDictionaryKeys(ForDictionaryKeys):
1013+
"""Generate optimized IR for a for loop over dictionary items without type checks."""
1014+
1015+
dict_next_op = true_dict_next_key_op
1016+
dict_iter_op = true_dict_iter_fast_path_op
1017+
dict_size_op = true_dict_check_size_op
1018+
1019+
1020+
class ForTrueDictionaryValues(ForDictionaryValues):
1021+
"""Generate optimized IR for a for loop over dictionary items without type checks."""
1022+
1023+
dict_next_op = true_dict_next_value_op
1024+
dict_iter_op = true_dict_iter_fast_path_op
1025+
dict_size_op = true_dict_check_size_op
1026+
1027+
1028+
class ForTrueDictionaryItems(ForDictionaryItems):
1029+
"""Generate optimized IR for a for loop over dictionary items without type checks."""
1030+
1031+
dict_next_op = true_dict_next_item_op
1032+
dict_iter_op = true_dict_iter_fast_path_op
1033+
dict_size_op = true_dict_check_size_op
1034+
1035+
9941036
class ForRange(ForGenerator):
9951037
"""Generate optimized IR for a for loop over an integer range."""
9961038

mypyc/irbuild/function.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@
5656
from mypyc.ir.rtypes import (
5757
RInstance,
5858
bool_rprimitive,
59-
dict_rprimitive,
6059
int_rprimitive,
6160
object_rprimitive,
61+
true_dict_rprimitive,
6262
)
6363
from mypyc.irbuild.builder import IRBuilder, calculate_arg_defaults, gen_arg_defaults
6464
from mypyc.irbuild.callable_class import (
@@ -76,7 +76,11 @@
7676
)
7777
from mypyc.irbuild.generator import gen_generator_func, gen_generator_func_body
7878
from mypyc.irbuild.targets import AssignmentTarget
79-
from mypyc.primitives.dict_ops import dict_get_method_with_none, dict_new_op, dict_set_item_op
79+
from mypyc.primitives.dict_ops import (
80+
dict_new_op,
81+
true_dict_get_method_with_none,
82+
true_dict_set_item_op,
83+
)
8084
from mypyc.primitives.generic_ops import py_setattr_op
8185
from mypyc.primitives.misc_ops import register_function
8286
from mypyc.primitives.registry import builtin_names
@@ -124,7 +128,7 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
124128
if decorated_func is not None:
125129
# Set the callable object representing the decorated function as a global.
126130
builder.primitive_op(
127-
dict_set_item_op,
131+
true_dict_set_item_op,
128132
[builder.load_globals_dict(), builder.load_str(dec.func.name), decorated_func],
129133
decorated_func.line,
130134
)
@@ -797,10 +801,12 @@ def generate_singledispatch_dispatch_function(
797801

798802
arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line)
799803
dispatch_cache = builder.builder.get_attr(
800-
dispatch_func_obj, "dispatch_cache", dict_rprimitive, line
804+
dispatch_func_obj, "dispatch_cache", true_dict_rprimitive, line
801805
)
802806
call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock()
803-
get_result = builder.primitive_op(dict_get_method_with_none, [dispatch_cache, arg_type], line)
807+
get_result = builder.primitive_op(
808+
true_dict_get_method_with_none, [dispatch_cache, arg_type], line
809+
)
804810
is_not_none = builder.translate_is_op(get_result, builder.none_object(), "is not", line)
805811
impl_to_use = Register(object_rprimitive)
806812
builder.add_bool_branch(is_not_none, use_cache, call_find_impl)
@@ -813,7 +819,7 @@ def generate_singledispatch_dispatch_function(
813819
find_impl = builder.load_module_attr_by_fullname("functools._find_impl", line)
814820
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
815821
uncached_impl = builder.py_call(find_impl, [arg_type, registry], line)
816-
builder.primitive_op(dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line)
822+
builder.primitive_op(true_dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line)
817823
builder.assign(impl_to_use, uncached_impl, line)
818824
builder.goto(call_func)
819825

@@ -877,8 +883,8 @@ def gen_dispatch_func_ir(
877883
"""
878884
builder.enter(FuncInfo(fitem, dispatch_name))
879885
setup_callable_class(builder)
880-
builder.fn_info.callable_class.ir.attributes["registry"] = dict_rprimitive
881-
builder.fn_info.callable_class.ir.attributes["dispatch_cache"] = dict_rprimitive
886+
builder.fn_info.callable_class.ir.attributes["registry"] = true_dict_rprimitive
887+
builder.fn_info.callable_class.ir.attributes["dispatch_cache"] = true_dict_rprimitive
882888
builder.fn_info.callable_class.ir.has_dict = True
883889
builder.fn_info.callable_class.ir.needs_getseters = True
884890
generate_singledispatch_callable_class_ctor(builder)
@@ -941,7 +947,7 @@ def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo)
941947

942948

943949
def load_singledispatch_registry(builder: IRBuilder, dispatch_func_obj: Value, line: int) -> Value:
944-
return builder.builder.get_attr(dispatch_func_obj, "registry", dict_rprimitive, line)
950+
return builder.builder.get_attr(dispatch_func_obj, "registry", true_dict_rprimitive, line)
945951

946952

947953
def singledispatch_main_func_name(orig_name: str) -> str:
@@ -990,9 +996,9 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None:
990996
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
991997
for typ in types:
992998
loaded_type = load_type(builder, typ, None, line)
993-
builder.primitive_op(dict_set_item_op, [registry, loaded_type, to_insert], line)
999+
builder.primitive_op(true_dict_set_item_op, [registry, loaded_type, to_insert], line)
9941000
dispatch_cache = builder.builder.get_attr(
995-
dispatch_func_obj, "dispatch_cache", dict_rprimitive, line
1001+
dispatch_func_obj, "dispatch_cache", true_dict_rprimitive, line
9961002
)
9971003
builder.gen_method_call(dispatch_cache, "clear", [], None, line)
9981004

0 commit comments

Comments
 (0)