Skip to content

Commit

Permalink
dedup SIMD sample_single_inclusive_bitmask
Browse files Browse the repository at this point in the history
  • Loading branch information
dhardy committed Feb 24, 2022
1 parent a97199b commit 9d0c88a
Showing 1 changed file with 45 additions and 90 deletions.
135 changes: 45 additions & 90 deletions src/distributions/uniform/uniform_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,51 @@ macro_rules! uniform_simd_int_impl {
let range: $unsigned = ((high - low) + 1).cast();
(range, low)
}

///
#[inline(always)]
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R,
) -> $ty
where
B1: SampleBorrow<$ty> + Sized,
B2: SampleBorrow<$ty> + Sized,
{
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
let is_full_range = range.eq($unsigned::splat(0));

// generate bitmask
range -= 1;
let mut mask = range | 1;

mask |= mask >> 1;
mask |= mask >> 2;
mask |= mask >> 4;

const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
if LANE_WIDTH >= 128 { mask |= mask >> 64; }

let mut v: $unsigned = rng.gen();
loop {
let masked = v & mask;
let accept = masked.le(range);
if accept.all() {
let masked: $ty = masked.cast();
// wrapping addition
let result = low + masked;
// `select` here compiles to a blend operation
// When `range.eq(0).none()` the compare and blend
// operations are avoided.
let v: $ty = v.cast();
return is_full_range.select(v, result);
}
// Replace only the failing lanes
v = accept.select(v, rng.gen());
}
}
}
};

Expand Down Expand Up @@ -806,51 +851,6 @@ macro_rules! uniform_simd_int_gt8_impl {
let cast_rand_bits: $ty = rand_bits.cast();
is_full_range.select(cast_rand_bits, low + cast_result)
}

/// Bitmask
#[inline(always)]
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R,
) -> $ty
where
B1: SampleBorrow<$ty> + Sized,
B2: SampleBorrow<$ty> + Sized,
{
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
let is_full_range = range.eq($unsigned::splat(0));

// generate bitmask
range -= 1;
let mut mask = range | 1;

mask |= mask >> 1;
mask |= mask >> 2;
mask |= mask >> 4;

const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
if LANE_WIDTH >= 128 { mask |= mask >> 64; }

let mut v: $unsigned = rng.gen();
loop {
let masked = v & mask;
let accept = masked.le(range);
if accept.all() {
let masked: $ty = masked.cast();
// wrapping addition
let result = low + masked;
// `select` here compiles to a blend operation
// When `range.eq(0).none()` the compare and blend
// operations are avoided.
let v: $ty = v.cast();
return is_full_range.select(v, result);
}
// Replace only the failing lanes
v = accept.select(v, rng.gen());
}
}
}
};

Expand Down Expand Up @@ -966,51 +966,6 @@ macro_rules! uniform_simd_int_le8_impl {
let cast_rand_bits: $ty = rand_bits.cast();
is_full_range.select(cast_rand_bits, low + cast_result)
}

///
#[inline(always)]
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R,
) -> $ty
where
B1: SampleBorrow<$ty> + Sized,
B2: SampleBorrow<$ty> + Sized,
{
let (mut range, low) = Self::sample_inc_setup(low_b, high_b);
let is_full_range = range.eq($unsigned::splat(0));

// generate bitmask
range -= 1;
let mut mask = range | 1;

mask |= mask >> 1;
mask |= mask >> 2;
mask |= mask >> 4;

const LANE_WIDTH: usize = std::mem::size_of::<$ty>() * 8 / <$ty>::lanes();
if LANE_WIDTH >= 16 { mask |= mask >> 8; }
if LANE_WIDTH >= 32 { mask |= mask >> 16; }
if LANE_WIDTH >= 64 { mask |= mask >> 32; }
if LANE_WIDTH >= 128 { mask |= mask >> 64; }

let mut v: $unsigned = rng.gen();
loop {
let masked = v & mask;
let accept = masked.le(range);
if accept.all() {
let masked: $ty = masked.cast();
// wrapping addition
let result = low + masked;
// `select` here compiles to a blend operation
// When `range.eq(0).none()` the compare and blend
// operations are avoided.
let v: $ty = v.cast();
return is_full_range.select(v, result);
}
// Replace only the failing lanes
v = accept.select(v, rng.gen());
}
}
}
};

Expand Down

0 comments on commit 9d0c88a

Please sign in to comment.