From 82971a308fffae9f7b335a78b5f0879bc82468bd Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 16 Nov 2021 09:30:25 +0000 Subject: [PATCH] temp --- benches/uniform.rs | 45 +++-- src/distributions/uniform/uniform_int.rs | 236 ++++++++++++++--------- 2 files changed, 177 insertions(+), 104 deletions(-) diff --git a/benches/uniform.rs b/benches/uniform.rs index cb743b4f788..cdd8376482e 100644 --- a/benches/uniform.rs +++ b/benches/uniform.rs @@ -86,6 +86,24 @@ fn uniform_int(c: &mut Criterion) { bench_int!(c, i128, (i128::MIN, 1)); } +#[cfg(feature = "simd_support")] +macro_rules! bench_dist_simd_group { + ($name:literal, $T:ty, $f:ident, $g:expr, $inputs:expr) => { + for input in $inputs { + $g.bench_with_input( + BenchmarkId::new($name, input.0), + &input.1, + |b, (low, high)| { + let mut rng = BenchRng::from_entropy(); + let (low, high) = (<$T>::splat(*low), <$T>::splat(*high)); + let dist = Uniform::new_inclusive(low, high); + b.iter(|| <$T as SampleUniform>::Sampler::$f(&dist.0, &mut rng)) + }, + ); + } + }; +} + #[cfg(feature = "simd_support")] macro_rules! bench_simd_group { ($name:literal, $T:ty, $f:ident, $g:expr, $inputs:expr) => { @@ -114,20 +132,21 @@ macro_rules! bench_simd_group { #[cfg(feature = "simd_support")] macro_rules! bench_simd { ($c:expr, $T:ty, $high:expr/*, $incr:expr*/) => {{ - let mut g = $c.benchmark_group(concat!("uniform_simd_", stringify!($T))); let inputs = &[("high_reject", $high), ("low_reject", (-1, 2))]; + let mut g = $c.benchmark_group(concat!("uniform_dist_simd_", stringify!($T))); + bench_dist_simd_group!("Old", $T, sample, g, inputs); + bench_dist_simd_group!("Canon", $T, sample_canon, g, inputs); + bench_dist_simd_group!("Canon-Lemire", $T, sample_canon_lemire, g, inputs); + bench_dist_simd_group!("Bitmask", $T, sample_bitmask, g, inputs); + drop(g); + + let mut g = $c.benchmark_group(concat!("uniform_int_", stringify!($T))); bench_simd_group!("Old", $T, sample_single_inclusive, g, inputs); bench_simd_group!("Canon", $T, sample_single_inclusive_canon, g, inputs); - bench_simd_group!( - "Canon-branchless", - $T, - sample_single_inclusive_canon_branchless, - g, - inputs - ); - bench_simd_group!("Canon-scalar", $T, sample_inclusive_canon_scalar, g, inputs); + bench_simd_group!("Canon-Lemire", $T, sample_inclusive_canon_lemire, g, inputs); bench_simd_group!("Bitmask", $T, sample_single_inclusive_bitmask, g, inputs); + }}; } @@ -138,11 +157,11 @@ fn uniform_simd(c: &mut Criterion) { bench_simd!(c, i8x2, (i8::MIN, 116)); bench_simd!(c, i8x4, (i8::MIN, 116)); bench_simd!(c, i8x8, (i8::MIN, 116)); - bench_simd!(c, i8x16, (i8::MIN, 116)); - bench_simd!(c, i8x32, (i8::MIN, 116)); - bench_simd!(c, i8x64, (i8::MIN, 116)); +// bench_simd!(c, i8x16, (i8::MIN, 116)); +// bench_simd!(c, i8x32, (i8::MIN, 116)); +// bench_simd!(c, i8x64, (i8::MIN, 116)); bench_simd!(c, i16x8, (i16::MIN, 32407)); - bench_simd!(c, i16x16, (i16::MIN, 32407)); +// bench_simd!(c, i16x16, (i16::MIN, 32407)); bench_simd!(c, i32x4, (i32::MIN, 1)); bench_simd!(c, i32x8, (i32::MIN, 1)); bench_simd!(c, i64x2, (i64::MIN, 1)); diff --git a/src/distributions/uniform/uniform_int.rs b/src/distributions/uniform/uniform_int.rs index e5eb5d2ecaa..40f28c9285b 100644 --- a/src/distributions/uniform/uniform_int.rs +++ b/src/distributions/uniform/uniform_int.rs @@ -701,6 +701,48 @@ macro_rules! uniform_simd_int_impl { } impl UniformInt<$ty> { + /// Sample, Bitmask method + #[inline] + pub fn sample_bitmask(&self, rng: &mut R) -> $ty { + let mut range: $unsigned = self.range.cast(); + let is_full_range = range.eq($unsigned::splat(0)); + + // the old impl use a mix of methods for different integer sizes, we only use + // the lz method here for a better comparison. + + // generate bitmask + range -= 1; + let mut mask = range | 1; + + mask |= mask >> 1; + mask |= mask >> 2; + mask |= mask >> 4; + + const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes(); + if LANE_WIDTH >= 16 { mask |= mask >> 8; } + if LANE_WIDTH >= 32 { mask |= mask >> 16; } + if LANE_WIDTH >= 64 { mask |= mask >> 32; } + if LANE_WIDTH >= 128 { mask |= mask >> 64; } + + let mut v: $unsigned = rng.gen(); + loop { + let masked = v & mask; + let accept = masked.le(range); + if accept.all() { + let masked: $ty = masked.cast(); + // wrapping addition + let result = self.low + masked; + // `select` here compiles to a blend operation + // When `range.eq(0).none()` the compare and blend + // operations are avoided. + let v: $ty = v.cast(); + return is_full_range.select(v, result); + } + // Replace only the failing lanes + v = accept.select(v, rng.gen()); + } + } + #[inline(always)] fn sample_inc_setup(low_b: B1, high_b: B2) -> ($unsigned, $ty) where @@ -715,8 +757,8 @@ macro_rules! uniform_simd_int_impl { (range, low) } - /// - #[inline(always)] + /// Sample single inclusive, using the Bitmask method + #[inline] pub fn sample_single_inclusive_bitmask( low_b: B1, high_b: B2, rng: &mut R, ) -> $ty @@ -770,7 +812,7 @@ macro_rules! uniform_simd_int_impl { )+ }; } - +/* #[cfg(feature = "simd_support")] macro_rules! uniform_simd_int_gt8_impl { ($ty:ident, $unsigned:ident) => { @@ -791,7 +833,7 @@ macro_rules! uniform_simd_int_gt8_impl { let (_, overflowed) = lo_order.overflowing_add(cast_new_hi); *result += overflowed.select($unsigned::splat(1), $unsigned::splat(0)); } - +/* /// Canon's method #[inline(always)] pub fn sample_single_inclusive_canon_branchless( @@ -850,6 +892,7 @@ macro_rules! uniform_simd_int_gt8_impl { let cast_rand_bits: $ty = rand_bits.cast(); is_full_range.select(cast_rand_bits, low + cast_result) } +*/ } }; @@ -858,54 +901,77 @@ macro_rules! uniform_simd_int_gt8_impl { uniform_simd_int_gt8_impl!{ $signed, $unsigned } )+}; } +*/ +// These are the naive ports of the above algorithms to SIMD. +// Caveat: supports only up to x8, and at x8 it relies on 512-bit SIMD. +// Caveat: always generates u64 samples which will often be missing. #[cfg(feature = "simd_support")] macro_rules! uniform_simd_int_le8_impl { - ($ty:ident, $unsigned:ident, $u64xN_type:ident, $u_extra_large:ident) => { + ($ty:ident, $unsigned:ident, $u_extra_large:ident) => { impl UniformInt<$ty> { #[inline(always)] fn canon_successive( - range: $unsigned, - result: &mut $unsigned, - lo_order: $unsigned, + range: $u_extra_large, + result: &mut $u_extra_large, + lo_order: $u_extra_large, rng: &mut R ) { // ...generate a new sample with 64 more bits, enough that bias is undetectable - let new_bits: $u_extra_large = rng.gen::<$u64xN_type>().cast(); - let large_range: $u_extra_large = range.cast(); - let (new_hi_order, _) = new_bits.wmul(large_range); + let new_bits: $u_extra_large = rng.gen::<$u_extra_large>(); + let (new_hi_order, _) = new_bits.wmul(range); // and adjust if needed - let cast_new_hi: $unsigned = new_hi_order.cast(); - let (_, overflowed) = lo_order.overflowing_add(cast_new_hi); - *result += overflowed.select($unsigned::splat(1), $unsigned::splat(0)); + let (_, overflowed) = lo_order.overflowing_add(new_hi_order); + *result += overflowed.select($u_extra_large::splat(1), $u_extra_large::splat(0)); } - /// - #[inline(always)] - pub fn sample_single_inclusive_canon_branchless( - low_b: B1, high_b: B2, rng: &mut R, - ) -> $ty - where - B1: SampleBorrow<$ty> + Sized, - B2: SampleBorrow<$ty> + Sized, - { - let (range, low) = Self::sample_inc_setup(low_b, high_b); + /// Sample, Canon's method + #[inline] + pub fn sample_canon(&self, rng: &mut R) -> $ty { + let range: $unsigned = self.range.cast(); let is_full_range = range.eq($unsigned::splat(0)); + let range: $u_extra_large = range.cast(); // generate a sample using a sensible integer type - let rand_bits = rng.gen::<$unsigned>(); + let rand_bits = rng.gen::<$u_extra_large>(); let (mut result, lo_order) = rand_bits.wmul(range); - Self::canon_successive(range, &mut result, lo_order, rng); + if lo_order.gt(0 - range).any() { + Self::canon_successive(range, &mut result, lo_order, rng); + } - let cast_result: $ty = result.cast(); - let cast_rand_bits: $ty = rand_bits.cast(); - is_full_range.select(cast_rand_bits, low + cast_result) + // truncate and return the result: + let result: $ty = result.cast(); + let rand_bits: $ty = rand_bits.cast(); + is_full_range.select(rand_bits, self.low + result) } - /// - #[inline(always)] - pub fn sample_inclusive_canon_scalar( + /// Sample, Canon's method with Lemire's early-out + #[inline] + pub fn sample_canon_lemire(&self, rng: &mut R) -> $ty { + let range: $unsigned = self.range.cast(); + let is_full_range = range.eq($unsigned::splat(0)); + let range: $u_extra_large = range.cast(); + + // generate a sample using a sensible integer type + let rand_bits = rng.gen::<$u_extra_large>(); + let (mut result, lo_order) = rand_bits.wmul(range); + + let nrmr: $unsigned = self.nrmr.cast(); + let nrmr: $u_extra_large = nrmr.cast(); + if lo_order.lt(nrmr).any() { + Self::canon_successive(range, &mut result, lo_order, rng); + } + + // truncate and return the result: + let result: $ty = result.cast(); + let rand_bits: $ty = rand_bits.cast(); + is_full_range.select(rand_bits, self.low + result) + } + + /// Sample single inclusive, using Canon's method + #[inline] + pub fn sample_single_inclusive_canon( low_b: B1, high_b: B2, rng: &mut R, ) -> $ty where @@ -914,36 +980,25 @@ macro_rules! uniform_simd_int_le8_impl { { let (range, low) = Self::sample_inc_setup(low_b, high_b); let is_full_range = range.eq($unsigned::splat(0)); + let range: $u_extra_large = range.cast(); // generate a sample using a sensible integer type - let rand_bits = rng.gen::<$unsigned>(); + let rand_bits = rng.gen::<$u_extra_large>(); let (mut result, lo_order) = rand_bits.wmul(range); - // ...generate a new sample with 64 more bits, enough that bias is undetectable - let new_bits: $u_extra_large = rng.gen::<$u64xN_type>().cast(); - let large_range: $u_extra_large = range.cast(); - - // let (new_hi_order, _) = new_bits.wmul(large_range); - let mut new_hi_order = <$u_extra_large>::default(); - - for i in 0..<$ty>::lanes() { - let (shi, _slo) = new_bits.extract(i).wmul(large_range.extract(i)); - new_hi_order = new_hi_order.replace(i, shi); + if lo_order.gt(0 - range).any() { + Self::canon_successive(range, &mut result, lo_order, rng); } - // and adjust if needed - let cast_new_hi: $unsigned = new_hi_order.cast(); - let (_, overflowed) = lo_order.overflowing_add(cast_new_hi); - result += overflowed.select($unsigned::splat(1), $unsigned::splat(0)); - - let cast_result: $ty = result.cast(); - let cast_rand_bits: $ty = rand_bits.cast(); - is_full_range.select(cast_rand_bits, low + cast_result) + // truncate and return the result: + let result: $ty = result.cast(); + let rand_bits: $ty = rand_bits.cast(); + is_full_range.select(rand_bits, low + result) } - /// - #[inline(always)] - pub fn sample_single_inclusive_canon( + /// Sample single inclusive, using Canon's method with Lemire's early-out + #[inline] + pub fn sample_inclusive_canon_lemire( low_b: B1, high_b: B2, rng: &mut R, ) -> $ty where @@ -952,25 +1007,28 @@ macro_rules! uniform_simd_int_le8_impl { { let (range, low) = Self::sample_inc_setup(low_b, high_b); let is_full_range = range.eq($unsigned::splat(0)); + let range: $u_extra_large = range.cast(); // generate a sample using a sensible integer type - let rand_bits = rng.gen::<$unsigned>(); + let rand_bits = rng.gen::<$u_extra_large>(); let (mut result, lo_order) = rand_bits.wmul(range); - if lo_order.gt(0 - range).any() { + let nrmr = ((0 - range) % range); + if lo_order.lt(nrmr).any() { Self::canon_successive(range, &mut result, lo_order, rng); } - let cast_result: $ty = result.cast(); - let cast_rand_bits: $ty = rand_bits.cast(); - is_full_range.select(cast_rand_bits, low + cast_result) + // truncate and return the result: + let result: $ty = result.cast(); + let rand_bits: $ty = rand_bits.cast(); + is_full_range.select(rand_bits, low + result) } } }; - ($(($unsigned:ident, $signed:ident, $u64xN_type:ident, $u_extra_large:ident)),+) => {$( - uniform_simd_int_le8_impl!{ $unsigned, $unsigned, $u64xN_type, $u_extra_large } - uniform_simd_int_le8_impl!{ $signed, $unsigned, $u64xN_type, $u_extra_large } + ($(($unsigned:ident, $signed:ident, $u_extra_large:ident),)+) => {$( + uniform_simd_int_le8_impl!{ $unsigned, $unsigned, $u_extra_large } + uniform_simd_int_le8_impl!{ $signed, $unsigned, $u_extra_large } )+}; } @@ -1022,41 +1080,37 @@ uniform_simd_int_impl! { u8 } -#[cfg(feature = "simd_support")] -uniform_simd_int_gt8_impl! { - (u8x16, i8x16), - (u8x32, i8x32), - (u8x64, i8x64), - - (u16x16, i16x16), - (u16x32, i16x32), - - (u32x16, i32x16) -} +// #[cfg(feature = "simd_support")] +// uniform_simd_int_gt8_impl! { +// (u8x16, i8x16), +// (u8x32, i8x32), +// (u8x64, i8x64), +// +// (u16x16, i16x16), +// (u16x32, i16x32), +// +// (u32x16, i32x16) +// } #[cfg(feature = "simd_support")] uniform_simd_int_le8_impl! { - (u8x2, i8x2, i64x2, u64x2), - (u8x4, i8x4, i64x4, u64x4), - (u8x8, i8x8, i64x8, u64x8), - - (u16x2, i16x2, i64x2, u64x2), - (u16x4, i16x4, i64x4, u64x4), - (u16x8, i16x8, i64x8, u64x8), + (u8x2, i8x2, u64x2), + (u8x4, i8x4, u64x4), + (u8x8, i8x8, u64x8), - (u32x2, i32x2, i64x2, u64x2), - (u32x4, i32x4, i64x4, u64x4), - (u32x8, i32x8, i64x8, u64x8), + (u16x2, i16x2, u64x2), + (u16x4, i16x4, u64x4), + (u16x8, i16x8, u64x8), - (u64x2, i64x2, i64x2, u64x2), - (u64x4, i64x4, i64x4, u64x4), - (u64x8, i64x8, i64x8, u64x8), + (u32x2, i32x2, u64x2), + (u32x4, i32x4, u64x4), + (u32x8, i32x8, u64x8), - (u128x2, i128x2, i64x2, u128x2), - (u128x4, i128x4, i64x4, u128x4) + (u64x2, i64x2, u64x2), + (u64x4, i64x4, u64x4), + (u64x8, i64x8, u64x8), - // (usizex2, isizex2, i64x2, u64x2), - // (usizex4, isizex4, i64x4, u64x4), - // (usizex8, isizex8, i64x8, u64x8) + (u128x2, i128x2, u128x2), + (u128x4, i128x4, u128x4), } #[cfg(test)]