Skip to content

Rewriteis_ascii using slice::as_chunks #144837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 26 additions & 150 deletions library/core/src/slice/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
use core::ascii::EscapeDefault;

use crate::fmt::{self, Write};
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
use crate::intrinsics::const_eval_select;
use crate::{ascii, iter, ops};

Expand Down Expand Up @@ -327,175 +326,52 @@ impl<'a> fmt::Debug for EscapeAscii<'a> {
}
}

/// ASCII test *without* the chunk-at-a-time optimizations.
///
/// This is carefully structured to produce nice small code -- it's smaller in
/// `-O` than what the "obvious" ways produces under `-C opt-level=s`. If you
/// touch it, be sure to run (and update if needed) the assembly test.
#[unstable(feature = "str_internals", issue = "none")]
#[doc(hidden)]
#[inline]
pub const fn is_ascii_simple(mut bytes: &[u8]) -> bool {
while let [rest @ .., last] = bytes {
if !last.is_ascii() {
const fn is_ascii_const(mut bytes: &[u8]) -> bool {
while let [first, rest @ ..] = bytes {
if !first.is_ascii() {
break;
}
bytes = rest;
}
bytes.is_empty()
}

/// The implementation using iterators produces a tighter loop than the
/// implementation using pattern-matching when inlined into `is_ascii_chunked`.
/// So we have duplicate implementations of the scalar case until iterators are
/// usable in const contexts.
#[inline(always)]
fn is_ascii_scalar(bytes: &[u8]) -> bool {
bytes.iter().all(u8::is_ascii)
}

/// Optimized ASCII test that will use usize-at-a-time operations instead of
/// byte-at-a-time operations (when possible).
///
/// The algorithm we use here is pretty simple. If `s` is too short, we just
/// check each byte and be done with it. Otherwise:
///
/// - Read the first word with an unaligned load.
/// - Align the pointer, read subsequent words until end with aligned loads.
/// - Read the last `usize` from `s` with an unaligned load.
///
/// If any of these loads produces something for which `contains_nonascii`
/// (above) returns true, then we know the answer is false.
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
#[inline]
#[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior
const fn is_ascii(s: &[u8]) -> bool {
const fn is_ascii(bytes: &[u8]) -> bool {
// The runtime version behaves the same as the compiletime version, it's
// just more optimized.
const_eval_select!(
@capture { s: &[u8] } -> bool:
@capture { bytes: &[u8] } -> bool:
if const {
is_ascii_simple(s)
is_ascii_const(bytes)
} else {
/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
/// from `../str/mod.rs`, which does something similar for utf8 validation.
const fn contains_nonascii(v: usize) -> bool {
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
(NONASCII_MASK & v) != 0
}

const USIZE_SIZE: usize = size_of::<usize>();

let len = s.len();
let align_offset = s.as_ptr().align_offset(USIZE_SIZE);

// If we wouldn't gain anything from the word-at-a-time implementation, fall
// back to a scalar loop.
//
// We also do this for architectures where `size_of::<usize>()` isn't
// sufficient alignment for `usize`, because it's a weird edge case.
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < align_of::<usize>() {
return is_ascii_simple(s);
}

// We always read the first word unaligned, which means `align_offset` is
// 0, we'd read the same value again for the aligned read.
let offset_to_aligned = if align_offset == 0 { USIZE_SIZE } else { align_offset };

let start = s.as_ptr();
// SAFETY: We verify `len < USIZE_SIZE` above.
let first_word = unsafe { (start as *const usize).read_unaligned() };

if contains_nonascii(first_word) {
return false;
}
// We checked this above, somewhat implicitly. Note that `offset_to_aligned`
// is either `align_offset` or `USIZE_SIZE`, both of are explicitly checked
// above.
debug_assert!(offset_to_aligned <= len);

// SAFETY: word_ptr is the (properly aligned) usize ptr we use to read the
// middle chunk of the slice.
let mut word_ptr = unsafe { start.add(offset_to_aligned) as *const usize };

// `byte_pos` is the byte index of `word_ptr`, used for loop end checks.
let mut byte_pos = offset_to_aligned;

// Paranoia check about alignment, since we're about to do a bunch of
// unaligned loads. In practice this should be impossible barring a bug in
// `align_offset` though.
// While this method is allowed to spuriously fail in CTFE, if it doesn't
// have alignment information it should have given a `usize::MAX` for
// `align_offset` earlier, sending things through the scalar path instead of
// this one, so this check should pass if it's reachable.
debug_assert!(word_ptr.is_aligned_to(align_of::<usize>()));

// Read subsequent words until the last aligned word, excluding the last
// aligned word by itself to be done in tail check later, to ensure that
// tail is always one `usize` at most to extra branch `byte_pos == len`.
while byte_pos < len - USIZE_SIZE {
// Sanity check that the read is in bounds
debug_assert!(byte_pos + USIZE_SIZE <= len);
// And that our assumptions about `byte_pos` hold.
debug_assert!(word_ptr.cast::<u8>() == start.wrapping_add(byte_pos));

// SAFETY: We know `word_ptr` is properly aligned (because of
// `align_offset`), and we know that we have enough bytes between `word_ptr` and the end
let word = unsafe { word_ptr.read() };
if contains_nonascii(word) {
return false;
}

byte_pos += USIZE_SIZE;
// SAFETY: We know that `byte_pos <= len - USIZE_SIZE`, which means that
// after this `add`, `word_ptr` will be at most one-past-the-end.
word_ptr = unsafe { word_ptr.add(1) };
}

// Sanity check to ensure there really is only one `usize` left. This should
// be guaranteed by our loop condition.
debug_assert!(byte_pos <= len && len - byte_pos <= USIZE_SIZE);

// SAFETY: This relies on `len >= USIZE_SIZE`, which we check at the start.
let last_word = unsafe { (start.add(len - USIZE_SIZE) as *const usize).read_unaligned() };

!contains_nonascii(last_word)
const CHUNK_SIZE: usize = if cfg!(all(target_arch = "x86_64", target_feature = "sse2")) {
4 * size_of::<usize>()
} else {
2 * size_of::<usize>()
};
is_ascii_chunked::<CHUNK_SIZE>(bytes)
}
)
}

/// ASCII test optimized to use the `pmovmskb` instruction available on `x86-64`
/// platforms.
///
/// Other platforms are not likely to benefit from this code structure, so they
/// use SWAR techniques to test for ASCII in `usize`-sized chunks.
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
/// Test for ASCII-ness `CHUNK_SIZE` bytes at a time.
/// This loop should be simple enough that LLVM can auto-vectorise it.
#[inline]
const fn is_ascii(bytes: &[u8]) -> bool {
// Process chunks of 32 bytes at a time in the fast path to enable
// auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers
// can be OR'd together and then the resulting vector can be tested for
// non-ASCII bytes.
const CHUNK_SIZE: usize = 32;

let mut i = 0;

while i + CHUNK_SIZE <= bytes.len() {
let chunk_end = i + CHUNK_SIZE;

// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
// creates a mask from the most significant bit of each byte.
// ASCII bytes are less than 128 (0x80), so their most significant
// bit is unset.
let mut count = 0;
while i < chunk_end {
count += bytes[i].is_ascii() as u8;
i += 1;
}

// All bytes should be <= 127 so count is equal to chunk size.
if count != CHUNK_SIZE as u8 {
return false;
}
}

// Process the remaining `bytes.len() % N` bytes.
let mut is_ascii = true;
while i < bytes.len() {
is_ascii &= bytes[i].is_ascii();
i += 1;
}

is_ascii
fn is_ascii_chunked<const CHUNK_SIZE: usize>(bytes: &[u8]) -> bool {
let (chunks, remainder) = bytes.as_chunks::<CHUNK_SIZE>();
chunks.iter().all(|chunk| is_ascii_scalar(chunk)) && is_ascii_scalar(remainder)
}
3 changes: 0 additions & 3 deletions library/core/src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ mod specialize;

#[stable(feature = "inherent_ascii_escape", since = "1.60.0")]
pub use ascii::EscapeAscii;
#[unstable(feature = "str_internals", issue = "none")]
#[doc(hidden)]
pub use ascii::is_ascii_simple;
#[stable(feature = "slice_get_slice", since = "1.28.0")]
pub use index::SliceIndex;
#[unstable(feature = "slice_range", issue = "76393")]
Expand Down
Loading