From 9dc74b761d7bd3d63506686dfd52da35f1790cdf Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Wed, 22 Nov 2023 07:06:23 +0800 Subject: [PATCH] Use RNG (random number generator) provided by RAFT (#79) Replace the random number generator (RNG) implemented by wholegraph with RNG provided by RAFT. It is put forward by issue #7 and issue #23. Authors: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) Approvers: - https://github.com/dongxuy04 - Chuang Zhu (https://github.com/chuangz0) URL: https://github.com/rapidsai/wholegraph/pull/79 --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- cpp/src/wholegraph_ops/raft_random_gen.cu | 28 +++- ...ighted_sample_without_replacement_func.cuh | 84 +++++----- ...ighted_sample_without_replacement_func.cuh | 30 ++-- cpp/src/wholememory_ops/raft_random.cuh | 143 ------------------ .../graph_sampling_test_utils.cu | 24 ++- 6 files changed, 105 insertions(+), 206 deletions(-) delete mode 100644 cpp/src/wholememory_ops/raft_random.cuh diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 9d116b4dd..77e8f8059 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -57,7 +57,7 @@ endfunction() # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${WHOLEGRAPH_MIN_VERSION_raft} FORK rapidsai - PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft} + PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft} # When PINNED_TAG above doesn't match wholegraph, # force local raft clone in build directory diff --git a/cpp/src/wholegraph_ops/raft_random_gen.cu b/cpp/src/wholegraph_ops/raft_random_gen.cu index b7277781f..5e4c802e1 100644 --- a/cpp/src/wholegraph_ops/raft_random_gen.cu +++ b/cpp/src/wholegraph_ops/raft_random_gen.cu @@ -16,7 +16,9 @@ #include #include -#include + +#include +#include #include "error.hpp" #include "logger.hpp" @@ -37,15 +39,25 @@ wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed, } auto* output_ptr = wholememory_tensor_get_data_pointer(output); - PCGenerator rng((unsigned long long)random_seed, subsequence, 0); + + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence); + for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) { + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; int32_t random_num; - rng.next(random_num); + raft::random::detail::custom_next(rng, &random_num, params, 0, 0); static_cast(output_ptr)[i] = random_num; } else { + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; int64_t random_num; - rng.next(random_num); + raft::random::detail::custom_next(rng, &random_num, params, 0, 0); static_cast(output_ptr)[i] = random_num; } } @@ -65,9 +77,13 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( return WHOLEMEMORY_INVALID_INPUT; } auto* output_ptr = wholememory_tensor_get_data_pointer(output); - PCGenerator rng((unsigned long long)random_seed, subsequence, 0); + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence); for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { - float u = -rng.next_float(1.0f, 0.5f); + float u = 0.0; + rng.next(u); + u = -(0.5 + 0.5 * u); uint64_t random_num2 = 0; int seed_count = -1; do { diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh index 0581090c3..291b26b2d 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include #include @@ -25,7 +27,6 @@ #include #include "wholememory_ops/output_memory_handle.hpp" -#include "wholememory_ops/raft_random.cuh" #include "wholememory_ops/temp_memory_handle.hpp" #include "wholememory_ops/thrust_allocator.hpp" @@ -58,25 +59,25 @@ __global__ void get_sample_count_without_replacement_kernel( } template -__global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr, - wholememory_array_description_t wm_csr_row_ptr_desc, - wholememory_gref_t wm_csr_col_ptr, - wholememory_array_description_t wm_csr_col_ptr_desc, - const IdType* input_nodes, - const int input_node_count, - const int max_sample_count, - unsigned long long random_seed, - const int* sample_offset, - wholememory_array_description_t sample_offset_desc, - WMIdType* output, - int* src_lid, - int64_t* output_edge_gid_ptr) +__global__ void large_sample_kernel( + wholememory_gref_t wm_csr_row_ptr, + wholememory_array_description_t wm_csr_row_ptr_desc, + wholememory_gref_t wm_csr_col_ptr, + wholememory_array_description_t wm_csr_col_ptr_desc, + const IdType* input_nodes, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + WMIdType* output, + int* src_lid, + int64_t* output_edge_gid_ptr) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; int gidx = threadIdx.x + blockIdx.x * blockDim.x; - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); - + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); @@ -104,8 +105,11 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr, } __syncthreads(); for (int idx = max_sample_count + threadIdx.x; idx < neighbor_count; idx += blockDim.x) { + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; int32_t rand_num; - rng.next(rand_num); + raft::random::detail::custom_next(rng, &rand_num, params, 0, 0); rand_num %= idx + 1; if (rand_num < max_sample_count) { atomicMax((int*)(output + offset + rand_num), idx); } } @@ -139,7 +143,7 @@ __global__ void unweighted_sample_without_replacement_kernel( const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, @@ -147,7 +151,7 @@ __global__ void unweighted_sample_without_replacement_kernel( int64_t* output_edge_gid_ptr) { int gidx = threadIdx.x + blockIdx.x * blockDim.x; - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; @@ -193,9 +197,12 @@ __global__ void unweighted_sample_without_replacement_kernel( #pragma unroll for (int i = 0; i < ITEMS_PER_THREAD; i++) { int idx = i * BLOCK_DIM + threadIdx.x; - int32_t random_num; - rng.next(random_num); - int32_t r = idx < M ? (random_num % (N - idx)) : N; + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; + int32_t rand_num; + raft::random::detail::custom_next(rng, &rand_num, params, 0, 0); + int32_t r = idx < M ? rand_num % (N - idx) : N; sa_p[i] = ((uint64_t)r << 32UL) | idx; } __syncthreads(); @@ -364,6 +371,8 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( (int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64); } // sample node + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count <= 0) { sample_all_kernel <<>>(wm_csr_row_ptr, @@ -392,7 +401,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, (const int*)output_sample_offset, output_sample_offset_desc, (WMIdType*)output_dest_node_ptr, @@ -403,19 +412,20 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( return; } - typedef void (*unweighted_sample_func_type)(wholememory_gref_t wm_csr_row_ptr, - wholememory_array_description_t wm_csr_row_ptr_desc, - wholememory_gref_t wm_csr_col_ptr, - wholememory_array_description_t wm_csr_col_ptr_desc, - const IdType* input_nodes, - const int input_node_count, - const int max_sample_count, - unsigned long long random_seed, - const int* sample_offset, - wholememory_array_description_t sample_offset_desc, - WMIdType* output, - int* src_lid, - int64_t* output_edge_gid_ptr); + typedef void (*unweighted_sample_func_type)( + wholememory_gref_t wm_csr_row_ptr, + wholememory_array_description_t wm_csr_row_ptr_desc, + wholememory_gref_t wm_csr_col_ptr, + wholememory_array_description_t wm_csr_col_ptr_desc, + const IdType* input_nodes, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + WMIdType* output, + int* src_lid, + int64_t* output_edge_gid_ptr); static const unweighted_sample_func_type func_array[32] = { unweighted_sample_without_replacement_kernel, unweighted_sample_without_replacement_kernel, @@ -460,7 +470,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, (const int*)output_sample_offset, output_sample_offset_desc, (WMIdType*)output_dest_node_ptr, diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index a2915cd00..de75d7394 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -25,9 +25,10 @@ #include "raft/matrix/detail/select_warpsort.cuh" #include "raft/util/cuda_dev_essentials.cuh" #include "wholememory_ops/output_memory_handle.hpp" -#include "wholememory_ops/raft_random.cuh" #include "wholememory_ops/temp_memory_handle.hpp" #include "wholememory_ops/thrust_allocator.hpp" +#include +#include #include #include #include @@ -41,9 +42,12 @@ namespace wholegraph_ops { template -__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng) +__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, + raft::random::detail::PCGenerator& rng) { - float u = -rng.next_float(1.0f, 0.5f); + float u = 0.0; + rng.next(u); + u = -(0.5 + 0.5 * u); uint64_t random_num2 = 0; int seed_count = -1; do { @@ -75,7 +79,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_ke const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* target_neighbor_offset, WeightKeyType* output_weighted_keys, NeighborIdxType* output_idxs, @@ -93,7 +97,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_ke int neighbor_count = (int)(end - start); if (neighbor_count <= max_sample_count) { need_random = false; } - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); int output_offset = target_neighbor_offset[input_idx]; output_weighted_keys += output_offset; output_idxs += output_offset; @@ -222,7 +226,7 @@ __launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_ const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, @@ -259,7 +263,7 @@ __launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_ uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; bq_t queue(max_sample_count, warp_smem); - PCGenerator rng(random_seed, static_cast(gidx), static_cast(0)); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); const int per_thread_lim = neighbor_count + raft::laneId(); for (int idx = threadIdx.x; idx < per_thread_lim; idx += blockDim.x) { WeightType weight_key = @@ -307,7 +311,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr, const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, @@ -339,7 +343,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr, input_nodes, input_node_count, max_sample_count, - random_seed, + rngstate, sample_offset, sample_offset_desc, output, @@ -374,7 +378,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr, input_nodes, input_node_count, max_sample_count, - random_seed, + rngstate, sample_offset, sample_offset_desc, output, @@ -492,6 +496,8 @@ void wholegraph_csr_weighted_sample_without_replacement_func( gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64)); } + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count > sample_count_threshold) { wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns); thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream), @@ -541,7 +547,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, tmp_neighbor_counts_mem_pointer, tmp_weights_buffer0_mem_pointer, local_idx_buffer0_mem_pointer, @@ -641,7 +647,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, static_cast(output_sample_offset), output_sample_offset_desc, output_dest_node_ptr, diff --git a/cpp/src/wholememory_ops/raft_random.cuh b/cpp/src/wholememory_ops/raft_random.cuh deleted file mode 100644 index 8d1b9ac3b..000000000 --- a/cpp/src/wholememory_ops/raft_random.cuh +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -/** PCG random number generator from raft */ -struct PCGenerator { - /** - * @brief ctor. Initializes the state for RNG. This code is derived from PCG basic code - * @param seed the seed (can be same across all threads). Same as PCG's initstate - * @param subsequence is same as PCG's initseq - * @param offset unused - */ - __host__ __device__ __forceinline__ PCGenerator(uint64_t seed, - uint64_t subsequence, - uint64_t offset) - { - pcg_state = uint64_t(0); - inc = (subsequence << 1u) | 1u; - uint32_t discard; - next(discard); - pcg_state += seed; - next(discard); - skipahead(offset); - } - - // Based on "Random Number Generation with Arbitrary Strides" F. B. Brown - // Link https://mcnp.lanl.gov/pdf_files/anl-rn-arb-stride.pdf - __host__ __device__ __forceinline__ void skipahead(uint64_t offset) - { - uint64_t G = 1; - uint64_t h = 6364136223846793005ULL; - uint64_t C = 0; - uint64_t f = inc; - while (offset) { - if (offset & 1) { - G = G * h; - C = C * h + f; - } - f = f * (h + 1); - h = h * h; - offset >>= 1; - } - pcg_state = pcg_state * G + C; - } - - /** - * @defgroup NextRand Generate the next random number - * @brief This code is derived from PCG basic code - * @{ - */ - __host__ __device__ __forceinline__ uint32_t next_u32() - { - uint32_t ret; - uint64_t oldstate = pcg_state; - pcg_state = oldstate * 6364136223846793005ULL + inc; - uint32_t xorshifted = ((oldstate >> 18u) ^ oldstate) >> 27u; - uint32_t rot = oldstate >> 59u; - ret = (xorshifted >> rot) | (xorshifted << ((-rot) & 31)); - return ret; - } - __host__ __device__ __forceinline__ uint64_t next_u64() - { - uint64_t ret; - uint32_t a, b; - a = next_u32(); - b = next_u32(); - ret = uint64_t(a) | (uint64_t(b) << 32); - return ret; - } - - __host__ __device__ __forceinline__ int32_t next_i32() - { - int32_t ret; - uint32_t val; - val = next_u32(); - ret = int32_t(val & 0x7fffffff); - return ret; - } - - __host__ __device__ __forceinline__ int64_t next_i64() - { - int64_t ret; - uint64_t val; - val = next_u64(); - ret = int64_t(val & 0x7fffffffffffffff); - return ret; - } - - __host__ __device__ __forceinline__ float next_float() - { - float ret; - uint32_t val = next_u32() >> 8; - ret = static_cast(val) / (1U << 24); - return ret; - } - - __host__ __device__ __forceinline__ float next_float(float max, float min) - { - float ret; - uint32_t val = next_u32() >> 8; - ret = static_cast(val) / (1U << 24); - ret *= (max - min); - ret += min; - return ret; - } - - __host__ __device__ __forceinline__ double next_double() - { - double ret; - uint64_t val = next_u64() >> 11; - ret = static_cast(val) / (1LU << 53); - return ret; - } - - __host__ __device__ __forceinline__ void next(uint32_t& ret) { ret = next_u32(); } - __host__ __device__ __forceinline__ void next(uint64_t& ret) { ret = next_u64(); } - __host__ __device__ __forceinline__ void next(int32_t& ret) { ret = next_i32(); } - __host__ __device__ __forceinline__ void next(int64_t& ret) { ret = next_i64(); } - - __host__ __device__ __forceinline__ void next(float& ret) { ret = next_float(); } - __host__ __device__ __forceinline__ void next(double& ret) { ret = next_double(); } - - /** @} */ - - private: - uint64_t pcg_state; - uint64_t inc; -}; diff --git a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu index 54f3ab934..45fa042ee 100644 --- a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu +++ b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu @@ -23,7 +23,8 @@ #include #include -#include "wholememory_ops/raft_random.cuh" +#include +#include #include namespace wholegraph_ops { @@ -383,12 +384,17 @@ void host_unweighted_sample_without_replacement( std::vector r(neighbor_count); for (int j = 0; j < device_num_threads; j++) { int local_gidx = gidx + j; - PCGenerator rng(random_seed, (uint64_t)local_gidx, (uint64_t)0); + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)local_gidx); + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; for (int k = 0; k < items_per_thread; k++) { int id = k * device_num_threads + j; int32_t random_num; - rng.next(random_num); + raft::random::detail::custom_next(rng, &random_num, params, 0, 0); if (id < neighbor_count) { r[id] = id < M ? (random_num % (N - id)) : N; } } } @@ -543,9 +549,11 @@ inline int count_one(unsigned long long num) } template -float host_gen_key_from_weight(const WeightType weight, PCGenerator& rng) +float host_gen_key_from_weight(const WeightType weight, raft::random::detail::PCGenerator& rng) { - float u = -rng.next_float(1.0f, 0.5f); + float u = 0.0; + rng.next(u); + u = -(0.5 + 0.5 * u); uint64_t random_num2 = 0; int seed_count = -1; do { @@ -624,7 +632,7 @@ void host_weighted_sample_without_replacement( std::priority_queue, std::vector>, cmp> small_heap; - auto consume_fun = [&](int id, PCGenerator& rng) { + auto consume_fun = [&](int id, raft::random::detail::PCGenerator& rng) { WeightType edge_weight = csr_weight_ptr[start + id]; WeightType weight = host_gen_key_from_weight(edge_weight, rng); process_count++; @@ -641,7 +649,9 @@ void host_weighted_sample_without_replacement( for (int j = 0; j < block_size; j++) { int local_gidx = gidx + j; - PCGenerator rng(random_seed, (uint64_t)local_gidx, (uint64_t)0); + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)local_gidx); for (int id = j; id < neighbor_count; id += block_size) { if (id < neighbor_count) { consume_fun(id, rng); } }