Skip to content

Commit

Permalink
<random>: Implement Lemire's fast integer generation (microsoft#3012)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicole Mazzuca <mazzucan@outlook.com>
Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
  • Loading branch information
3 people authored and CaseyCarter committed Oct 6, 2022
1 parent a72c4e6 commit 908f716
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 10 deletions.
6 changes: 2 additions & 4 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,5 @@ function(add_benchmark name)
target_link_libraries(benchmark-${name} PRIVATE benchmark::benchmark)
endfunction()

add_benchmark(std_copy
src/std_copy.cpp
CXX_STANDARD 23
)
add_benchmark(std_copy src/std_copy.cpp)
add_benchmark(random_integer_generation src/random_integer_generation.cpp)
102 changes: 102 additions & 0 deletions benchmarks/src/random_integer_generation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <benchmark/benchmark.h>
#include <cstdint>
#include <random>

/// Test URBGs alone

static void BM_mt19937(benchmark::State& state) {
std::mt19937 gen;
for (auto _ : state) {
benchmark::DoNotOptimize(gen());
}
}
BENCHMARK(BM_mt19937);

static void BM_mt19937_64(benchmark::State& state) {
std::mt19937_64 gen;
for (auto _ : state) {
benchmark::DoNotOptimize(gen());
}
}
BENCHMARK(BM_mt19937_64);

static void BM_lcg(benchmark::State& state) {
std::minstd_rand gen;
for (auto _ : state) {
benchmark::DoNotOptimize(gen());
}
}
BENCHMARK(BM_lcg);

std::uint32_t GetMax() {
std::random_device gen;
std::uniform_int_distribution<std::uint32_t> dist(10'000'000, 20'000'000);
return dist(gen);
}

static const std::uint32_t maximum = GetMax(); // random divisor to prevent strength reduction

/// Test mt19937

static void BM_raw_mt19937_old(benchmark::State& state) {
std::mt19937 gen;
std::_Rng_from_urng<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_old);

static void BM_raw_mt19937_new(benchmark::State& state) {
std::mt19937 gen;
std::_Rng_from_urng_v2<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_new);

/// Test mt19937_64

static void BM_raw_mt19937_64_old(benchmark::State& state) {
std::mt19937_64 gen;
std::_Rng_from_urng<std::uint64_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_64_old);

static void BM_raw_mt19937_64_new(benchmark::State& state) {
std::mt19937_64 gen;
std::_Rng_from_urng_v2<std::uint64_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_mt19937_64_new);

/// Test minstd_rand

static void BM_raw_lcg_old(benchmark::State& state) {
std::minstd_rand gen;
std::_Rng_from_urng<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_lcg_old);

static void BM_raw_lcg_new(benchmark::State& state) {
std::minstd_rand gen;
std::_Rng_from_urng_v2<std::uint32_t, decltype(gen)> rng(gen);
for (auto _ : state) {
benchmark::DoNotOptimize(rng(maximum));
}
}
BENCHMARK(BM_raw_lcg_new);

BENCHMARK_MAIN();
6 changes: 4 additions & 2 deletions stl/inc/random
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,9 @@ private:

template <class _Engine>
result_type _Eval(_Engine& _Eng, _Ty _Min, _Ty _Max) const { // compute next value in range [_Min, _Max]
_Rng_from_urng<_Uty, _Engine> _Generator(_Eng);
conditional_t<_Has_static_min_max<_Engine>::value, _Rng_from_urng_v2<_Uty, _Engine>,
_Rng_from_urng<_Uty, _Engine>>
_Generator(_Eng);

const _Uty _Umin = _Adjust(static_cast<_Uty>(_Min));
const _Uty _Umax = _Adjust(static_cast<_Uty>(_Max));
Expand All @@ -1862,7 +1864,7 @@ private:

static _Uty _Adjust(_Uty _Uval) { // convert signed ranges to unsigned ranges and vice versa
if constexpr (is_signed_v<_Ty>) {
const _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)
constexpr _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)

if (_Uval < _Adjuster) {
return static_cast<_Uty>(_Uval + _Adjuster);
Expand Down
132 changes: 128 additions & 4 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <yvals.h>
#if _STL_COMPILER_PREPROCESSOR

#include <__msvc_int128.hpp>
#include <__msvc_iter_core.hpp>
#include <climits>
#include <cstdlib>
Expand Down Expand Up @@ -6030,7 +6031,8 @@ public:

using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;

explicit _Rng_from_urng(_Urng& _Func) : _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(_Udiff(-1)) {
explicit _Rng_from_urng(_Urng& _Func)
: _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(static_cast<_Udiff>(-1)) {
for (; static_cast<_Udiff>((_Urng::max)() - (_Urng::min)()) < _Bmask; _Bmask >>= 1) {
--_Bits;
}
Expand All @@ -6041,7 +6043,7 @@ public:
_Udiff _Ret = 0; // random bits
_Udiff _Mask = 0; // 2^N - 1, _Ret is within [0, _Mask]

while (_Mask < _Udiff(_Index - 1)) { // need more random bits
while (_Mask < static_cast<_Udiff>(_Index - 1)) { // need more random bits
_Ret <<= _Bits - 1; // avoid full shift
_Ret <<= 1;
_Ret |= _Get_bits();
Expand All @@ -6051,7 +6053,7 @@ public:
}

// _Ret is [0, _Mask], _Index - 1 <= _Mask, return if unbiased
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == _Udiff(_Index - 1)) {
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == static_cast<_Udiff>(_Index - 1)) {
return static_cast<_Diff>(_Ret % _Index);
}
}
Expand All @@ -6075,7 +6077,7 @@ public:
private:
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
for (;;) { // repeat until random value is in range
_Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());
const _Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());

if (_Val <= _Bmask) {
return _Val;
Expand All @@ -6088,6 +6090,128 @@ private:
_Udiff _Bmask; // 2^_Bits - 1
};

template <class _Diff, class _Urng>
class _Rng_from_urng_v2 { // wrap a URNG as an RNG
public:
using _Ty0 = make_unsigned_t<_Diff>;
using _Ty1 = _Invoke_result_t<_Urng&>;

using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT;
using _Uprod = conditional_t<_Udiff_bits <= 16, uint32_t, conditional_t<_Udiff_bits <= 32, uint64_t, _Unsigned128>>;

explicit _Rng_from_urng_v2(_Urng& _Func) : _Ref(_Func) {}

_Diff operator()(_Diff _Index) { // adapt _Urng closed range to [0, _Index)
// From Daniel Lemire, "Fast Random Integer Generation in an Interval", ACM Trans. Model. Comput. Simul. 29 (1),
// 2019.
//
// Algorithm 5 <-> This Code:
// m <-> _Product
// l <-> _Rem
// s <-> _Index
// t <-> _Threshold
// L <-> _Generated_bits
// 2^L - 1 <-> _Mask

_Udiff _Mask = _Bmask;
unsigned int _Niter = 1;

if constexpr (_Bits < _Udiff_bits) {
while (_Mask < static_cast<_Udiff>(_Index - 1)) {
_Mask <<= _Bits;
_Mask |= _Bmask;
++_Niter;
}
}

// x <- random integer in [0, 2^L)
// m <- x * s
auto _Product = _Get_random_product(_Index, _Niter);
// l <- m mod 2^L
auto _Rem = static_cast<_Udiff>(_Product) & _Mask;

if (_Rem < _Index) {
// t <- (2^L - s) mod s
const auto _Threshold = (_Mask - _Index + 1) % _Index;
while (_Rem < _Threshold) {
_Product = _Get_random_product(_Index, _Niter);
_Rem = static_cast<_Udiff>(_Product) & _Mask;
}
}

unsigned int _Generated_bits;
if constexpr (_Bits < _Udiff_bits) {
_Generated_bits = static_cast<unsigned int>(_Popcount(_Mask));
} else {
_Generated_bits = _Udiff_bits;
}

// m / 2^L
return static_cast<_Diff>(_Product >> _Generated_bits);
}

_Udiff _Get_all_bits() {
_Udiff _Ret = _Get_bits();

if constexpr (_Bits < _Udiff_bits) {
for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}

return _Ret;
}

_Rng_from_urng_v2(const _Rng_from_urng_v2&) = delete;
_Rng_from_urng_v2& operator=(const _Rng_from_urng_v2&) = delete;

private:
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
static constexpr auto _Urng_min = (_Urng::min)();
for (;;) { // repeat until random value is in range
const _Udiff _Val = _Ref() - _Urng_min;

if (_Val <= _Bmask) {
return _Val;
}
}
}

static constexpr size_t _Calc_bits() {
auto _Bits_local = _Udiff_bits;
auto _Bmask_local = static_cast<_Udiff>(-1);
for (; (_Urng::max)() - (_Urng::min)() < _Bmask_local; _Bmask_local >>= 1) {
--_Bits_local;
}

return _Bits_local;
}

_Uprod _Get_random_product(const _Diff _Index, unsigned int _Niter) {
_Udiff _Ret = _Get_bits();
if constexpr (_Bits < _Udiff_bits) {
while (--_Niter > 0) {
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}

if constexpr (is_same_v<_Udiff, uint64_t>) {
uint64_t _High;
const auto _Low = _Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High);
return _Uprod{_Low, _High};
} else {
return _Uprod{_Ret} * _Uprod{_Index};
}
}

_Urng& _Ref; // reference to URNG
static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits()
static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1
};

extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xbad_alloc();
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xinvalid_argument(_In_z_ const char*);
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xlength_error(_In_z_ const char*);
Expand Down
1 change: 1 addition & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ tests\Dev11_1150223_shared_mutex
tests\Dev11_1158803_regex_thread_safety
tests\Dev11_1180290_filesystem_error_code
tests\GH_000177_forbidden_aliasing
tests\GH_000178_uniform_int
tests\GH_000342_filebuf_close
tests\GH_000431_copy_move_family
tests\GH_000431_equal_family
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/GH_000178_uniform_int/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\usual_matrix.lst
Loading

0 comments on commit 908f716

Please sign in to comment.