Skip to content

Commit d749e82

Browse files
committed
optimize fmod performance
1 parent fd2aed8 commit d749e82

File tree

3 files changed

+253
-9
lines changed

3 files changed

+253
-9
lines changed

libm/src/math/generic/fmod.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2-
use crate::support::{CastFrom, Float, Int, MinInt};
2+
use crate::support::{CastFrom, CastInto, Float, HInt, Int, MinInt, NarrowingDiv};
33

44
#[inline]
5-
pub fn fmod<F: Float>(x: F, y: F) -> F {
5+
pub fn fmod<F: Float>(x: F, y: F) -> F
6+
where
7+
F::Int: HInt,
8+
<F::Int as HInt>::D: NarrowingDiv,
9+
{
610
let _1 = F::Int::ONE;
711
let sx = x.to_bits() & F::SIGN_MASK;
812
let ux = x.to_bits() & !F::SIGN_MASK;
@@ -29,7 +33,7 @@ pub fn fmod<F: Float>(x: F, y: F) -> F {
2933

3034
// To compute `(num << ex) % (div << ey)`, first
3135
// evaluate `rem = (num << (ex - ey)) % div` ...
32-
let rem = reduction(num, ex - ey, div);
36+
let rem = reduction::<F>(num, ex - ey, div);
3337
// ... so the result will be `rem << ey`
3438

3539
if rem.is_zero() {
@@ -58,11 +62,55 @@ fn into_sig_exp<F: Float>(mut bits: F::Int) -> (F::Int, u32) {
5862
}
5963

6064
/// Compute the remainder `(x * 2.pow(e)) % y` without overflow.
61-
fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
62-
x %= y;
63-
for _ in 0..e {
64-
x <<= 1;
65-
x = x.checked_sub(y).unwrap_or(x);
65+
fn reduction<F>(mut x: F::Int, e: u32, y: F::Int) -> F::Int
66+
where
67+
F: Float,
68+
F::Int: HInt,
69+
<<F as Float>::Int as HInt>::D: NarrowingDiv,
70+
{
71+
// `f16` only has 5 exponent bits, so even `f16::MAX = 65504.0` is only
72+
// a 40-bit integer multiple of the smallest subnormal.
73+
if F::BITS == 16 {
74+
debug_assert!(F::EXP_MAX - F::EXP_MIN == 29);
75+
debug_assert!(e <= 29);
76+
let u: u16 = x.cast();
77+
let v: u16 = y.cast();
78+
let u = (u as u64) << e;
79+
let v = v as u64;
80+
return F::Int::cast_from((u % v) as u16);
6681
}
67-
x
82+
83+
// Ensure `x < 2y` for later steps
84+
if x >= (y << 1) {
85+
// This case is only reached with subnormal divisors,
86+
// but it might be better to just normalize all significands
87+
// to make this unnecessary. The further calls could potentially
88+
// benefit from assuming a specific fixed leading bit position.
89+
x %= y;
90+
}
91+
92+
// The simple implementation seems to be fastest for a short reduction
93+
// at this size. The limit here was chosen empirically on an Intel Nehalem.
94+
// Less old CPUs that have faster `u64 * u64 -> u128` might not benefit,
95+
// and 32-bit systems or architectures without hardware multipliers might
96+
// want to do this in more cases.
97+
if F::BITS == 64 && e < 32 {
98+
// Assumes `x < 2y`
99+
for _ in 0..e {
100+
x = x.checked_sub(y).unwrap_or(x);
101+
x <<= 1;
102+
}
103+
return x.checked_sub(y).unwrap_or(x);
104+
}
105+
106+
// Fast path for short reductions
107+
if e < F::BITS {
108+
let w = x.widen() << e;
109+
if let Some((_, r)) = w.checked_narrowing_div_rem(y) {
110+
return r;
111+
}
112+
}
113+
114+
// Assumes `x < 2y`
115+
crate::support::linear_mul_reduction(x, e, y)
68116
}

libm/src/math/support/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) mod feature_detect;
88
mod float_traits;
99
pub mod hex_float;
1010
mod int_traits;
11+
mod modular;
1112

1213
#[allow(unused_imports)]
1314
pub use big::{i256, u256};
@@ -29,6 +30,7 @@ pub use hex_float::hf128;
2930
#[allow(unused_imports)]
3031
pub use hex_float::{hf32, hf64};
3132
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
33+
pub use modular::linear_mul_reduction;
3234

3335
/// Hint to the compiler that the current path is cold.
3436
pub fn cold_path() {

libm/src/math/support/modular.rs

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
use crate::support::int_traits::NarrowingDiv;
2+
use crate::support::{DInt, HInt, Int};
3+
4+
/// Contains:
5+
/// n in (R/8, R/4)
6+
/// x in [0, 2n)
7+
#[derive(Debug, Clone, PartialEq, Eq)]
8+
struct Reducer<U: HInt> {
9+
// let m = 2n
10+
m: U,
11+
// RR/2 = qm + r
12+
r: U,
13+
xq2: U::D,
14+
}
15+
16+
impl<U> Reducer<U>
17+
where
18+
U: HInt,
19+
U: Int<Unsigned = U>,
20+
{
21+
/// Construct a reducer for `(x << _) mod n`.
22+
///
23+
/// Requires `R/8 < n < R/4` and `x < 2n`.
24+
fn new(x: U, n: U) -> Self
25+
where
26+
U::D: NarrowingDiv,
27+
{
28+
let _1 = U::ONE;
29+
assert!(n > (_1 << (U::BITS - 3)));
30+
assert!(n < (_1 << (U::BITS - 2)));
31+
let m = n << 1;
32+
assert!(x < m);
33+
34+
// We need q and r s.t. RR/2 = qm + r
35+
// As R/4 < m < R/2,
36+
// we have R <= q < 2R
37+
// so let q = R + f
38+
// RR/2 = (R + f)m + r
39+
// R(R/2 - m) = fm + r
40+
41+
// v = R/2 - m < R/4 < m
42+
let v = (_1 << (U::BITS - 1)) - m;
43+
let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap();
44+
45+
// xq < qm <= RR/2
46+
// 2xq < RR
47+
// 2xq = 2xR + 2xf;
48+
let x2: U = x << 1;
49+
let xq2 = x2.widen_hi() + x2.widen_mul(f);
50+
Self { m, r, xq2 }
51+
}
52+
53+
/// Extract the current remainder in the range `[0, 2n)`
54+
fn partial_remainder(&self) -> U {
55+
// RR/2 = qm + r, 0 <= r < m
56+
// 2xq = uR + v, 0 <= v < R
57+
// muR = 2mxq - mv
58+
// = xRR - 2xr - mv
59+
// mu + (2xr + mv)/R == xR
60+
61+
// 0 <= 2xq < RR
62+
// R <= q < 2R
63+
// 0 <= x < R/2
64+
// R/4 < m < R/2
65+
// 0 <= r < m
66+
// 0 <= mv < mR
67+
// 0 <= 2xr < rR < mR
68+
69+
// 0 <= (2xr + mv)/R < 2m
70+
// Add `mu` to each term to obtain:
71+
// mu <= xR < mu + 2m
72+
73+
// Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
74+
// `mu` and `m(u+2)`, so we can truncate the latter to find `x`.
75+
let _1 = U::ONE;
76+
self.m.widen_mul(self.xq2.hi() + (_1 + _1)).hi()
77+
}
78+
79+
/// Maps the remainder `x` to `(x << k) - un`,
80+
/// for a suitable quotient `u`, which is returned.
81+
fn shift_reduce(&mut self, k: u32) -> U {
82+
assert!(k < U::BITS);
83+
// 2xq << k = aRR/2 + b;
84+
let a = self.xq2.hi() >> (U::BITS - 1 - k);
85+
let (lo, hi) = (self.xq2 << k).lo_hi();
86+
let b = U::D::from_lo_hi(lo, hi & (U::MAX >> 1));
87+
88+
// (2xq << k) - aqm
89+
// = aRR/2 + b - aqm
90+
// = a(RR/2 - qm) + b
91+
// = ar + b
92+
self.xq2 = a.widen_mul(self.r) + b;
93+
a
94+
}
95+
96+
/// Maps the remainder `x` to `x(R/2) - un`,
97+
/// for a suitable quotient `u`, which is returned.
98+
fn word_reduce(&mut self) -> U {
99+
// 2xq = uR + v
100+
let (v, u) = self.xq2.lo_hi();
101+
// xqR - uqm
102+
// = uRR/2 + vR/2 - uRR/2 + ur
103+
// = ur + (v/2)R
104+
self.xq2 = u.widen_mul(self.r) + U::widen_hi(v >> 1);
105+
u
106+
}
107+
}
108+
109+
/// Compute the remainder `(x << e) % y` with unbounded integers.
110+
/// Requires `x < 2y` and `y.leading_zeros() >= 2`
111+
pub fn linear_mul_reduction<U>(x: U, mut e: u32, y: U) -> U
112+
where
113+
U: HInt + Int<Unsigned = U>,
114+
U::D: NarrowingDiv,
115+
{
116+
assert!(y <= U::MAX >> 2);
117+
assert!(x < (y << 1));
118+
let _0 = U::ZERO;
119+
let _1 = U::ONE;
120+
121+
// power of two divisor
122+
if (y & (y - _1)).is_zero() {
123+
if e < U::BITS {
124+
return (x << e) & (y - _1);
125+
} else {
126+
return _0;
127+
}
128+
}
129+
130+
// shift the divisor so it has exactly two leading zeros
131+
let y_shift = y.leading_zeros() - 2;
132+
let mut m = Reducer::new(x, y << y_shift);
133+
e += y_shift;
134+
135+
while e >= U::BITS - 1 {
136+
m.word_reduce();
137+
e -= U::BITS - 1;
138+
}
139+
m.shift_reduce(e);
140+
141+
let rem = m.partial_remainder() >> y_shift;
142+
rem.checked_sub(y).unwrap_or(rem)
143+
}
144+
145+
#[cfg(test)]
146+
mod test {
147+
use crate::support::linear_mul_reduction;
148+
use crate::support::modular::Reducer;
149+
150+
#[test]
151+
fn reducer_ops() {
152+
for n in 33..=63_u8 {
153+
for x in 0..2 * n {
154+
let temp = Reducer::new(x, n);
155+
let n = n as u32;
156+
let x0 = temp.partial_remainder() as u32;
157+
assert_eq!(x as u32, x0);
158+
for k in 0..=7 {
159+
let mut red = temp.clone();
160+
let u = red.shift_reduce(k) as u32;
161+
let x1 = red.partial_remainder() as u32;
162+
assert_eq!(x1, (x0 << k) - u * n);
163+
assert!(x1 < 2 * n);
164+
assert!((red.xq2 as u32).is_multiple_of(2 * x1));
165+
166+
// `word_reduce` is equivalent to
167+
// `shift_reduce(U::BITS - 1)`
168+
if k == 7 {
169+
let mut alt = temp.clone();
170+
let w = alt.word_reduce();
171+
assert_eq!(u, w as u32);
172+
assert_eq!(alt, red);
173+
}
174+
}
175+
}
176+
}
177+
}
178+
#[test]
179+
fn reduction() {
180+
for y in 1..64u8 {
181+
for x in 0..2 * y {
182+
let mut r = x % y;
183+
for e in 0..100 {
184+
assert_eq!(r, linear_mul_reduction(x, e, y));
185+
// maintain the correct expected remainder
186+
r <<= 1;
187+
if r >= y {
188+
r -= y;
189+
}
190+
}
191+
}
192+
}
193+
}
194+
}

0 commit comments

Comments
 (0)