Skip to content

[mypyc] feat: unwrap NewType types to their base types for optimized code paths #19497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def type_to_rtype(self, typ: Type | None) -> RType:

typ = get_proper_type(typ)
if isinstance(typ, Instance):
if typ.type.is_newtype:
# Unwrap NewType to its base type for rprimitive mapping
assert len(typ.type.bases) == 1, typ.type.bases
return self.type_to_rtype(typ.type.bases[0])
if typ.type.fullname == "builtins.int":
return int_rprimitive
elif typ.type.fullname == "builtins.float":
Expand Down
1 change: 1 addition & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def rpartition(self, sep: str, /) -> Tuple[str, str, str]: ...
def removeprefix(self, prefix: str, /) -> str: ...
def removesuffix(self, suffix: str, /) -> str: ...
def islower(self) -> bool: ...
def count(self, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int: pass

class float:
def __init__(self, x: object) -> None: pass
Expand Down
149 changes: 108 additions & 41 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
[case testStrSplit]
from typing import Optional, List
from typing import NewType, Optional, List, Union
NewStr = NewType("NewStr", str)

def do_split(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]:
def do_split(s: Union[str, NewStr], sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]:
if sep is not None:
if max_split is not None:
return s.split(sep, max_split)
else:
return s.split(sep)
return s.split()
[typing fixtures/typing-full.pyi]
[out]
def do_split(s, sep, max_split):
s :: str
Expand Down Expand Up @@ -56,12 +58,15 @@ L9:


[case testStrEquality]
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def eq(x: str, y: str) -> bool:
return x == y

def neq(x: str, y: str) -> bool:
def neq(x: str, y: Union[str, NewStr]) -> bool:
return x != y

[typing fixtures/typing-full.pyi]
[out]
def eq(x, y):
x, y :: str
Expand All @@ -79,13 +84,14 @@ L0:
return r1

[case testStrReplace]
from typing import Optional

def do_replace(s: str, old_substr: str, new_substr: str, max_count: Optional[int] = None) -> str:
from typing import NewType, Optional, Union
NewStr = NewType("NewStr", str)
def do_replace(s: Union[str, NewStr], old_substr: str, new_substr: str, max_count: Optional[int] = None) -> str:
if max_count is not None:
return s.replace(old_substr, new_substr, max_count)
else:
return s.replace(old_substr, new_substr)
[typing fixtures/typing-full.pyi]
[out]
def do_replace(s, old_substr, new_substr, max_count):
s, old_substr, new_substr :: str
Expand Down Expand Up @@ -114,17 +120,19 @@ L5:
unreachable

[case testStrStartswithEndswithTuple]
from typing import Tuple
from typing import NewType, Tuple, Union
NewStr = NewType("NewStr", str)

def do_startswith(s1: str, s2: Tuple[str, ...]) -> bool:
def do_startswith(s1: Union[str, NewStr], s2: Tuple[str, ...]) -> bool:
return s1.startswith(s2)

def do_endswith(s1: str, s2: Tuple[str, ...]) -> bool:
def do_endswith(s1: Union[str, NewStr], s2: Tuple[str, ...]) -> bool:
return s1.endswith(s2)

def do_tuple_literal_args(s1: str) -> None:
def do_tuple_literal_args(s1: Union[str, NewStr]) -> None:
x = s1.startswith(("a", "b"))
y = s1.endswith(("a", "b"))
[typing fixtures/typing-full.pyi]
[out]
def do_startswith(s1, s2):
s1 :: str
Expand Down Expand Up @@ -165,11 +173,14 @@ L0:
return 1

[case testStrToBool]
def is_true(x: str) -> bool:
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def is_true(x: Union[str, NewStr]) -> bool:
if x:
return True
else:
return False
[typing fixtures/typing-full.pyi]
[out]
def is_true(x):
x :: str
Expand All @@ -185,11 +196,14 @@ L3:
unreachable

[case testStringFormatMethod]
def f(s: str, num: int) -> None:
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def f(s: Union[str, NewStr], num: int) -> None:
s1 = "Hi! I'm {}, and I'm {} years old.".format(s, num)
s2 = ''.format()
s3 = 'abc'.format()
s4 = '}}{}{{{}}}{{{}'.format(num, num, num)
[typing fixtures/typing-full.pyi]
[out]
def f(s, num):
s :: str
Expand Down Expand Up @@ -217,11 +231,14 @@ L0:
return 1

[case testFStrings_64bit]
def f(var: str, num: int) -> None:
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def f(var: Union[str, NewStr], num: int) -> None:
s1 = f"Hi! I'm {var}. I am {num} years old."
s2 = f'Hello {var:>{num}}'
s3 = f''
s4 = f'abc'
[typing fixtures/typing-full.pyi]
[out]
def f(var, num):
var :: str
Expand Down Expand Up @@ -267,7 +284,9 @@ L0:
return 1

[case testStringFormattingCStyle]
def f(var: str, num: int) -> None:
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def f(var: Union[str, NewStr], num: int) -> None:
s1 = "Hi! I'm %s." % var
s2 = "I am %d years old." % num
s3 = "Hi! I'm %s. I am %d years old." % (var, num)
Expand Down Expand Up @@ -322,7 +341,9 @@ L0:
return 1

[case testEncode_64bit]
def f(s: str) -> None:
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def f(s: Union[str, NewStr]) -> None:
s.encode()
s.encode('utf-8')
s.encode('utf8', 'strict')
Expand All @@ -340,6 +361,7 @@ def f(s: str) -> None:
s.encode(encoding=encoding, errors=errors)
s.encode('latin2')

[typing fixtures/typing-full.pyi]
[out]
def f(s):
s :: str
Expand Down Expand Up @@ -410,7 +432,9 @@ L0:
return 1

[case testOrd]
def str_ord(x: str) -> int:
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def str_ord(x: Union[str, NewStr]) -> int:
return ord(x)
def str_ord_literal() -> int:
return ord("a")
Expand All @@ -420,6 +444,7 @@ def bytes_ord_literal() -> int:
return ord(b"a")
def any_ord(x) -> int:
return ord(x)
[typing fixtures/typing-full.pyi]
[out]
def str_ord(x):
x :: str
Expand Down Expand Up @@ -459,13 +484,16 @@ L0:
return r6

[case testStrip]
def do_strip(s: str) -> None:
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def do_strip(s: Union[str, NewStr]) -> None:
s.lstrip("x")
s.strip("y")
s.rstrip("z")
s.lstrip()
s.strip()
s.rstrip()
[typing fixtures/typing-full.pyi]
[out]
def do_strip(s):
s, r0, r1, r2, r3, r4, r5, r6, r7, r8 :: str
Expand All @@ -481,60 +509,99 @@ L0:
r8 = CPyStr_RStrip(s, 0)
return 1

[case testCountAll]
[case testCountAll_64bit]
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def do_count(s: str) -> int:
return s.count("x") # type: ignore [attr-defined]
return s.count("x")
[typing fixtures/typing-full.pyi]
[out]
def do_count(s):
s, r0 :: str
r1 :: native_int
r2 :: bit
r3 :: object
r4 :: int
r2, r3, r4 :: bit
r5, r6, r7 :: int
L0:
r0 = 'x'
r1 = CPyStr_Count(s, r0, 0)
r2 = r1 >= 0 :: signed
r3 = box(native_int, r1)
r4 = unbox(int, r3)
return r4
r3 = r1 <= 4611686018427387903 :: signed
if r3 goto L1 else goto L2 :: bool
L1:
r4 = r1 >= -4611686018427387904 :: signed
if r4 goto L3 else goto L2 :: bool
L2:
r5 = CPyTagged_FromInt64(r1)
r6 = r5
goto L4
L3:
r7 = r1 << 1
r6 = r7
L4:
return r6

[case testCountStart]
[case testCountStart_64bit]
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def do_count(s: str, start: int) -> int:
return s.count("x", start) # type: ignore [attr-defined]
return s.count("x", start)
[typing fixtures/typing-full.pyi]
[out]
def do_count(s, start):
s :: str
start :: int
r0 :: str
r1 :: native_int
r2 :: bit
r3 :: object
r4 :: int
r2, r3, r4 :: bit
r5, r6, r7 :: int
L0:
r0 = 'x'
r1 = CPyStr_Count(s, r0, start)
r2 = r1 >= 0 :: signed
r3 = box(native_int, r1)
r4 = unbox(int, r3)
return r4
r3 = r1 <= 4611686018427387903 :: signed
if r3 goto L1 else goto L2 :: bool
L1:
r4 = r1 >= -4611686018427387904 :: signed
if r4 goto L3 else goto L2 :: bool
L2:
r5 = CPyTagged_FromInt64(r1)
r6 = r5
goto L4
L3:
r7 = r1 << 1
r6 = r7
L4:
return r6

[case testCountStartEnd]
[case testCountStartEnd_64bit]
from typing import NewType, Union
NewStr = NewType("NewStr", str)
def do_count(s: str, start: int, end: int) -> int:
return s.count("x", start, end) # type: ignore [attr-defined]
return s.count("x", start, end)
[typing fixtures/typing-full.pyi]
[out]
def do_count(s, start, end):
s :: str
start, end :: int
r0 :: str
r1 :: native_int
r2 :: bit
r3 :: object
r4 :: int
r2, r3, r4 :: bit
r5, r6, r7 :: int
L0:
r0 = 'x'
r1 = CPyStr_CountFull(s, r0, start, end)
r2 = r1 >= 0 :: signed
r3 = box(native_int, r1)
r4 = unbox(int, r3)
return r4
r3 = r1 <= 4611686018427387903 :: signed
if r3 goto L1 else goto L2 :: bool
L1:
r4 = r1 >= -4611686018427387904 :: signed
if r4 goto L3 else goto L2 :: bool
L2:
r5 = CPyTagged_FromInt64(r1)
r6 = r5
goto L4
L3:
r7 = r1 << 1
r6 = r7
L4:
return r6
Loading