Skip to content

Commit 268c837

Browse files
[mypyc] feat: unwrap NewType types to their base types for optimized code paths (#19497)
This PR adds special case logic for unwrapping NewType types to their actual type. This logic is currently working. a `NewType("name", str)` now generates the same code as a `str`. I wasn't entirely sure of the best way to test this, so I just tweaked the str tests to use a union of str and newtype str and validated that the IR still uses `str` and not `object` Almost all of my tests are running fine, but I get a strange mypy error in the str.count tests saying that `str` objects do not have a .count method? Of course a str object has a .count method. Do you think this might related to the typeshed stuff again? --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 68657d2 commit 268c837

File tree

3 files changed

+113
-41
lines changed

3 files changed

+113
-41
lines changed

mypyc/irbuild/mapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def type_to_rtype(self, typ: Type | None) -> RType:
7373

7474
typ = get_proper_type(typ)
7575
if isinstance(typ, Instance):
76+
if typ.type.is_newtype:
77+
# Unwrap NewType to its base type for rprimitive mapping
78+
assert len(typ.type.bases) == 1, typ.type.bases
79+
return self.type_to_rtype(typ.type.bases[0])
7680
if typ.type.fullname == "builtins.int":
7781
return int_rprimitive
7882
elif typ.type.fullname == "builtins.float":

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def rpartition(self, sep: str, /) -> Tuple[str, str, str]: ...
122122
def removeprefix(self, prefix: str, /) -> str: ...
123123
def removesuffix(self, suffix: str, /) -> str: ...
124124
def islower(self) -> bool: ...
125+
def count(self, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int: pass
125126

126127
class float:
127128
def __init__(self, x: object) -> None: pass

mypyc/test-data/irbuild-str.test

Lines changed: 108 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
[case testStrSplit]
2-
from typing import Optional, List
2+
from typing import NewType, Optional, List, Union
3+
NewStr = NewType("NewStr", str)
34

4-
def do_split(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]:
5+
def do_split(s: Union[str, NewStr], sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]:
56
if sep is not None:
67
if max_split is not None:
78
return s.split(sep, max_split)
89
else:
910
return s.split(sep)
1011
return s.split()
12+
[typing fixtures/typing-full.pyi]
1113
[out]
1214
def do_split(s, sep, max_split):
1315
s :: str
@@ -56,12 +58,15 @@ L9:
5658

5759

5860
[case testStrEquality]
61+
from typing import NewType, Union
62+
NewStr = NewType("NewStr", str)
5963
def eq(x: str, y: str) -> bool:
6064
return x == y
6165

62-
def neq(x: str, y: str) -> bool:
66+
def neq(x: str, y: Union[str, NewStr]) -> bool:
6367
return x != y
6468

69+
[typing fixtures/typing-full.pyi]
6570
[out]
6671
def eq(x, y):
6772
x, y :: str
@@ -79,13 +84,14 @@ L0:
7984
return r1
8085

8186
[case testStrReplace]
82-
from typing import Optional
83-
84-
def do_replace(s: str, old_substr: str, new_substr: str, max_count: Optional[int] = None) -> str:
87+
from typing import NewType, Optional, Union
88+
NewStr = NewType("NewStr", str)
89+
def do_replace(s: Union[str, NewStr], old_substr: str, new_substr: str, max_count: Optional[int] = None) -> str:
8590
if max_count is not None:
8691
return s.replace(old_substr, new_substr, max_count)
8792
else:
8893
return s.replace(old_substr, new_substr)
94+
[typing fixtures/typing-full.pyi]
8995
[out]
9096
def do_replace(s, old_substr, new_substr, max_count):
9197
s, old_substr, new_substr :: str
@@ -114,17 +120,19 @@ L5:
114120
unreachable
115121

116122
[case testStrStartswithEndswithTuple]
117-
from typing import Tuple
123+
from typing import NewType, Tuple, Union
124+
NewStr = NewType("NewStr", str)
118125

119-
def do_startswith(s1: str, s2: Tuple[str, ...]) -> bool:
126+
def do_startswith(s1: Union[str, NewStr], s2: Tuple[str, ...]) -> bool:
120127
return s1.startswith(s2)
121128

122-
def do_endswith(s1: str, s2: Tuple[str, ...]) -> bool:
129+
def do_endswith(s1: Union[str, NewStr], s2: Tuple[str, ...]) -> bool:
123130
return s1.endswith(s2)
124131

125-
def do_tuple_literal_args(s1: str) -> None:
132+
def do_tuple_literal_args(s1: Union[str, NewStr]) -> None:
126133
x = s1.startswith(("a", "b"))
127134
y = s1.endswith(("a", "b"))
135+
[typing fixtures/typing-full.pyi]
128136
[out]
129137
def do_startswith(s1, s2):
130138
s1 :: str
@@ -165,11 +173,14 @@ L0:
165173
return 1
166174

167175
[case testStrToBool]
168-
def is_true(x: str) -> bool:
176+
from typing import NewType, Union
177+
NewStr = NewType("NewStr", str)
178+
def is_true(x: Union[str, NewStr]) -> bool:
169179
if x:
170180
return True
171181
else:
172182
return False
183+
[typing fixtures/typing-full.pyi]
173184
[out]
174185
def is_true(x):
175186
x :: str
@@ -185,11 +196,14 @@ L3:
185196
unreachable
186197

187198
[case testStringFormatMethod]
188-
def f(s: str, num: int) -> None:
199+
from typing import NewType, Union
200+
NewStr = NewType("NewStr", str)
201+
def f(s: Union[str, NewStr], num: int) -> None:
189202
s1 = "Hi! I'm {}, and I'm {} years old.".format(s, num)
190203
s2 = ''.format()
191204
s3 = 'abc'.format()
192205
s4 = '}}{}{{{}}}{{{}'.format(num, num, num)
206+
[typing fixtures/typing-full.pyi]
193207
[out]
194208
def f(s, num):
195209
s :: str
@@ -217,11 +231,14 @@ L0:
217231
return 1
218232

219233
[case testFStrings_64bit]
220-
def f(var: str, num: int) -> None:
234+
from typing import NewType, Union
235+
NewStr = NewType("NewStr", str)
236+
def f(var: Union[str, NewStr], num: int) -> None:
221237
s1 = f"Hi! I'm {var}. I am {num} years old."
222238
s2 = f'Hello {var:>{num}}'
223239
s3 = f''
224240
s4 = f'abc'
241+
[typing fixtures/typing-full.pyi]
225242
[out]
226243
def f(var, num):
227244
var :: str
@@ -267,7 +284,9 @@ L0:
267284
return 1
268285

269286
[case testStringFormattingCStyle]
270-
def f(var: str, num: int) -> None:
287+
from typing import NewType, Union
288+
NewStr = NewType("NewStr", str)
289+
def f(var: Union[str, NewStr], num: int) -> None:
271290
s1 = "Hi! I'm %s." % var
272291
s2 = "I am %d years old." % num
273292
s3 = "Hi! I'm %s. I am %d years old." % (var, num)
@@ -322,7 +341,9 @@ L0:
322341
return 1
323342

324343
[case testEncode_64bit]
325-
def f(s: str) -> None:
344+
from typing import NewType, Union
345+
NewStr = NewType("NewStr", str)
346+
def f(s: Union[str, NewStr]) -> None:
326347
s.encode()
327348
s.encode('utf-8')
328349
s.encode('utf8', 'strict')
@@ -340,6 +361,7 @@ def f(s: str) -> None:
340361
s.encode(encoding=encoding, errors=errors)
341362
s.encode('latin2')
342363

364+
[typing fixtures/typing-full.pyi]
343365
[out]
344366
def f(s):
345367
s :: str
@@ -410,7 +432,9 @@ L0:
410432
return 1
411433

412434
[case testOrd]
413-
def str_ord(x: str) -> int:
435+
from typing import NewType, Union
436+
NewStr = NewType("NewStr", str)
437+
def str_ord(x: Union[str, NewStr]) -> int:
414438
return ord(x)
415439
def str_ord_literal() -> int:
416440
return ord("a")
@@ -420,6 +444,7 @@ def bytes_ord_literal() -> int:
420444
return ord(b"a")
421445
def any_ord(x) -> int:
422446
return ord(x)
447+
[typing fixtures/typing-full.pyi]
423448
[out]
424449
def str_ord(x):
425450
x :: str
@@ -459,13 +484,16 @@ L0:
459484
return r6
460485

461486
[case testStrip]
462-
def do_strip(s: str) -> None:
487+
from typing import NewType, Union
488+
NewStr = NewType("NewStr", str)
489+
def do_strip(s: Union[str, NewStr]) -> None:
463490
s.lstrip("x")
464491
s.strip("y")
465492
s.rstrip("z")
466493
s.lstrip()
467494
s.strip()
468495
s.rstrip()
496+
[typing fixtures/typing-full.pyi]
469497
[out]
470498
def do_strip(s):
471499
s, r0, r1, r2, r3, r4, r5, r6, r7, r8 :: str
@@ -481,60 +509,99 @@ L0:
481509
r8 = CPyStr_RStrip(s, 0)
482510
return 1
483511

484-
[case testCountAll]
512+
[case testCountAll_64bit]
513+
from typing import NewType, Union
514+
NewStr = NewType("NewStr", str)
485515
def do_count(s: str) -> int:
486-
return s.count("x") # type: ignore [attr-defined]
516+
return s.count("x")
517+
[typing fixtures/typing-full.pyi]
487518
[out]
488519
def do_count(s):
489520
s, r0 :: str
490521
r1 :: native_int
491-
r2 :: bit
492-
r3 :: object
493-
r4 :: int
522+
r2, r3, r4 :: bit
523+
r5, r6, r7 :: int
494524
L0:
495525
r0 = 'x'
496526
r1 = CPyStr_Count(s, r0, 0)
497527
r2 = r1 >= 0 :: signed
498-
r3 = box(native_int, r1)
499-
r4 = unbox(int, r3)
500-
return r4
528+
r3 = r1 <= 4611686018427387903 :: signed
529+
if r3 goto L1 else goto L2 :: bool
530+
L1:
531+
r4 = r1 >= -4611686018427387904 :: signed
532+
if r4 goto L3 else goto L2 :: bool
533+
L2:
534+
r5 = CPyTagged_FromInt64(r1)
535+
r6 = r5
536+
goto L4
537+
L3:
538+
r7 = r1 << 1
539+
r6 = r7
540+
L4:
541+
return r6
501542

502-
[case testCountStart]
543+
[case testCountStart_64bit]
544+
from typing import NewType, Union
545+
NewStr = NewType("NewStr", str)
503546
def do_count(s: str, start: int) -> int:
504-
return s.count("x", start) # type: ignore [attr-defined]
547+
return s.count("x", start)
548+
[typing fixtures/typing-full.pyi]
505549
[out]
506550
def do_count(s, start):
507551
s :: str
508552
start :: int
509553
r0 :: str
510554
r1 :: native_int
511-
r2 :: bit
512-
r3 :: object
513-
r4 :: int
555+
r2, r3, r4 :: bit
556+
r5, r6, r7 :: int
514557
L0:
515558
r0 = 'x'
516559
r1 = CPyStr_Count(s, r0, start)
517560
r2 = r1 >= 0 :: signed
518-
r3 = box(native_int, r1)
519-
r4 = unbox(int, r3)
520-
return r4
561+
r3 = r1 <= 4611686018427387903 :: signed
562+
if r3 goto L1 else goto L2 :: bool
563+
L1:
564+
r4 = r1 >= -4611686018427387904 :: signed
565+
if r4 goto L3 else goto L2 :: bool
566+
L2:
567+
r5 = CPyTagged_FromInt64(r1)
568+
r6 = r5
569+
goto L4
570+
L3:
571+
r7 = r1 << 1
572+
r6 = r7
573+
L4:
574+
return r6
521575

522-
[case testCountStartEnd]
576+
[case testCountStartEnd_64bit]
577+
from typing import NewType, Union
578+
NewStr = NewType("NewStr", str)
523579
def do_count(s: str, start: int, end: int) -> int:
524-
return s.count("x", start, end) # type: ignore [attr-defined]
580+
return s.count("x", start, end)
581+
[typing fixtures/typing-full.pyi]
525582
[out]
526583
def do_count(s, start, end):
527584
s :: str
528585
start, end :: int
529586
r0 :: str
530587
r1 :: native_int
531-
r2 :: bit
532-
r3 :: object
533-
r4 :: int
588+
r2, r3, r4 :: bit
589+
r5, r6, r7 :: int
534590
L0:
535591
r0 = 'x'
536592
r1 = CPyStr_CountFull(s, r0, start, end)
537593
r2 = r1 >= 0 :: signed
538-
r3 = box(native_int, r1)
539-
r4 = unbox(int, r3)
540-
return r4
594+
r3 = r1 <= 4611686018427387903 :: signed
595+
if r3 goto L1 else goto L2 :: bool
596+
L1:
597+
r4 = r1 >= -4611686018427387904 :: signed
598+
if r4 goto L3 else goto L2 :: bool
599+
L2:
600+
r5 = CPyTagged_FromInt64(r1)
601+
r6 = r5
602+
goto L4
603+
L3:
604+
r7 = r1 << 1
605+
r6 = r7
606+
L4:
607+
return r6

0 commit comments

Comments
 (0)