Skip to content

Commit 6ac4350

Browse files
committed
Optimize is_ascii
1 parent 213d946 commit 6ac4350

File tree

4 files changed

+82
-222
lines changed

4 files changed

+82
-222
lines changed

library/core/src/slice/ascii.rs

Lines changed: 66 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
use core::ascii::EscapeDefault;
44

55
use crate::fmt::{self, Write};
6-
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
76
use crate::intrinsics::const_eval_select;
87
use crate::{ascii, iter, ops};
98

@@ -327,175 +326,93 @@ impl<'a> fmt::Debug for EscapeAscii<'a> {
327326
}
328327
}
329328

330-
/// ASCII test *without* the chunk-at-a-time optimizations.
331-
///
332-
/// This is carefully structured to produce nice small code -- it's smaller in
333-
/// `-O` than what the "obvious" ways produces under `-C opt-level=s`. If you
334-
/// touch it, be sure to run (and update if needed) the assembly test.
335-
#[unstable(feature = "str_internals", issue = "none")]
336-
#[doc(hidden)]
337-
#[inline]
338-
pub const fn is_ascii_simple(mut bytes: &[u8]) -> bool {
339-
while let [rest @ .., last] = bytes {
340-
if !last.is_ascii() {
341-
break;
342-
}
343-
bytes = rest;
344-
}
345-
bytes.is_empty()
346-
}
347-
348-
/// Optimized ASCII test that will use usize-at-a-time operations instead of
349-
/// byte-at-a-time operations (when possible).
350-
///
351-
/// The algorithm we use here is pretty simple. If `s` is too short, we just
352-
/// check each byte and be done with it. Otherwise:
353-
///
354-
/// - Read the first word with an unaligned load.
355-
/// - Align the pointer, read subsequent words until end with aligned loads.
356-
/// - Read the last `usize` from `s` with an unaligned load.
357-
///
358-
/// If any of these loads produces something for which `contains_nonascii`
359-
/// (above) returns true, then we know the answer is false.
360-
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
361329
#[inline]
362330
#[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior
363-
const fn is_ascii(s: &[u8]) -> bool {
331+
const fn is_ascii(bytes: &[u8]) -> bool {
364332
// The runtime version behaves the same as the compiletime version, it's
365333
// just more optimized.
366334
const_eval_select!(
367-
@capture { s: &[u8] } -> bool:
335+
@capture { bytes: &[u8] } -> bool:
368336
if const {
369-
is_ascii_simple(s)
337+
is_ascii_const(bytes)
370338
} else {
371-
/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
372-
/// from `../str/mod.rs`, which does something similar for utf8 validation.
373-
const fn contains_nonascii(v: usize) -> bool {
374-
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
375-
(NONASCII_MASK & v) != 0
376-
}
377-
378-
const USIZE_SIZE: usize = size_of::<usize>();
379-
380-
let len = s.len();
381-
let align_offset = s.as_ptr().align_offset(USIZE_SIZE);
382-
383-
// If we wouldn't gain anything from the word-at-a-time implementation, fall
384-
// back to a scalar loop.
385-
//
386-
// We also do this for architectures where `size_of::<usize>()` isn't
387-
// sufficient alignment for `usize`, because it's a weird edge case.
388-
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < align_of::<usize>() {
389-
return is_ascii_simple(s);
339+
if cfg!(all(target_arch = "x86_64", target_feature = "sse2")) {
340+
is_ascii_swar::<4>(bytes)
341+
} else if cfg!(target_arch = "aarch64") {
342+
is_ascii_simd::<{ 2 * size_of::<usize>() }>(bytes)
343+
} else {
344+
is_ascii_swar::<2>(bytes)
390345
}
346+
}
347+
)
348+
}
391349

392-
// We always read the first word unaligned, which means `align_offset` is
393-
// 0, we'd read the same value again for the aligned read.
394-
let offset_to_aligned = if align_offset == 0 { USIZE_SIZE } else { align_offset };
350+
#[inline]
351+
const fn is_ascii_const(mut bytes: &[u8]) -> bool {
352+
while let [first, rest @ ..] = bytes {
353+
if !first.is_ascii() {
354+
break;
355+
}
356+
bytes = rest;
357+
}
358+
bytes.is_empty()
359+
}
395360

396-
let start = s.as_ptr();
397-
// SAFETY: We verify `len < USIZE_SIZE` above.
398-
let first_word = unsafe { (start as *const usize).read_unaligned() };
361+
#[inline(always)]
362+
fn is_ascii_scalar(bytes: &[u8]) -> bool {
363+
bytes.iter().all(u8::is_ascii)
364+
}
399365

400-
if contains_nonascii(first_word) {
401-
return false;
402-
}
403-
// We checked this above, somewhat implicitly. Note that `offset_to_aligned`
404-
// is either `align_offset` or `USIZE_SIZE`, both of are explicitly checked
405-
// above.
406-
debug_assert!(offset_to_aligned <= len);
407-
408-
// SAFETY: word_ptr is the (properly aligned) usize ptr we use to read the
409-
// middle chunk of the slice.
410-
let mut word_ptr = unsafe { start.add(offset_to_aligned) as *const usize };
411-
412-
// `byte_pos` is the byte index of `word_ptr`, used for loop end checks.
413-
let mut byte_pos = offset_to_aligned;
414-
415-
// Paranoia check about alignment, since we're about to do a bunch of
416-
// unaligned loads. In practice this should be impossible barring a bug in
417-
// `align_offset` though.
418-
// While this method is allowed to spuriously fail in CTFE, if it doesn't
419-
// have alignment information it should have given a `usize::MAX` for
420-
// `align_offset` earlier, sending things through the scalar path instead of
421-
// this one, so this check should pass if it's reachable.
422-
debug_assert!(word_ptr.is_aligned_to(align_of::<usize>()));
423-
424-
// Read subsequent words until the last aligned word, excluding the last
425-
// aligned word by itself to be done in tail check later, to ensure that
426-
// tail is always one `usize` at most to extra branch `byte_pos == len`.
427-
while byte_pos < len - USIZE_SIZE {
428-
// Sanity check that the read is in bounds
429-
debug_assert!(byte_pos + USIZE_SIZE <= len);
430-
// And that our assumptions about `byte_pos` hold.
431-
debug_assert!(word_ptr.cast::<u8>() == start.wrapping_add(byte_pos));
432-
433-
// SAFETY: We know `word_ptr` is properly aligned (because of
434-
// `align_offset`), and we know that we have enough bytes between `word_ptr` and the end
435-
let word = unsafe { word_ptr.read() };
436-
if contains_nonascii(word) {
437-
return false;
438-
}
439-
440-
byte_pos += USIZE_SIZE;
441-
// SAFETY: We know that `byte_pos <= len - USIZE_SIZE`, which means that
442-
// after this `add`, `word_ptr` will be at most one-past-the-end.
443-
word_ptr = unsafe { word_ptr.add(1) };
444-
}
366+
#[inline(always)]
367+
fn is_ascii_word(word: usize) -> bool {
368+
word & usize::repeat_u8(0x80) == 0
369+
}
445370

446-
// Sanity check to ensure there really is only one `usize` left. This should
447-
// be guaranteed by our loop condition.
448-
debug_assert!(byte_pos <= len && len - byte_pos <= USIZE_SIZE);
371+
/// Check `bytes` are ASCII by reading `UNROLL_FACTOR` words at a time.
372+
#[inline(always)]
373+
#[unstable(feature = "str_internals", issue = "none")]
374+
pub fn is_ascii_swar<const UNROLL_FACTOR: usize>(bytes: &[u8]) -> bool {
375+
if bytes.len() < size_of::<usize>() {
376+
return is_ascii_scalar(bytes);
377+
}
449378

450-
// SAFETY: This relies on `len >= USIZE_SIZE`, which we check at the start.
451-
let last_word = unsafe { (start.add(len - USIZE_SIZE) as *const usize).read_unaligned() };
379+
// SAFETY: Casting between `u8` and `usize` is fine.
380+
let (_, words, _) = unsafe { bytes.align_to::<usize>() };
381+
let crate::ops::Range { start, end } = bytes.as_ptr_range();
452382

453-
!contains_nonascii(last_word)
454-
}
455-
)
456-
}
383+
// SAFETY: checked above that `len >= size_of::<usize>()`.
384+
let first_word = unsafe { start.cast::<usize>().read_unaligned() };
385+
if !is_ascii_word(first_word) {
386+
return false;
387+
}
457388

458-
/// ASCII test optimized to use the `pmovmskb` instruction available on `x86-64`
459-
/// platforms.
460-
///
461-
/// Other platforms are not likely to benefit from this code structure, so they
462-
/// use SWAR techniques to test for ASCII in `usize`-sized chunks.
463-
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
464-
#[inline]
465-
const fn is_ascii(bytes: &[u8]) -> bool {
466-
// Process chunks of 32 bytes at a time in the fast path to enable
467-
// auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers
468-
// can be OR'd together and then the resulting vector can be tested for
469-
// non-ASCII bytes.
470-
const CHUNK_SIZE: usize = 32;
471-
472-
let mut i = 0;
473-
474-
while i + CHUNK_SIZE <= bytes.len() {
475-
let chunk_end = i + CHUNK_SIZE;
476-
477-
// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
478-
// creates a mask from the most significant bit of each byte.
479-
// ASCII bytes are less than 128 (0x80), so their most significant
480-
// bit is unset.
481-
let mut count = 0;
482-
while i < chunk_end {
483-
count += bytes[i].is_ascii() as u8;
484-
i += 1;
389+
let (chunks, remainder) = words.as_chunks::<UNROLL_FACTOR>();
390+
for chunk in chunks {
391+
let word = chunk.iter().fold(0, |acc, word| word | acc);
392+
if !is_ascii_word(word) {
393+
return false;
485394
}
395+
}
486396

487-
// All bytes should be <= 127 so count is equal to chunk size.
488-
if count != CHUNK_SIZE as u8 {
397+
for word in remainder {
398+
if !is_ascii_word(*word) {
489399
return false;
490400
}
491401
}
492402

493-
// Process the remaining `bytes.len() % N` bytes.
494-
let mut is_ascii = true;
495-
while i < bytes.len() {
496-
is_ascii &= bytes[i].is_ascii();
497-
i += 1;
403+
// SAFETY: checked above that `len >= size_of::<usize>()`.
404+
let last_word = unsafe { end.cast::<usize>().read_unaligned() };
405+
if !is_ascii_word(last_word) {
406+
return false;
498407
}
499408

500-
is_ascii
409+
true
410+
}
411+
412+
/// Check `bytes` are ASCII by reading `CHUNK_SIZE` bytes at a time.
413+
#[inline(always)]
414+
#[unstable(feature = "str_internals", issue = "none")]
415+
pub fn is_ascii_simd<const CHUNK_SIZE: usize>(bytes: &[u8]) -> bool {
416+
let (chunks, remainder) = bytes.as_chunks::<CHUNK_SIZE>();
417+
chunks.iter().all(|chunk| is_ascii_scalar(chunk)) && is_ascii_scalar(remainder)
501418
}

library/core/src/slice/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ mod specialize;
4545
pub use ascii::EscapeAscii;
4646
#[unstable(feature = "str_internals", issue = "none")]
4747
#[doc(hidden)]
48-
pub use ascii::is_ascii_simple;
48+
pub use ascii::{is_ascii_simd, is_ascii_swar};
4949
#[stable(feature = "slice_get_slice", since = "1.28.0")]
5050
pub use index::SliceIndex;
5151
#[unstable(feature = "slice_range", issue = "76393")]

library/coretests/benches/ascii/is_ascii.rs

Lines changed: 13 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ macro_rules! benches {
66
($( fn $name: ident($arg: ident: &[u8]) $body: block )+) => {
77
benches!(mod short SHORT[..] $($name $arg $body)+);
88
benches!(mod medium MEDIUM[..] $($name $arg $body)+);
9+
benches!(mod medium_15 MEDIUM[..=15] $($name $arg $body)+);
910
benches!(mod long LONG[..] $($name $arg $body)+);
1011
// Ensure we benchmark cases where the functions are called with strings
1112
// that are not perfectly aligned or have a length which is not a
@@ -37,87 +38,27 @@ macro_rules! benches {
3738
}
3839

3940
benches! {
40-
fn case00_libcore(bytes: &[u8]) {
41-
bytes.is_ascii()
41+
fn is_ascii_swar_1(bytes: &[u8]) {
42+
core::slice::is_ascii_swar::<1>(bytes)
4243
}
4344

44-
fn case01_iter_all(bytes: &[u8]) {
45-
bytes.iter().all(|b| b.is_ascii())
45+
fn is_ascii_swar_2(bytes: &[u8]) {
46+
core::slice::is_ascii_swar::<2>(bytes)
4647
}
4748

48-
fn case02_align_to(bytes: &[u8]) {
49-
is_ascii_align_to(bytes)
49+
fn is_ascii_swar_4(bytes: &[u8]) {
50+
core::slice::is_ascii_swar::<4>(bytes)
5051
}
5152

52-
fn case03_align_to_unrolled(bytes: &[u8]) {
53-
is_ascii_align_to_unrolled(bytes)
53+
fn is_ascii_simd_08(bytes: &[u8]) {
54+
core::slice::is_ascii_simd::<8>(bytes)
5455
}
5556

56-
fn case04_while_loop(bytes: &[u8]) {
57-
// Process chunks of 32 bytes at a time in the fast path to enable
58-
// auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers
59-
// can be OR'd together and then the resulting vector can be tested for
60-
// non-ASCII bytes.
61-
const CHUNK_SIZE: usize = 32;
62-
63-
let mut i = 0;
64-
65-
while i + CHUNK_SIZE <= bytes.len() {
66-
let chunk_end = i + CHUNK_SIZE;
67-
68-
// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
69-
// creates a mask from the most significant bit of each byte.
70-
// ASCII bytes are less than 128 (0x80), so their most significant
71-
// bit is unset.
72-
let mut count = 0;
73-
while i < chunk_end {
74-
count += bytes[i].is_ascii() as u8;
75-
i += 1;
76-
}
77-
78-
// All bytes should be <= 127 so count is equal to chunk size.
79-
if count != CHUNK_SIZE as u8 {
80-
return false;
81-
}
82-
}
83-
84-
// Process the remaining `bytes.len() % N` bytes.
85-
let mut is_ascii = true;
86-
while i < bytes.len() {
87-
is_ascii &= bytes[i].is_ascii();
88-
i += 1;
89-
}
90-
91-
is_ascii
57+
fn is_ascii_simd_16(bytes: &[u8]) {
58+
core::slice::is_ascii_simd::<16>(bytes)
9259
}
93-
}
9460

95-
// These are separate since it's easier to debug errors if they don't go through
96-
// macro expansion first.
97-
fn is_ascii_align_to(bytes: &[u8]) -> bool {
98-
if bytes.len() < size_of::<usize>() {
99-
return bytes.iter().all(|b| b.is_ascii());
61+
fn is_ascii_simd_32(bytes: &[u8]) {
62+
core::slice::is_ascii_simd::<32>(bytes)
10063
}
101-
// SAFETY: transmuting a sequence of `u8` to `usize` is always fine
102-
let (head, body, tail) = unsafe { bytes.align_to::<usize>() };
103-
head.iter().all(|b| b.is_ascii())
104-
&& body.iter().all(|w| !contains_nonascii(*w))
105-
&& tail.iter().all(|b| b.is_ascii())
106-
}
107-
108-
fn is_ascii_align_to_unrolled(bytes: &[u8]) -> bool {
109-
if bytes.len() < size_of::<usize>() {
110-
return bytes.iter().all(|b| b.is_ascii());
111-
}
112-
// SAFETY: transmuting a sequence of `u8` to `[usize; 2]` is always fine
113-
let (head, body, tail) = unsafe { bytes.align_to::<[usize; 2]>() };
114-
head.iter().all(|b| b.is_ascii())
115-
&& body.iter().all(|w| !contains_nonascii(w[0] | w[1]))
116-
&& tail.iter().all(|b| b.is_ascii())
117-
}
118-
119-
#[inline]
120-
fn contains_nonascii(v: usize) -> bool {
121-
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; size_of::<usize>()]);
122-
(NONASCII_MASK & v) != 0
12364
}

library/coretests/benches/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#![feature(iter_array_chunks)]
99
#![feature(iter_next_chunk)]
1010
#![feature(iter_advance_by)]
11+
#![feature(str_internals)]
12+
#![allow(internal_features)]
1113

1214
extern crate test;
1315

0 commit comments

Comments
 (0)