Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] SampleNeighbors (Without replacement for now) #6770

Merged
merged 59 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f243cd9
add the function prototype
mfbalin Dec 18, 2023
9d19abe
refactor sliceCSCIndptr function for reuse
mfbalin Dec 18, 2023
5183fea
implement one more required component.
mfbalin Dec 18, 2023
f5e7db4
fix CSRToCOO
mfbalin Dec 18, 2023
284e5dc
almost finished
mfbalin Dec 19, 2023
63b630f
fix bugs
mfbalin Dec 19, 2023
3675521
refactor different components into their own files
mfbalin Dec 19, 2023
159e409
change API to be more general
mfbalin Dec 19, 2023
e545a5b
fix sort parameter bug
mfbalin Dec 19, 2023
10a774b
refactor dispatch portion
mfbalin Dec 19, 2023
40c01b4
refactor devicetoHostCopy
mfbalin Dec 19, 2023
35d2344
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 20, 2023
103711f
fix num_bits calculation bug
mfbalin Dec 20, 2023
1f87981
unique_and_compact workaround
mfbalin Dec 20, 2023
aa4dd8b
fixes for the tests
mfbalin Dec 20, 2023
5865f08
linting
mfbalin Dec 20, 2023
3678b2d
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 20, 2023
24ca56a
fix test failure
mfbalin Dec 20, 2023
b19359b
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 20, 2023
0359cdb
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 25, 2023
5c46df6
resolve merge conflicts
mfbalin Dec 25, 2023
ff2e89f
linting
mfbalin Dec 25, 2023
35b3222
fix future merge conflict
mfbalin Dec 25, 2023
b0ac06d
remove changes in other PRs
mfbalin Dec 25, 2023
2280a2f
remove changes in other PRs
mfbalin Dec 25, 2023
3d18890
add documentation for SampleNeighbors
mfbalin Dec 25, 2023
4ef45d9
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 26, 2023
3e2760e
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 26, 2023
f634b76
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 26, 2023
48b7389
skip link tests that require IsIn for CUDA
mfbalin Dec 26, 2023
88f235b
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 26, 2023
da35b8e
address some of the reviews
mfbalin Dec 26, 2023
e1ea695
linting
mfbalin Dec 26, 2023
5a1afc1
polish the implementation and fix pinned indices bug
mfbalin Dec 26, 2023
da1a52b
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 26, 2023
207da5f
enable tests that require `gb.isin`.
mfbalin Dec 26, 2023
6232c0e
add working lightning example to test
mfbalin Dec 26, 2023
2b18364
improve variable names for better readability
mfbalin Dec 26, 2023
125ea18
linting
mfbalin Dec 26, 2023
f745c41
fix example dataloader params
mfbalin Dec 26, 2023
3ce33ed
fix example for new lightning version
mfbalin Dec 26, 2023
344bd2b
create a cuda sampling op header
mfbalin Dec 27, 2023
21d2469
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 27, 2023
265a617
fix the documentation bug
mfbalin Dec 27, 2023
e58affd
remove unnecessary include
mfbalin Dec 27, 2023
2435e55
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 27, 2023
6afcd29
remove doc for removed param
mfbalin Dec 27, 2023
f9c5cb6
use more efficient sorting algorithm
mfbalin Dec 27, 2023
d7ac4f6
use bfloat16 for random numbers to speedup sort and fix CopyTo test
mfbalin Dec 27, 2023
16fda67
add comment about the used datatype
mfbalin Dec 27, 2023
afec25d
optimize edge id type for sort
mfbalin Dec 27, 2023
c53b476
eliminate synchronization for max_in_degree.
mfbalin Dec 28, 2023
726112c
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 28, 2023
5622a62
take back CopyTo fix, use extra_attrs
mfbalin Dec 28, 2023
61d4756
address reviews and take back example change
mfbalin Dec 28, 2023
5b65e12
take back example change fully
mfbalin Dec 28, 2023
3ede9f3
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 28, 2023
1b5e596
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 28, 2023
0732bd7
Merge branch 'master' into gb_cuda_sampling
mfbalin Dec 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,48 @@
namespace graphbolt {
namespace ops {

/**
* @brief Sample neighboring edges of the given nodes and return the induced
* subgraph.
*
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param nodes The nodes from which to sample neighbors.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
* neighbors of the node as a collective, regardless of the edge type.
* - Otherwise, the length should equal to the number of edge types, and
* each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors with non-zero probability when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param type_per_edge A tensor representing the type of each edge, if present.
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
*/
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt);

/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
Expand Down
2 changes: 1 addition & 1 deletion graphbolt/src/cuda/index_select_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
const IdType* index_sorted_ptr = sorted_index.data_ptr<IdType>();
const int64_t* permutation_ptr = permutation.data_ptr<int64_t>();

cudaStream_t stream = cuda::GetCurrentStream();
auto stream = cuda::GetCurrentStream();

if (aligned_feature_size == 1) {
// Use a single thread to process each output row to avoid wasting threads.
Expand Down
319 changes: 319 additions & 0 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/index_select_impl.cu
* @brief Index select operator implementation on CUDA.
*/
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <curand_kernel.h>
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <algorithm>
#include <array>
#include <cub/cub.cuh>
#include <cuda/std/tuple>
#include <limits>
#include <numeric>
#include <type_traits>

#include "../random.h"
#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

/**
* @brief Fills the random_arr with random numbers and the edge_ids array with
* original edge ids. When random_arr is sorted along with edge_ids, the first
* fanout elements of each row gives us the sampled edges.
*/
template <
typename float_t, typename indptr_t, typename indices_t, typename weights_t,
typename edge_id_t>
__global__ void _ComputeRandoms(
const int64_t num_edges, const indptr_t* const sliced_indptr,
const indptr_t* const sub_indptr, const indices_t* const csr_rows,
const weights_t* const weights, const indices_t* const indices,
const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
curandStatePhilox4_32_10_t rng;
const auto labor = indices != nullptr;

if (!labor) {
curand_init(random_seed, i, 0, &rng);
}

while (i < num_edges) {
const auto row_position = csr_rows[i];
const auto row_offset = i - sub_indptr[row_position];
const auto in_idx = sliced_indptr[row_position] + row_offset;

if (labor) {
constexpr uint64_t kCurandSeed = 999961;
curand_init(kCurandSeed, random_seed, indices[in_idx], &rng);
}

const auto rnd = curand_uniform(&rng);
const auto prob = weights ? weights[in_idx] : static_cast<weights_t>(1);
const auto exp_rnd = -__logf(rnd);
const float_t adjusted_rnd = prob > 0
? static_cast<float_t>(exp_rnd / prob)
: std::numeric_limits<float_t>::infinity();
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
random_arr[i] = adjusted_rnd;
edge_ids[i] = row_offset;

i += stride;
}
}

template <typename indptr_t>
struct MinInDegreeFanout {
const indptr_t* in_degree;
int64_t fanout;
__host__ __device__ auto operator()(int64_t i) {
return static_cast<indptr_t>(
min(static_cast<int64_t>(in_degree[i]), fanout));
}
};

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
indptr_t* indptr;
indices_t* indices;
__host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t>
struct AddOffset {
indptr_t offset;
template <typename edge_id_t>
__host__ __device__ indptr_t operator()(edge_id_t x) {
return x + offset;
}
};

template <typename indptr_t, typename indices_t>
struct IteratorFuncAddOffset {
indptr_t* indptr;
indptr_t* sliced_indptr;
indices_t* indices;
__host__ __device__ auto operator()(int64_t i) {
return thrust::transform_output_iterator{
indices + indptr[i], AddOffset<indptr_t>{sliced_indptr[i]}};
}
};

c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask) {
TORCH_CHECK(
fanouts.size() == 1, "Heterogenous sampling is not supported yet!");
TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
// are all resident on the GPU. If not, it is better to first extract them
// before calling this function.
auto allocator = cuda::GetAllocator();
const auto stream = cuda::GetCurrentStream();
const auto num_rows = nodes.size(0);
const auto fanout =
fanouts[0] >= 0 ? fanouts[0] : std::numeric_limits<int64_t>::max();
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
auto max_in_degree = torch::empty(
1,
c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "SampleNeighborsInDegree", ([&] {
size_t tmp_storage_size = 0;
cub::DeviceReduce::Max(
nullptr, tmp_storage_size, in_degree.data_ptr<scalar_t>(),
max_in_degree.data_ptr<scalar_t>(), num_rows, stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceReduce::Max(
tmp_storage.get(), tmp_storage_size, in_degree.data_ptr<scalar_t>(),
max_in_degree.data_ptr<scalar_t>(), num_rows, stream);
}));
auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
auto sub_indptr = ExclusiveCumSum(in_degree);
auto output_indptr = torch::empty_like(sub_indptr);
auto coo_rows = CSRToCOO(sub_indptr, indices.scalar_type());
const auto num_edges = coo_rows.size(0);
const auto random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
torch::Tensor picked_eids;
torch::Tensor output_indices;

AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = scalar_t;
thrust::counting_iterator<int64_t> iota(0);
auto sampled_degree = thrust::make_transform_iterator(
iota, MinInDegreeFanout<indptr_t>{
in_degree.data_ptr<indptr_t>(), fanout});

{ // Compute output_indptr.
size_t tmp_storage_size = 0;
cub::DeviceScan::ExclusiveSum(
nullptr, tmp_storage_size, sampled_degree,
output_indptr.data_ptr<indptr_t>(), num_rows + 1, stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceScan::ExclusiveSum(
tmp_storage.get(), tmp_storage_size, sampled_degree,
output_indptr.data_ptr<indptr_t>(), num_rows + 1, stream);
}

auto num_sampled_edges =
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};

// Find the smallest integer type to store the edge id offsets.
// CSRToCOO had synch inside, so it is safe to read max_in_degree now.
const int num_bits =
cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
std::array<int, 4> type_bits = {8, 16, 32, 64};
const auto type_index =
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
type_bits.begin();
std::array<torch::ScalarType, 5> types = {
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
torch::kLong};
auto edge_id_dtype = types[type_index];
AT_DISPATCH_INTEGRAL_TYPES(
edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
using edge_id_t = std::make_unsigned_t<scalar_t>;
TORCH_CHECK(
num_bits <= sizeof(edge_id_t) * 8,
"Selected edge_id_t must be capable of storing edge_ids.");
// Using bfloat16 for random numbers works just as reliably as
// float32 and provides around %30 percent speedup.
using rnd_t = nv_bfloat16;
auto randoms = allocator.AllocateStorage<rnd_t>(num_edges);
auto randoms_sorted = allocator.AllocateStorage<rnd_t>(num_edges);
auto edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges);
auto sorted_edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges);
AT_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = scalar_t;
auto probs_or_mask_scalar_type = torch::kFloat32;
if (probs_or_mask.has_value()) {
probs_or_mask_scalar_type =
probs_or_mask.value().scalar_type();
}
GRAPHBOLT_DISPATCH_ALL_TYPES(
probs_or_mask_scalar_type, "SampleNeighborsProbs",
([&] {
using probs_t = scalar_t;
probs_t* probs_ptr = nullptr;
if (probs_or_mask.has_value()) {
probs_ptr =
probs_or_mask.value().data_ptr<probs_t>();
}
const indices_t* indices_ptr =
layer ? indices.data_ptr<indices_t>() : nullptr;
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(num_edges + BLOCK_SIZE - 1) / BLOCK_SIZE);
// Compute row and random number pairs.
CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0, stream,
num_edges, sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), probs_ptr,
indices_ptr, random_seed, randoms.get(),
edge_id_segments.get());
}));
}));

// Sort the random numbers along with edge ids, after
// sorting the first fanout elements of each row will
// give us the sampled edges.
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceSegmentedSort::SortPairs(
nullptr, tmp_storage_size, randoms.get(),
randoms_sorted.get(), edge_id_segments.get(),
sorted_edge_id_segments.get(), num_edges, num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream));
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceSegmentedSort::SortPairs(
tmp_storage.get(), tmp_storage_size, randoms.get(),
randoms_sorted.get(), edge_id_segments.get(),
sorted_edge_id_segments.get(), num_edges, num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream));
mfbalin marked this conversation as resolved.
Show resolved Hide resolved

picked_eids = torch::empty(
static_cast<indptr_t>(num_sampled_edges),
nodes.options().dtype(indptr.scalar_type()));

auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(),
sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>()});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();

// Copy the sampled edge ids into picked_eids tensor.
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
size_t tmp_storage_size = 0;
CUDA_CALL(cub::DeviceCopy::Batched(
nullptr, tmp_storage_size, input_buffer_it + i,
output_buffer_it + i, sampled_degree + i,
std::min(num_rows - i, max_copy_at_once), stream));
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceCopy::Batched(
tmp_storage.get(), tmp_storage_size, input_buffer_it + i,
output_buffer_it + i, sampled_degree + i,
std::min(num_rows - i, max_copy_at_once), stream));
}
}));

output_indices = torch::empty(
picked_eids.size(0),
picked_eids.options().dtype(indices.scalar_type()));

// Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
using indices_t = scalar_t;
const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(stream);
thrust::gather(
exec_policy, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>());
}));
}));

torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);

return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, torch::nullopt);
}

} // namespace ops
} // namespace graphbolt
Loading