From a46bb74c3c0f9ab24914345ca35565ac9d8e0711 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 17 Apr 2018 00:19:05 -0700 Subject: [PATCH] switch dot(dns, csr) case to a deterministic algorithm with unit test for determinism --- src/operator/tensor/dot-inl.cuh | 208 ++++++++++-------- src/operator/tensor/util/tensor_util-inl.cuh | 16 ++ tests/python/unittest/test_sparse_operator.py | 25 +++ 3 files changed, 161 insertions(+), 88 deletions(-) diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index 19cac543bf50..fffefbcf699b 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -27,11 +27,12 @@ #include #include +#include "./indexing_op.h" +#include "./init_op.h" +#include "./sort_op.h" #include "./util/tensor_util-inl.h" #include "./util/tensor_util-inl.cuh" -typedef unsigned long long AtomicIType; - namespace mxnet { namespace op { @@ -445,53 +446,59 @@ struct DotCsrRspDnsScalarKernel { }; /*! - * \brief GPU Kernel to re-arrange nnz elements to csc order - * Parallelization by output elements: 1 thread/row of csr + * \brief GPU Kernel to scatter row id to corresponding entries + * \param tid global thread id + * \param csr_indptr indptr array of csr + * \param csr_rows array of row id of csr elements + * \param num_rows total number of rows in csr matrix + * Parallelization by output elements: 1 thread/row */ -struct CscDataIndicesKernel { - template +struct CsrRowScatterKernel { + template __device__ __forceinline__ static void Map(int tid, - const DType* csr_data, - const IType* csr_indices, const CType* csr_indptr, - DType* csc_data, - AtomicIType* csc_indices, - AtomicIType* csc_indptr, - AtomicIType* col_counters, - const nnvm::dim_t num_rows, - const nnvm::dim_t num_cols) { + CType* csr_rows, + const nnvm::dim_t num_rows) { if (tid < num_rows) { - for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) { - // target column - const IType target_col = csr_indices[i]; - const int target_offset = atomicAdd(&col_counters[target_col], 1); - const int new_pos = csc_indptr[target_col] + target_offset; - csc_data[new_pos] = csr_data[i]; - csc_indices[new_pos] = tid; + for (CType i = csr_indptr[tid]; i < csr_indptr[tid+1]; ++i) { + csr_rows[i] = tid; } } } }; /*! - * \brief GPU Kernel of getting count for every column + * \brief GPU Kernel of generation of transposed csr matrix + * \param tid global thread id + * \param * Parallelization by output elements: 1 thread/element */ -struct CsrTransHistogramKernel { +struct CscDataIndicesKernel { /*! * \brief * \param tid global thread id - * \param in_indices csr matrix column indices - * \param out_indptr csr matrix row pointer - * \param nnz number of non-zero elements in csr + * \param lhs_data lhs dense matrix data + * \param rhs_data csr matrix data + * \param rhs_indices csr matrix column indices + * \param rhs_indptr csr matrix row pointer + * \param out output matrix data + * \param lhs_num_cols lhs dns matrix number of columns + * \param out_num_rows output dns matrix number of rows + * \param out_num_cols output dns matrix number of columns */ - template + template __device__ __forceinline__ static void Map(int tid, - const IType* in_indices, - AtomicIType* out_indptr, + const IType* original_idx_ptr, + const DType* csr_data_ptr, + const CType* csr_rows_ptr, + DType* csc_data_ptr, + IType* csc_indices_ptr, const nnvm::dim_t nnz) { + using nnvm::dim_t; if (tid < nnz) { - atomicAdd(&out_indptr[in_indices[tid]], 1); + const IType origin = original_idx_ptr[tid]; + csc_data_ptr[tid] = csr_data_ptr[origin]; + csc_indices_ptr[tid] = csr_rows_ptr[origin]; } } }; @@ -525,14 +532,14 @@ struct DotDnsCsrTransDnsKernel { const nnvm::dim_t out_num_cols) { using nnvm::dim_t; if (tid < out_num_rows*out_num_cols) { - const dim_t i = static_cast(tid) / out_num_cols; // i = row this thread computes - const dim_t k = static_cast(tid) % out_num_cols; // k = col this thread computes + const dim_t i = static_cast(tid) % out_num_rows; // i = row this thread computes + const dim_t k = static_cast(tid) / out_num_rows; // k = col this thread computes // Compute inner product of i-th row and k-th col DType sum = 0; for (CType col_id = rhs_indptr[k]; col_id < rhs_indptr[k + 1]; ++col_id) { sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * rhs_data[col_id]; } - out[tid] = sum; + out[i * out_num_cols + k] = sum; } } }; @@ -1037,65 +1044,90 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, const nnvm::dim_t num_dns_rows = dns.shape_[0]; const nnvm::dim_t nnz = rhs.storage_shape().Size(); - DType* csc_data_ptr = NULL; - AtomicIType* csc_indices_ptr = NULL; - AtomicIType* csc_indptr_ptr = NULL; - AtomicIType* col_counters = NULL; - size_t ull_num_bytes = sizeof(AtomicIType); - void* temp_storage = NULL; - size_t temp_storage_bytes = 0; - - // Get necessary temporary storage amount - cub::DeviceScan::ExclusiveSum(NULL, - temp_storage_bytes, - csc_indices_ptr, - csc_indices_ptr, - num_csr_cols + 1, - Stream::GetStream(s)); - // Align to multiple of ull_num_bytes - temp_storage_bytes += (ull_num_bytes - (temp_storage_bytes % ull_num_bytes)); - size_t csc_data_size = nnz*sizeof(DType); - size_t csc_indices_size = nnz*ull_num_bytes; - size_t csc_indptr_size = (num_csr_cols+1)*ull_num_bytes; - size_t col_counters_size = (num_csr_cols+1)*ull_num_bytes; - Tensor workspace = - ctx.requested[0].get_space_typed( - Shape1(csc_data_size + csc_indices_size + - csc_indptr_size + col_counters_size + - temp_storage_bytes), - s); - csc_indices_ptr = reinterpret_cast(workspace.dptr_); - csc_indptr_ptr = reinterpret_cast( - workspace.dptr_ + csc_indices_size); - col_counters = reinterpret_cast( - workspace.dptr_ + csc_indices_size + csc_indptr_size); - csc_data_ptr = reinterpret_cast(workspace.dptr_ + csc_indices_size + - csc_indptr_size + col_counters_size); - temp_storage = reinterpret_cast(workspace.dptr_ + csc_data_size + - csc_indices_size + csc_indptr_size + - col_counters_size); - mxnet_op::Kernel::Launch( - s, num_dns_rows*num_csr_cols, ret->data().dptr()); - // Reset values for indptr, ready for histogramming - mxnet_op::Kernel::Launch( - s, num_csr_cols+1, csc_indptr_ptr); - // Histogramming on col id - mxnet_op::Kernel::Launch( - s, nnz, csr_indices.dptr(), csc_indptr_ptr, nnz); - cub::DeviceScan::ExclusiveSum(temp_storage, + IType* original_idx_ptr = nullptr; + IType* csc_indices_ptr = nullptr; + IType* csc_cols_ptr = nullptr; + CType* csr_rows_ptr = nullptr; + CType* csc_indptr_ptr = nullptr; + DType* csc_data_ptr = nullptr; + char* temp_storage_ptr = nullptr; + size_t original_idx_bytes = nnz*sizeof(IType); + size_t csc_indices_bytes = nnz*sizeof(IType); + size_t csc_cols_bytes = nnz*sizeof(IType); + size_t csr_rows_bytes = nnz*sizeof(CType); + size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType); + size_t csc_data_bytes = nnz*sizeof(DType); + size_t unique_temp_storage_bytes = 0; + size_t scan_temp_storage_bytes = 0; + size_t temp_storage_bytes = SortByKeyWorkspaceSize(nnz); + size_t *null_ptr = nullptr; + IType* csr_indices_ptr = csr_indices.dptr(); + cub::DeviceSelect::Unique(NULL, unique_temp_storage_bytes, csr_indices_ptr, + csr_indices_ptr, null_ptr, nnz, Stream::GetStream(s)); + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + scan_temp_storage_bytes, + csc_indptr_ptr, + csc_indptr_ptr, + num_csr_cols+1, + mshadow::Stream::GetStream(s)); + temp_storage_bytes = std::max(temp_storage_bytes, unique_temp_storage_bytes); + temp_storage_bytes = std::max(temp_storage_bytes, scan_temp_storage_bytes); + temp_storage_bytes += (sizeof(dim_t) - temp_storage_bytes % sizeof(dim_t)); + size_t total_workspace_bytes = + original_idx_bytes + csc_indices_bytes + csc_cols_bytes + csr_rows_bytes + + csc_indptr_bytes + csc_data_bytes + temp_storage_bytes; + total_workspace_bytes += (sizeof(IType) - total_workspace_bytes % sizeof(IType)); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(total_workspace_bytes), s); + original_idx_ptr = reinterpret_cast(workspace.dptr_); + csc_indices_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + csc_cols_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes); + csr_rows_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes + csc_cols_bytes); + csc_indptr_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes + csc_cols_bytes + + csr_rows_bytes); + temp_storage_ptr = workspace.dptr_ + original_idx_bytes + csc_indices_bytes + + csc_cols_bytes + csr_rows_bytes + csc_indptr_bytes; + csc_data_ptr = reinterpret_cast( + workspace.dptr_ + total_workspace_bytes - csc_data_bytes); + + // Fill original_idx + mxnet_op::Kernel::Launch( + s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr); + // Fill csc_cols with copy of csr_indices + mxnet_op::Kernel, gpu>::Launch( + s, nnz, csc_cols_ptr, csr_indices_ptr); + + // Allocate the tensors needed for SortByKey + Tensor original_idx(original_idx_ptr, Shape1(nnz), s); + Tensor csc_cols(csc_cols_ptr, Shape1(nnz), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + + SortByKey(csc_cols, original_idx, true, &temp_storage, 0, static_cast(nnz)); + + // Scatter csr indptr to row id + mxnet_op::Kernel::Launch( + s, num_csr_rows, csr_indptr.dptr(), csr_rows_ptr, num_csr_rows); + // Reset indptr to zero + mxnet_op::Kernel::Launch(s, num_csr_cols+1, csc_indptr_ptr); + // Histogram on the sorted cols + mxnet_op::Kernel::Launch( + s, nnz, csc_indptr_ptr, csc_cols_ptr, nnz); + + // Scan the bin counts for every column to get csc_indptr + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, temp_storage_bytes, csc_indptr_ptr, csc_indptr_ptr, - num_csr_cols + 1, - Stream::GetStream(s)); - // Reset values for col_counter, ready for the final transform - mxnet_op::Kernel::Launch( - s, num_csr_cols+1, col_counters); - // Transform to CSC + num_csr_cols+1, + mshadow::Stream::GetStream(s)); + // Assign data to csc matrix arrays mxnet_op::Kernel::Launch( - s, num_csr_rows, csr_data.dptr(), csr_indices.dptr(), - csr_indptr.dptr(), csc_data_ptr, csc_indices_ptr, - csc_indptr_ptr, col_counters, num_csr_rows, num_csr_cols); + s, nnz, original_idx_ptr, csr_data.dptr(), csr_rows_ptr, csc_data_ptr, + csc_indices_ptr, nnz); + mxnet_op::Kernel::Launch( s, out_num_rows * out_num_cols, dns.dptr(), csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, diff --git a/src/operator/tensor/util/tensor_util-inl.cuh b/src/operator/tensor/util/tensor_util-inl.cuh index f38e8e117c94..c9ee625af0c8 100644 --- a/src/operator/tensor/util/tensor_util-inl.cuh +++ b/src/operator/tensor/util/tensor_util-inl.cuh @@ -231,6 +231,22 @@ struct MarkCsrColWarpKernel { } }; +/*! + * \brief GPU Kernel to perform histogram (input types should be integer types) + * Parallelization by output elements: 1 thread/input element + */ +struct HistogramKernel { + template + __device__ __forceinline__ static void Map(int tid, + IType* target, + const CType* source, + const nnvm::dim_t num_elems) { + if (tid < num_elems) { + atomicAdd(&target[source[tid]], 1); + } + } +}; + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 06712e78f110..4a5766aff730 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1326,6 +1326,31 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40) +@with_seed() +def test_sparse_dot_determinism(): + def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b): + lhs_row = rnd.randint(200, 400) + lhs_col = rnd.randint(200, 400) + if transpose_a: + if transpose_b: + rhs_shape = (rnd.randint(200, 400), lhs_row) + else: + rhs_shape = (lhs_row, rnd.randint(200, 400)) + else: + if transpose_b: + rhs_shape = (rnd.randint(200, 400), lhs_col) + else: + rhs_shape = (lhs_col, rnd.randint(200, 400)) + lhs_shape = (lhs_row, lhs_col) + lhs = rand_ndarray(lhs_shape, lhs_stype, density=lhs_density) + rhs = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density) + res1 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b) + res2 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.0, atol=0.0) + test_dot_determinism('default', 'csr', 1.0, 0.1, False, False) + test_dot_determinism('default', 'csr', 1.0, 0.1, False, True) + + @with_seed() def test_sparse_slice(): def check_csr_slice(shape, slice_input):