Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Wrap cub with CUB_NS_PREFIX and remove dependency on Thrust to linking issues with Torch 1.8 #2758

Merged
merged 4 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array/cuda/array_cumsum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
* \brief Array cumsum GPU implementation
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {
using runtime::NDArray;
Expand Down
73 changes: 49 additions & 24 deletions src/array/cuda/array_nonzero.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,71 @@
* \file array/cpu/array_nonzero.cc
* \brief Array nonzero CPU implementation
*/
#include <thrust/iterator/counting_iterator.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/device_vector.h>

#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

template <typename IdType>
struct IsNonZero {
__device__ bool operator() (const IdType val) {
return val != 0;
struct IsNonZeroIndex {
explicit IsNonZeroIndex(const IdType * array) : array_(array) {
}

__device__ bool operator() (const int64_t index) {
return array_[index] != 0;
}

const IdType * array_;
};

template <DLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);

const int64_t len = array->shape[0];
IdArray ret = NewIdArray(len, array->ctx, 64);
thrust::device_ptr<IdType> in_data(array.Ptr<IdType>());
thrust::device_ptr<int64_t> out_data(ret.Ptr<int64_t>());
// TODO(minjie): should take control of the memory allocator.
// See PyTorch's implementation here:
// https://github.com/pytorch/pytorch/blob/1f7557d173c8e9066ed9542ada8f4a09314a7e17/
// aten/src/THC/generic/THCTensorMath.cu#L104
auto startiter = thrust::make_counting_iterator<int64_t>(0);
auto enditer = startiter + len;
auto indices_end = thrust::copy_if(thrust::cuda::par.on(thr_entry->stream),
startiter,
enditer,
in_data,
out_data,
IsNonZero<IdType>());
const int64_t num_nonzeros = indices_end - out_data;
IdArray ret = NewIdArray(len, ctx, 64);

cudaStream_t stream = 0;

const IdType * const in_data = static_cast<const IdType*>(array->data);
int64_t * const out_data = static_cast<int64_t*>(ret->data);

IsNonZeroIndex<IdType> comp(in_data);
cub::CountingInputIterator<int64_t> counter(0);

// room for cub to output on GPU
int64_t * d_num_nonzeros = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t)));

size_t temp_size = 0;
cub::DeviceSelect::If(nullptr, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
void * temp = device->AllocWorkspace(ctx, temp_size);
cub::DeviceSelect::If(temp, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
device->FreeWorkspace(ctx, temp);

// copy number of selected elements from GPU to CPU
int64_t num_nonzeros;
device->CopyDataFromTo(
d_num_nonzeros, 0,
&num_nonzeros, 0,
sizeof(num_nonzeros),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1},
stream);
device->FreeWorkspace(ctx, d_num_nonzeros);
device->StreamSync(ctx, stream);

// truncate array to size
return ret.CreateView({num_nonzeros}, ret->dtype, 0);
}

Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/array_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
* \brief Array sort GPU implementation
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {
using runtime::NDArray;
Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/csr_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
* \brief Sort CSR index
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {

Expand Down
17 changes: 17 additions & 0 deletions src/array/cuda/dgl_cub.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*!
* Copyright (c) 2021 by Contributors
* \file cuda_common.h
* \brief Wrapper to place cub in dgl namespace.
*/

#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_

// include cub in a safe manner
#define CUB_NS_PREFIX namespace dgl {
#define CUB_NS_POSTFIX }
#include "cub/cub.cuh"
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX

#endif
2 changes: 1 addition & 1 deletion src/array/cuda/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

#include "./utils.h"
#include <cub/cub.cuh>
#include "./dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/cuda/cuda_hashtable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* \brief Device level functions for within cuda kernels.
*/

#include <cub/cub.cuh>
#include <cassert>

#include "cuda_hashtable.cuh"
#include "../../kernel/cuda/atomic.cuh"
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
#include "../../array/cuda/dgl_cub.cuh"

using namespace dgl::kernel::cuda;

Expand Down