Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Use RNG (random number generator) provided by RAFT (#79)
Browse files Browse the repository at this point in the history
Replace the random number generator (RNG) implemented by wholegraph with RNG provided by RAFT. 
It is put forward by issue #7 and issue #23.

Authors:
  - https://github.com/linhu-nv
  - Brad Rees (https://github.com/BradReesWork)

Approvers:
  - https://github.com/dongxuy04
  - Chuang Zhu (https://github.com/chuangz0)

URL: #79
  • Loading branch information
linhu-nv authored Nov 21, 2023
1 parent 7a3c873 commit 9dc74b7
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 206 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ endfunction()
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${WHOLEGRAPH_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft}
PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft}

# When PINNED_TAG above doesn't match wholegraph,
# force local raft clone in build directory
Expand Down
28 changes: 22 additions & 6 deletions cpp/src/wholegraph_ops/raft_random_gen.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <cmath>
#include <wholememory/wholegraph_op.h>
#include <wholememory_ops/raft_random.cuh>

#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>

#include "error.hpp"
#include "logger.hpp"
Expand All @@ -37,15 +39,25 @@ wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed,
}

auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);

raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate(_rngstate);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence);

for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) {
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t random_num;
rng.next(random_num);
raft::random::detail::custom_next(rng, &random_num, params, 0, 0);
static_cast<int*>(output_ptr)[i] = random_num;
} else {
raft::random::detail::UniformDistParams<int64_t> params;
params.start = 0;
params.end = 1;
int64_t random_num;
rng.next(random_num);
raft::random::detail::custom_next(rng, &random_num, params, 0, 0);
static_cast<int64_t*>(output_ptr)[i] = random_num;
}
}
Expand All @@ -65,9 +77,13 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
return WHOLEMEMORY_INVALID_INPUT;
}
auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);
raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate(_rngstate);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence);
for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
float u = -rng.next_float(1.0f, 0.5f);
float u = 0.0;
rng.next(u);
u = -(0.5 + 0.5 * u);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
#include <random>
#include <thrust/scan.h>

#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/integer_utils.hpp>
#include <wholememory/device_reference.cuh>
#include <wholememory/env_func_ptrs.h>
#include <wholememory/global_reference.h>
#include <wholememory/tensor_description.h>

#include "wholememory_ops/output_memory_handle.hpp"
#include "wholememory_ops/raft_random.cuh"
#include "wholememory_ops/temp_memory_handle.hpp"
#include "wholememory_ops/thrust_allocator.hpp"

Expand Down Expand Up @@ -58,25 +59,25 @@ __global__ void get_sample_count_without_replacement_kernel(
}

template <typename IdType, typename LocalIdType, typename WMIdType, typename WMOffsetType>
__global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
wholememory_array_description_t wm_csr_row_ptr_desc,
wholememory_gref_t wm_csr_col_ptr,
wholememory_array_description_t wm_csr_col_ptr_desc,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
int* src_lid,
int64_t* output_edge_gid_ptr)
__global__ void large_sample_kernel(
wholememory_gref_t wm_csr_row_ptr,
wholememory_array_description_t wm_csr_row_ptr_desc,
wholememory_gref_t wm_csr_col_ptr,
wholememory_array_description_t wm_csr_col_ptr_desc,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
int* src_lid,
int64_t* output_edge_gid_ptr)
{
int input_idx = blockIdx.x;
if (input_idx >= input_node_count) return;
int gidx = threadIdx.x + blockIdx.x * blockDim.x;
PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);

raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
wholememory::device_reference<WMOffsetType> csr_row_ptr_gen(wm_csr_row_ptr);
wholememory::device_reference<WMIdType> csr_col_ptr_gen(wm_csr_col_ptr);

Expand Down Expand Up @@ -104,8 +105,11 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
}
__syncthreads();
for (int idx = max_sample_count + threadIdx.x; idx < neighbor_count; idx += blockDim.x) {
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t rand_num;
rng.next(rand_num);
raft::random::detail::custom_next(rng, &rand_num, params, 0, 0);
rand_num %= idx + 1;
if (rand_num < max_sample_count) { atomicMax((int*)(output + offset + rand_num), idx); }
}
Expand Down Expand Up @@ -139,15 +143,15 @@ __global__ void unweighted_sample_without_replacement_kernel(
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
int* src_lid,
int64_t* output_edge_gid_ptr)
{
int gidx = threadIdx.x + blockIdx.x * blockDim.x;
PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
int input_idx = blockIdx.x;
if (input_idx >= input_node_count) return;

Expand Down Expand Up @@ -193,9 +197,12 @@ __global__ void unweighted_sample_without_replacement_kernel(
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
int idx = i * BLOCK_DIM + threadIdx.x;
int32_t random_num;
rng.next(random_num);
int32_t r = idx < M ? (random_num % (N - idx)) : N;
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t rand_num;
raft::random::detail::custom_next(rng, &rand_num, params, 0, 0);
int32_t r = idx < M ? rand_num % (N - idx) : N;
sa_p[i] = ((uint64_t)r << 32UL) | idx;
}
__syncthreads();
Expand Down Expand Up @@ -364,6 +371,8 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64);
}
// sample node
raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate(_rngstate);
if (max_sample_count <= 0) {
sample_all_kernel<IdType, int, WMIdType, int64_t>
<<<center_node_count, 64, 0, stream>>>(wm_csr_row_ptr,
Expand Down Expand Up @@ -392,7 +401,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
(WMIdType*)output_dest_node_ptr,
Expand All @@ -403,19 +412,20 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
return;
}

typedef void (*unweighted_sample_func_type)(wholememory_gref_t wm_csr_row_ptr,
wholememory_array_description_t wm_csr_row_ptr_desc,
wholememory_gref_t wm_csr_col_ptr,
wholememory_array_description_t wm_csr_col_ptr_desc,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
int* src_lid,
int64_t* output_edge_gid_ptr);
typedef void (*unweighted_sample_func_type)(
wholememory_gref_t wm_csr_row_ptr,
wholememory_array_description_t wm_csr_row_ptr_desc,
wholememory_gref_t wm_csr_col_ptr,
wholememory_array_description_t wm_csr_col_ptr_desc,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
int* src_lid,
int64_t* output_edge_gid_ptr);
static const unweighted_sample_func_type func_array[32] = {
unweighted_sample_without_replacement_kernel<IdType, int, WMIdType, int64_t, 32, 1>,
unweighted_sample_without_replacement_kernel<IdType, int, WMIdType, int64_t, 32, 2>,
Expand Down Expand Up @@ -460,7 +470,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
(WMIdType*)output_dest_node_ptr,
Expand Down
30 changes: 18 additions & 12 deletions cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
#include "raft/matrix/detail/select_warpsort.cuh"
#include "raft/util/cuda_dev_essentials.cuh"
#include "wholememory_ops/output_memory_handle.hpp"
#include "wholememory_ops/raft_random.cuh"
#include "wholememory_ops/temp_memory_handle.hpp"
#include "wholememory_ops/thrust_allocator.hpp"
#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/integer_utils.hpp>
#include <wholememory/device_reference.cuh>
#include <wholememory/env_func_ptrs.h>
Expand All @@ -41,9 +42,12 @@
namespace wholegraph_ops {

template <typename WeightType>
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng)
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight,
raft::random::detail::PCGenerator& rng)
{
float u = -rng.next_float(1.0f, 0.5f);
float u = 0.0;
rng.next(u);
u = -(0.5 + 0.5 * u);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
Expand Down Expand Up @@ -75,7 +79,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_ke
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* target_neighbor_offset,
WeightKeyType* output_weighted_keys,
NeighborIdxType* output_idxs,
Expand All @@ -93,7 +97,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_ke
int neighbor_count = (int)(end - start);
if (neighbor_count <= max_sample_count) { need_random = false; }

PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
int output_offset = target_neighbor_offset[input_idx];
output_weighted_keys += output_offset;
output_idxs += output_offset;
Expand Down Expand Up @@ -222,7 +226,7 @@ __launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand Down Expand Up @@ -259,7 +263,7 @@ __launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_

uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr;
bq_t queue(max_sample_count, warp_smem);
PCGenerator rng(random_seed, static_cast<uint64_t>(gidx), static_cast<uint64_t>(0));
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
const int per_thread_lim = neighbor_count + raft::laneId();
for (int idx = threadIdx.x; idx < per_thread_lim; idx += blockDim.x) {
WeightType weight_key =
Expand Down Expand Up @@ -307,7 +311,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand Down Expand Up @@ -339,7 +343,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr,
input_nodes,
input_node_count,
max_sample_count,
random_seed,
rngstate,
sample_offset,
sample_offset_desc,
output,
Expand Down Expand Up @@ -374,7 +378,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr,
input_nodes,
input_node_count,
max_sample_count,
random_seed,
rngstate,
sample_offset,
sample_offset_desc,
output,
Expand Down Expand Up @@ -492,6 +496,8 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64));
}

raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate(_rngstate);
if (max_sample_count > sample_count_threshold) {
wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns);
thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream),
Expand Down Expand Up @@ -541,7 +547,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
tmp_neighbor_counts_mem_pointer,
tmp_weights_buffer0_mem_pointer,
local_idx_buffer0_mem_pointer,
Expand Down Expand Up @@ -641,7 +647,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
static_cast<const int*>(output_sample_offset),
output_sample_offset_desc,
output_dest_node_ptr,
Expand Down
Loading

0 comments on commit 9dc74b7

Please sign in to comment.