Skip to content

Commit

Permalink
[MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns…
Browse files Browse the repository at this point in the history
… on GPU (apache#10371)

* add support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU

* add unit test for new op and forward_stype_hint parameter to dot

* update documentation for dot

* address code reviews

* fix flaky test_gluon:test_lambda through loosening the atol

* switch dot(dns, csr) case to a deterministic algorithm with unit test for determinism

* address code reviews and add backward
  • Loading branch information
haojin2 authored and eric-haibin-lin committed Apr 26, 2018
1 parent 6cacb57 commit e391eff
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 42 deletions.
245 changes: 245 additions & 0 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <mxnet/base.h>
#include <mxnet/operator.h>
#include "./init_op.h"
#include "./sort_op.h"
#include "./util/tensor_util-inl.h"
#include "./util/tensor_util-inl.cuh"

Expand Down Expand Up @@ -442,6 +444,99 @@ struct DotCsrRspDnsScalarKernel {
}
};

/*!
* \brief GPU Kernel to scatter row id to corresponding entries
* \param tid global thread id
* \param csr_indptr indptr array of csr
* \param csr_rows array of row id of csr elements
* \param num_rows total number of rows in csr matrix
* Parallelization by output elements: 1 thread/row
*/
struct CsrRowScatterKernel {
template<typename CType>
__device__ __forceinline__ static void Map(int tid,
const CType* csr_indptr,
CType* csr_rows,
const nnvm::dim_t num_rows) {
if (tid < num_rows) {
for (CType i = csr_indptr[tid]; i < csr_indptr[tid+1]; ++i) {
csr_rows[i] = tid;
}
}
}
};

struct CscDataIndicesKernel {
/*!
* \brief
* \param tid global thread id
* \param lhs_data lhs dense matrix data
* \param rhs_data csr matrix data
* \param rhs_indices csr matrix column indices
* \param rhs_indptr csr matrix row pointer
* \param out output matrix data
* \param lhs_num_cols lhs dns matrix number of columns
* \param out_num_rows output dns matrix number of rows
* \param out_num_cols output dns matrix number of columns
*/
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
const IType* original_idx_ptr,
const DType* csr_data_ptr,
const CType* csr_rows_ptr,
DType* csc_data_ptr,
IType* csc_indices_ptr,
const nnvm::dim_t nnz) {
using nnvm::dim_t;
if (tid < nnz) {
const IType origin = original_idx_ptr[tid];
csc_data_ptr[tid] = csr_data_ptr[origin];
csc_indices_ptr[tid] = csr_rows_ptr[origin];
}
}
};

/*!
* \brief GPU Kernel of dot(dns, csr.T) = dns
* Parallelization by output elements: 1 thread/element
*/
struct DotDnsCsrTransDnsKernel {
/*!
* \brief
* \param tid global thread id
* \param lhs_data lhs dense matrix data
* \param rhs_data csr matrix data
* \param rhs_indices csr matrix column indices
* \param rhs_indptr csr matrix row pointer
* \param out output matrix data
* \param lhs_num_cols lhs dns matrix number of columns
* \param out_num_rows output dns matrix number of rows
* \param out_num_cols output dns matrix number of columns
*/
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
const DType* lhs_data,
const DType* rhs_data,
const IType* rhs_indices,
const CType* rhs_indptr,
DType* out,
const nnvm::dim_t lhs_num_cols,
const nnvm::dim_t out_num_rows,
const nnvm::dim_t out_num_cols) {
using nnvm::dim_t;
if (tid < out_num_rows*out_num_cols) {
const dim_t i = static_cast<dim_t>(tid) % out_num_rows; // i = row this thread computes
const dim_t k = static_cast<dim_t>(tid) / out_num_rows; // k = col this thread computes
// Compute inner product of i-th row and k-th col
DType sum = 0;
for (CType col_id = rhs_indptr[k]; col_id < rhs_indptr[k + 1]; ++col_id) {
sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * rhs_data[col_id];
}
out[i * out_num_cols + k] = sum;
}
}
};

/*!
* \brief GPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
*/
Expand Down Expand Up @@ -895,6 +990,156 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx,
});
}

// Returns integer log2(a) rounded up
inline int log2i(size_t a) {
int k = 1;
while (a >>= 1) k++;
return k;
}

/*
* \brief GPU Impl of dot(dns, csr) = csr
*/
inline void DotDnsCsrCsrImpl(const OpContext& ctx, const gpu& gpu_dev,
const TBlob& lhs, const NDArray& rhs,
const OpReqType req, NDArray* ret) {
LOG(FATAL) << "dot(dense, csr) = csr is not implemented on GPU";
}

/*
* \brief GPU Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
*/
inline void DotDnsCsrDnsImpl(const OpContext& ctx, const gpu& gpu_dev,
const TBlob& dns, const NDArray& rhs,
const OpReqType req, NDArray* ret,
const bool transpose_b) {
if (req == kNullOp) {
return;
}
CHECK_EQ(req, kWriteTo);
CHECK_EQ(rhs.storage_type(), kCSRStorage);

using namespace mshadow;
using namespace mshadow::expr;
using nnvm::dim_t;

/* Initialize data structures */
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
TBlob csr_data = rhs.data();
TBlob csr_indices = rhs.aux_data(csr::kIdx);
TBlob csr_indptr = rhs.aux_data(csr::kIndPtr);
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
return;
}

MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type
const nnvm::dim_t out_num_rows = ret->shape()[0];
const nnvm::dim_t out_num_cols = ret->shape()[1];
// if dot(dense, csr) = dns, transform to csc first
if (!transpose_b) {
const nnvm::dim_t num_csr_rows = rhs.shape()[0];
const nnvm::dim_t num_csr_cols = rhs.shape()[1];
const nnvm::dim_t num_dns_rows = dns.shape_[0];
const nnvm::dim_t nnz = rhs.storage_shape().Size();

IType* original_idx_ptr = nullptr;
IType* csc_indices_ptr = nullptr;
IType* csc_cols_ptr = nullptr;
CType* csr_rows_ptr = nullptr;
CType* csc_indptr_ptr = nullptr;
DType* csc_data_ptr = nullptr;
char* temp_storage_ptr = nullptr;
size_t original_idx_bytes = nnz*sizeof(IType);
size_t csc_indices_bytes = nnz*sizeof(IType);
size_t csc_cols_bytes = nnz*sizeof(IType);
size_t csr_rows_bytes = nnz*sizeof(CType);
size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType);
size_t csc_data_bytes = nnz*sizeof(DType);
size_t scan_temp_storage_bytes = 0;
size_t temp_storage_bytes = SortByKeyWorkspaceSize<IType, IType, gpu>(nnz);
IType* csr_indices_ptr = csr_indices.dptr<IType>();
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
scan_temp_storage_bytes,
csc_indptr_ptr,
csc_indptr_ptr,
num_csr_cols+1,
mshadow::Stream<gpu>::GetStream(s));
temp_storage_bytes = std::max(temp_storage_bytes, scan_temp_storage_bytes);
temp_storage_bytes += (sizeof(dim_t) - temp_storage_bytes % sizeof(dim_t));
size_t total_workspace_bytes =
original_idx_bytes + csc_indices_bytes + csc_cols_bytes + csr_rows_bytes +
csc_indptr_bytes + csc_data_bytes + temp_storage_bytes;
total_workspace_bytes += (sizeof(IType) - total_workspace_bytes % sizeof(IType));
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(total_workspace_bytes), s);
original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_);
csc_indices_ptr = reinterpret_cast<IType*>(workspace.dptr_ + original_idx_bytes);
csc_cols_ptr = reinterpret_cast<IType*>(workspace.dptr_ + original_idx_bytes +
csc_indices_bytes);
csr_rows_ptr = reinterpret_cast<CType*>(workspace.dptr_ + original_idx_bytes +
csc_indices_bytes + csc_cols_bytes);
csc_indptr_ptr = reinterpret_cast<CType*>(workspace.dptr_ + original_idx_bytes +
csc_indices_bytes + csc_cols_bytes +
csr_rows_bytes);
temp_storage_ptr = workspace.dptr_ + original_idx_bytes + csc_indices_bytes +
csc_cols_bytes + csr_rows_bytes + csc_indptr_bytes;
csc_data_ptr = reinterpret_cast<DType*>(
workspace.dptr_ + total_workspace_bytes - csc_data_bytes);

// Fill original_idx
mxnet_op::Kernel<range_fwd, gpu>::Launch(
s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
// Fill csc_cols with copy of csr_indices
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
s, nnz, csc_cols_ptr, csr_indices_ptr);
// Allocate the tensors needed for SortByKey
Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(nnz), s);
Tensor<gpu, 1, IType> csc_cols(csc_cols_ptr, Shape1(nnz), s);
Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);

int num_bits = log2i(num_csr_cols - 1);
SortByKey(csc_cols, original_idx, true, &temp_storage, 0, num_bits);

// Scatter csr indptr to row id
mxnet_op::Kernel<CsrRowScatterKernel, gpu>::Launch(
s, num_csr_rows, csr_indptr.dptr<CType>(), csr_rows_ptr, num_csr_rows);
// Reset indptr to zero
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, num_csr_cols+1, csc_indptr_ptr);
// Histogram on the sorted cols
mxnet_op::Kernel<HistogramKernel, gpu>::Launch(
s, nnz, csc_indptr_ptr, csc_cols_ptr, nnz);
// Scan the bin counts for every column to get csc_indptr
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
temp_storage_bytes,
csc_indptr_ptr,
csc_indptr_ptr,
num_csr_cols+1,
mshadow::Stream<gpu>::GetStream(s));
// Assign data to csc matrix arrays
mxnet_op::Kernel<CscDataIndicesKernel, gpu>::Launch(
s, nnz, original_idx_ptr, csr_data.dptr<DType>(), csr_rows_ptr, csc_data_ptr,
csc_indices_ptr, nnz);

mxnet_op::Kernel<DotDnsCsrTransDnsKernel, gpu>::Launch(
s, out_num_rows * out_num_cols, dns.dptr<DType>(),
csc_data_ptr, csc_indices_ptr, csc_indptr_ptr,
ret->data().dptr<DType>(), dns.shape_[1],
out_num_rows, out_num_cols);
} else {
mxnet_op::Kernel<DotDnsCsrTransDnsKernel, gpu>::Launch(
s, out_num_rows * out_num_cols, dns.dptr<DType>(),
csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(), ret->data().dptr<DType>(),
dns.shape_[1], out_num_rows, out_num_cols);
}
});
});
});
}

} // namespace op
} // namespace mxnet

Expand Down
Loading

0 comments on commit e391eff

Please sign in to comment.