diff --git a/benchmark/python/cast_storage.py b/benchmark/python/cast_storage.py index a92da7295d2b..38398e5e164a 100644 --- a/benchmark/python/cast_storage.py +++ b/benchmark/python/cast_storage.py @@ -23,17 +23,17 @@ def measure_cost(repeat, f, *args, **kwargs): def run_cast_storage_synthetic(): - def dns_to_csr(m, n, density, ctx, repeat): + def dense_to_sparse(m, n, density, ctx, repeat, stype): set_default_context(ctx) data_shape = (m, n) - dns_data = rand_ndarray(data_shape, 'csr', density).todense() + dns_data = rand_ndarray(data_shape, stype, density).todense() dns_data.wait_to_read() # do one warm up run, verify correctness - assert same(mx.nd.cast_storage(dns_data, stype='csr').asnumpy(), dns_data.asnumpy()) + assert same(mx.nd.cast_storage(dns_data, stype).asnumpy(), dns_data.asnumpy()) # start benchmarking - cost = measure_cost(repeat, mx.nd.cast_storage, dns_data, stype='csr') + cost = measure_cost(repeat, mx.nd.cast_storage, dns_data, stype) results = '{:10.1f} {:>10} {:8d} {:8d} {:10.2f}'.format(density*100, str(ctx), m, n, cost*1000) print(results) @@ -46,24 +46,36 @@ def dns_to_csr(m, n, density, ctx, repeat): # num_repeat number of benchmark runs to average over # contexts mx.cpu(), mx.gpu() # note: benchmark different contexts separately; to benchmark cpu, compile without CUDA + # benchmarks dns_to_csr, dns_to_rsp m = [ 512, 512] n = [50000, 100000] density = [1.00, 0.80, 0.60, 0.40, 0.20, 0.10, 0.05, 0.02, 0.01] num_repeat = 10 contexts = [mx.gpu()] + benchmarks = ["dns_to_csr", "dns_to_rsp"] # run benchmark - print("==================================================") - print(" cast_storage benchmark: dense to csr, size m x n ") - print("==================================================") - headline = '{:>10} {:>10} {:>8} {:>8} {:>10}'.format('density(%)', 'context', 'm', 'n', 'time(ms)') - print(headline) - for i in range(len(n)): - for ctx in contexts: - for den in density: - dns_to_csr(m[i], n[i], den, ctx, num_repeat) + for b in benchmarks: + stype = '' + print("==================================================") + if b is "dns_to_csr": + stype = 'csr' + print(" cast_storage benchmark: dense to csr, size m x n ") + elif b is "dns_to_rsp": + stype = 'row_sparse' + print(" cast_storage benchmark: dense to rsp, size m x n ") + else: + print("invalid benchmark: %s" %b) + continue + print("==================================================") + headline = '{:>10} {:>10} {:>8} {:>8} {:>10}'.format('density(%)', 'context', 'm', 'n', 'time(ms)') + print(headline) + for i in range(len(n)): + for ctx in contexts: + for den in density: + dense_to_sparse(m[i], n[i], den, ctx, num_repeat, stype) + print("") print("") - print("==================================================") if __name__ == "__main__": diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 06ef759c53d0..ca1e8415f7ee 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -110,9 +110,15 @@ def rand_ndarray(shape, stype, density=None): def rand_shape_2d(dim0=10, dim1=10): return rnd.randint(1, dim0 + 1), rnd.randint(1, dim1 + 1) + def rand_shape_3d(dim0=10, dim1=10, dim2=10): return rnd.randint(1, dim0 + 1), rnd.randint(1, dim1 + 1), rnd.randint(1, dim2 + 1) + +def rand_shape_nd(n, dim=10): + return rnd.randint(1, dim+1, size=n) + + def np_reduce(dat, axis, keepdims, numpy_reduce_func): """Compatible reduce for old version of NumPy. diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d4a473c8be0c..694c50c4fe05 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -10,6 +10,9 @@ #include #include #include +#ifdef __CUDACC__ +#include "../common/cuda_utils.h" +#endif // __CUDACC__ namespace mxnet { namespace op { @@ -32,6 +35,13 @@ int get_num_threads(const int N); i < (n); \ i += blockDim.x * gridDim.x) +inline cudaDeviceProp cuda_get_device_prop() { + int device; + CUDA_CALL(cudaGetDevice(&device)); + cudaDeviceProp deviceProp; + CUDA_CALL(cudaGetDeviceProperties(&deviceProp, device)); + return deviceProp; +} /*! * \brief Get the number of blocks for cuda kernel given N diff --git a/src/operator/tensor/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh index 47fb3b42c356..dfb09d632ac8 100644 --- a/src/operator/tensor/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -8,19 +8,300 @@ #include #include +#include #include namespace mxnet { namespace op { -using mshadow::cuda::kBaseThreadNum; -inline void CastStorageDnsRspImpl(const OpContext& ctx, const gpu& gpu_dev, const TBlob& dns, NDArray* rsp) { - LOG(FATAL) << "CastStorageDnsRspImpl gpu version is not implemented."; +/*! + * \brief Thread kernel for marking non-zero rows of a tensor. + * Parallelized by tensor rows: 1 thread/row + */ +struct MarkRspRowIdxThreadKernel { + /*! + * \brief + * \param tid global thread id + * \param row_flg row flag array to mark non-zero rows + * \param dns dense matrix data + * \param num_rows number of rows (size of first dimension of tensor) + * \param row_length number of elements per row + */ + template + __device__ __forceinline__ static void Map(int tid, + RType* row_flg, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + if (tid < num_rows) { + dim_t j = 0; + dim_t offset = tid * row_length; + for (; j < row_length; ++j) { + if (dns[offset+j] != 0) { + break; + } + } + if (j < row_length) { + row_flg[tid] = 1; // mark as one for non-zero row + } else { + row_flg[tid] = 0; // mark as zero for zero row + } + } + } +}; + +/*! + * \brief Warp kernel for marking non-zero rows of a tensor. + * Parallelized by tensor rows: 1 warp/row + */ +struct MarkRspRowIdxWarpKernel { + template + __device__ __forceinline__ static void Map(int tid, + RType* row_flg, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + typedef cub::WarpReduce WarpReduce; + const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32; + __shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block]; + + const dim_t warp_id = tid / 32; // global warp id + const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block + const dim_t lane = tid & (32-1); // local thread id within warp + + if (warp_id < num_rows) { + dim_t flg = 0; + dim_t offset = warp_id * row_length; + for (dim_t j = lane; j < row_length; j+=32) { + if (dns[offset+j] != 0) { + // avoid break: causes slower performance on sparse tensors (<20% density), + // due to thread divergence + flg++; + } + } + dim_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(flg); + if (lane == 0) { + if (aggr > 0) { + row_flg[warp_id] = 1; // mark as one for non-zero row + } else { + row_flg[warp_id] = 0; // mark as zero for zero row + } + } + } + } +}; + +/*! + * \brief Block kernel for marking non-zero rows of a tensor. + * Parallelized by tensor rows: 1 threadBlock/row + */ +struct MarkRspRowIdxBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, + RType* row_flg, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + using mshadow::cuda::kBaseThreadNum; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + if (blockIdx.x < num_rows) { + dim_t flg = 0; + dim_t offset = blockIdx.x * row_length; + for (dim_t j = threadIdx.x; j < row_length; j+=kBaseThreadNum) { + if (dns[offset+j] != 0) { + // avoid break: causes slower performance on sparse tensors (<20% density), + // due to thread divergence + flg++; + } + } + dim_t aggr = BlockReduce(temp_storage).Sum(flg); + if (threadIdx.x == 0) { + if (aggr > 0) { + row_flg[blockIdx.x] = 1; // mark as one for non-zero row + } else { + row_flg[blockIdx.x] = 0; // mark as zero for zero row + } + } + } + } +}; + +/*! + * \brief Kernel for filling the row index array of the rsp tensor. + * Parallelized by tensor rows: 1 thread/row + */ +struct FillRspRowIdxKernel { + /*! + * \brief + * \param tid global thread id + * \param row_idx row index array to store indices of non-zero rows + * \param row_flg_sum inclusive prefix sum array over marked row flag array + * \param num_rows number of rows (size of first dimension of tensor) + */ + template + __device__ __forceinline__ static void Map(int tid, + RType* row_idx, + const RType* row_flg_sum, + const nnvm::dim_t num_rows) { + if (tid < num_rows) { + nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1]; + if (row_flg_sum[tid] > prev) { + row_idx[prev] = tid; + } + } + } +}; + +/*! + * \brief Kernel for filling the value array of the rsp tensor. + * Parallelized by rsp tensor elements: 1 thread/element + */ +struct FillRspValsKernel { + /*! + * \brief + * \param tid global thread id + * \param rsp_val value array of rsp tensor to store data + * \param row_idx indices of non-zero rows + * \param dns dense matrix data + * \param nnr number of non-zero rows + * \param row_length number of elements per row + */ + template + __device__ __forceinline__ static void Map(int tid, + DType* rsp_val, + const RType* row_idx, + const DType* dns, + const nnvm::dim_t nnr, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + if (tid < nnr*row_length) { + const dim_t row_id = tid / row_length; + const dim_t row_el = tid % row_length; + const dim_t dns_idx = row_idx[row_id] * row_length + row_el; + rsp_val[tid] = dns[dns_idx]; + } + } +}; + +/*! + * \brief GPU implementation of casting a dns tensor to rsp type. + */ +inline void CastStorageDnsRspImpl(const OpContext& ctx, + const gpu& gpu_dev, + const TBlob& dns, + NDArray* rsp) { + CHECK(rsp != nullptr); + CHECK_EQ(rsp->storage_type(), kRowSparseStorage); + CHECK_EQ(dns.shape_, rsp->shape()); + using mshadow::Shape1; + using mxnet_op::Kernel; + using nnvm::dim_t; + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type + const dim_t num_rows = dns.shape_[0]; + const dim_t row_length = dns.shape_.ProdShape(1, dns.shape_.ndim()); + const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize; + const dim_t threads_per_block = mshadow::cuda::kBaseThreadNum; + const dim_t min_num_warps = 512; + dim_t num_threads; + // TODO: remove kernel dependency on warpSize=32 + if (threads_per_warp != 32) { + LOG(FATAL) << "CastStorageDnsRspImpl GPU kernels expect warpSize=32"; + } + // Determine temporary device storage requirements + RType* row_flg = NULL; + void* d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + row_flg, + row_flg, + num_rows, + mshadow::Stream::GetStream(s)); + + // Allocate temp storage for marking non-zero rows and for cub's prefix sum + mshadow::Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(num_rows*sizeof(RType)+temp_storage_bytes), s); + row_flg = reinterpret_cast(workspace.dptr_); + d_temp_storage = workspace.dptr_ + num_rows*sizeof(RType); + + // Mark non-zero rows as 'one' in row_flg + // Different kernel versions are optimized for different matrix instances + // (1) 'Thread kernel' (one thread computing one row) + // (2) 'Warp kernel' (one warp computing one row) + // (3) 'Block kernel' (one thread block computing one row) + const int kernel_version = 0; + switch (kernel_version) { + case 1: + num_threads = num_rows; + Kernel::Launch(s, num_threads, + row_flg, dns.dptr(), num_rows, row_length); + break; + case 2: + num_threads = num_rows * threads_per_warp; + Kernel::Launch(s, num_threads, + row_flg, dns.dptr(), num_rows, row_length); + break; + case 3: + num_threads = num_rows * threads_per_block; + Kernel::Launch(s, num_threads, + row_flg, dns.dptr(), num_rows, row_length); + break; + default: + if (row_length < threads_per_warp) { + num_threads = num_rows; + Kernel::Launch(s, num_threads, + row_flg, dns.dptr(), num_rows, row_length); + } else if (row_length < threads_per_block || num_rows > min_num_warps) { + num_threads = num_rows * threads_per_warp; + Kernel::Launch(s, num_threads, + row_flg, dns.dptr(), num_rows, row_length); + } else { + num_threads = num_rows * threads_per_block; + Kernel::Launch(s, num_threads, + row_flg, dns.dptr(), num_rows, row_length); + } + break; + } + // Compute non-zero row indices through inclusive prefix sum + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + row_flg, + row_flg, + num_rows, + mshadow::Stream::GetStream(s)); + + // Get total number of non-zero rows from device + RType nnr = 0; + CUDA_CALL(cudaMemcpy(&nnr, &row_flg[num_rows-1], sizeof(RType), cudaMemcpyDeviceToHost)); + + // Allocate rsp tensor row index array and fill + rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(static_cast(nnr))); + if (0 == nnr) return; + RType* row_idx = rsp->aux_data(rowsparse::kIdx).dptr(); + num_threads = num_rows; + Kernel::Launch(s, num_threads, + row_idx, row_flg, num_rows); + + // Construct shape of rsp tensor data, allocate, and fill + auto storage_shape = dns.shape_; + storage_shape[0] = nnr; + rsp->CheckAndAllocData(storage_shape); + num_threads = nnr * row_length; + Kernel::Launch(s, num_threads, + rsp->data().dptr(), row_idx, dns.dptr(), nnr, row_length); + }); + }); } /*! - * \brief Thread kernel for initializing the indptr in a csr tensor. + * \brief Thread kernel for initializing the indptr in a csr matrix. * Parallelized by matrix rows: 1 thread/row */ struct FillCsrIndPtrThreadKernel { @@ -33,15 +314,19 @@ struct FillCsrIndPtrThreadKernel { * \param num_cols number of columns of the dense matrix */ template - __device__ __forceinline__ static void Map(int tid, IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { + __device__ __forceinline__ static void Map(int tid, + IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using nnvm::dim_t; if (tid == 0) { indptr[tid] = 0; } if (tid < num_rows) { - int nnz = 0; - const int offset = tid * num_cols; - for (int j = 0; j < num_cols; ++j) { + dim_t nnz = 0; + const dim_t offset = tid * num_cols; + for (dim_t j = 0; j < num_cols; ++j) { if (dns[offset+j] != 0) { nnz++; } @@ -52,7 +337,7 @@ struct FillCsrIndPtrThreadKernel { }; /*! - * \brief Thread kernel for initializing the col_idx and value array of the csr matrix + * \brief Thread kernel for initializing the col_idx and value array of the csr matrix. * Parallelized by matrix rows: 1 thread/row */ struct FillCsrColIdxAndValsThreadKernel { @@ -67,13 +352,18 @@ struct FillCsrColIdxAndValsThreadKernel { * \param num_cols number of columns of the dense matrix */ template - __device__ __forceinline__ static void Map(int tid, DType* val, CType* col_idx, - const IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { + __device__ __forceinline__ static void Map(int tid, + DType* val, + CType* col_idx, + const IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using nnvm::dim_t; if (tid < num_rows) { - const int offset = tid * num_cols; - int k = indptr[tid]; - for (int j = 0; j < num_cols; ++j) { + const dim_t offset = tid * num_cols; + dim_t k = indptr[tid]; + for (dim_t j = 0; j < num_cols; ++j) { if (dns[offset+j] != 0) { val[k] = dns[offset+j]; col_idx[k] = j; @@ -85,32 +375,36 @@ struct FillCsrColIdxAndValsThreadKernel { }; /*! - * \brief Warp kernel for initializing the indptr in a csr matrix + * \brief Warp kernel for initializing the indptr in a csr matrix. * Parallelized by matrix rows: 1 warp/row */ struct FillCsrIndPtrWarpKernel { template - __device__ __forceinline__ static void Map(int tid, IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { - typedef cub::WarpReduce WarpReduce; - const int warps_per_block = kBaseThreadNum / 32; + __device__ __forceinline__ static void Map(int tid, + IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using nnvm::dim_t; + typedef cub::WarpReduce WarpReduce; + const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32; __shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block]; if (tid == 0) { indptr[tid] = 0; } - const int warp_id = tid / 32; // global warp id - const int warp_lane = threadIdx.x / 32; // local warp id within thread block - const int lane = tid & (32-1); // local thread id within warp + const dim_t warp_id = tid / 32; // global warp id + const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block + const dim_t lane = tid & (32-1); // local thread id within warp if (warp_id < num_rows) { - int lane_nnz = 0; - const int offset = warp_id * num_cols; - for (int j = lane; j < num_cols; j+=32) { + dim_t lane_nnz = 0; + const dim_t offset = warp_id * num_cols; + for (dim_t j = lane; j < num_cols; j+=32) { if (dns[offset+j] != 0) { lane_nnz++; } } - int aggr = WarpReduce(temp_storage[warp_lane]).Sum(lane_nnz); + dim_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(lane_nnz); if (lane == 0) { indptr[warp_id+1] = aggr; } @@ -119,27 +413,32 @@ struct FillCsrIndPtrWarpKernel { }; /*! - * \brief Warp kernel for initializing the col_idx and value array of the csr matrix + * \brief Warp kernel for initializing the col_idx and value array of the csr matrix. * Parallelized by matrix rows: 1 warp/row */ struct FillCsrColIdxAndValsWarpKernel { template - __device__ __forceinline__ static void Map(int tid, DType* val, CType* col_idx, - const IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { - typedef cub::WarpScan WarpScan; - const int warps_per_block = kBaseThreadNum / 32; + __device__ __forceinline__ static void Map(int tid, + DType* val, + CType* col_idx, + const IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using nnvm::dim_t; + typedef cub::WarpScan WarpScan; + const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32; __shared__ typename WarpScan::TempStorage temp_storage[warps_per_block]; - __shared__ volatile int warp_nnz[warps_per_block]; + __shared__ volatile dim_t warp_nnz[warps_per_block]; - const int warp_id = tid / 32; // global warp id - const int warp_lane = threadIdx.x / 32; // local warp id within thread block - const int lane = tid & (32-1); // local thread id within warp + const dim_t warp_id = tid / 32; // global warp id + const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block + const dim_t lane = tid & (32-1); // local thread id within warp if (warp_id < num_rows) { - const int offset = warp_id * num_cols; - int k = indptr[warp_id]; - int nnz; - for (int j = lane; j < num_cols+lane; j+=32) { + const dim_t offset = warp_id * num_cols; + dim_t k = indptr[warp_id]; + dim_t nnz; + for (dim_t j = lane; j < num_cols+lane; j+=32) { nnz = 0; if (j < num_cols) { if (dns[offset+j] != 0) { @@ -168,28 +467,33 @@ struct FillCsrColIdxAndValsWarpKernel { }; /*! - * \brief Block kernel for initializing the indptr in a csr tensor. + * \brief Block kernel for initializing the indptr in a csr matrix. * Parallelized by matrix rows: 1 threadBlock/row */ struct FillCsrIndPtrBlockKernel { template - __device__ __forceinline__ static void Map(int tid, IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { - typedef cub::BlockReduce BlockReduce; + __device__ __forceinline__ static void Map(int tid, + IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using mshadow::cuda::kBaseThreadNum; + using nnvm::dim_t; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; if (tid == 0) { indptr[tid] = 0; } if (blockIdx.x < num_rows) { - int lane_nnz = 0; - const int offset = blockIdx.x * num_cols; - for (int j = threadIdx.x; j < num_cols; j+=kBaseThreadNum) { + dim_t lane_nnz = 0; + const dim_t offset = blockIdx.x * num_cols; + for (dim_t j = threadIdx.x; j < num_cols; j+=kBaseThreadNum) { if (dns[offset+j] != 0) { lane_nnz++; } } - int aggr = BlockReduce(temp_storage).Sum(lane_nnz); + dim_t aggr = BlockReduce(temp_storage).Sum(lane_nnz); if (threadIdx.x == 0) { indptr[blockIdx.x+1] = aggr; } @@ -198,23 +502,29 @@ struct FillCsrIndPtrBlockKernel { }; /*! - * \brief Block kernel for initializing the col_idx and value array of the csr matrix + * \brief Block kernel for initializing the col_idx and value array of the csr matrix. * Parallelized by matrix rows: 1 threadBlock/row */ struct FillCsrColIdxAndValsBlockKernel { template - __device__ __forceinline__ static void Map(int tid, DType* val, CType* col_idx, - const IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { - typedef cub::BlockScan BlockScan; + __device__ __forceinline__ static void Map(int tid, + DType* val, + CType* col_idx, + const IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using mshadow::cuda::kBaseThreadNum; + using nnvm::dim_t; + typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ volatile int block_nnz; + __shared__ volatile dim_t block_nnz; if (blockIdx.x < num_rows) { - const int offset = blockIdx.x * num_cols; - int k = indptr[blockIdx.x]; - int nnz; - for (int j = threadIdx.x; j < num_cols+threadIdx.x; j+=kBaseThreadNum) { + const dim_t offset = blockIdx.x * num_cols; + dim_t k = indptr[blockIdx.x]; + dim_t nnz; + for (dim_t j = threadIdx.x; j < num_cols+threadIdx.x; j+=kBaseThreadNum) { nnz = 0; if (j < num_cols) { if (dns[offset+j] != 0) { @@ -243,8 +553,7 @@ struct FillCsrColIdxAndValsBlockKernel { }; /*! - * \brief - * GPU implementation of casting a dense matrix to csr type. + * \brief GPU implementation of casting a dense matrix to csr type. */ inline void CastStorageDnsCsrImpl(const OpContext& ctx, const gpu& gpu_dev, @@ -254,18 +563,24 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, CHECK_EQ(csr->storage_type(), kCSRStorage); CHECK_EQ(dns.shape_.ndim(), 2); CHECK_EQ(dns.shape_, csr->shape()); + using mshadow::Shape1; + using mxnet_op::Kernel; + using nnvm::dim_t; mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col_idx type - const index_t num_rows = dns.shape_[0]; - const index_t num_cols = dns.shape_[1]; - const int threads_per_warp = 32; - const int threads_per_block = kBaseThreadNum; - const int min_num_warps = 512; - int num_threads; - - csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); + const dim_t num_rows = dns.shape_[0]; + const dim_t num_cols = dns.shape_[1]; + const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize; + const dim_t threads_per_block = mshadow::cuda::kBaseThreadNum; + const dim_t min_num_warps = 512; + dim_t num_threads; + // TODO: remove kernel dependency on warpSize=32 + if (threads_per_warp != 32) { + LOG(FATAL) << "CastStorageDnsCsrImpl GPU kernels expect warpSize=32"; + } + csr->CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows+1)); IType* indptr = csr->aux_data(csr::kIndPtr).dptr(); DType* dns_data = dns.dptr(); @@ -277,32 +592,32 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, switch (kernel_version) { case 1: num_threads = num_rows; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); break; case 2: num_threads = num_rows * threads_per_warp; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); break; case 3: num_threads = num_rows * threads_per_block; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); break; default: if (num_cols < threads_per_warp) { num_threads = num_rows; - mxnet_op::Kernel::Launch(s, num_threads, - indptr, dns_data, num_rows, num_cols); + Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); } else if (num_cols < threads_per_block || num_rows > min_num_warps) { num_threads = num_rows * threads_per_warp; - mxnet_op::Kernel::Launch(s, num_threads, - indptr, dns_data, num_rows, num_cols); + Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); } else { num_threads = num_rows * threads_per_block; - mxnet_op::Kernel::Launch(s, num_threads, - indptr, dns_data, num_rows, num_cols); + Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); } break; } @@ -314,12 +629,12 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, temp_storage_bytes, indptr, indptr, - static_cast(num_rows+1), + num_rows+1, mshadow::Stream::GetStream(s)); // Allocate temporary storage mshadow::Tensor workspace = ctx.requested[0] - .get_space_typed(mshadow::Shape1(temp_storage_bytes), s); + .get_space_typed(Shape1(temp_storage_bytes), s); d_temp_storage = workspace.dptr_; // Compute indptr through inclusive prefix sum @@ -327,7 +642,7 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, temp_storage_bytes, indptr, indptr, - static_cast(num_rows+1), + num_rows+1, mshadow::Stream::GetStream(s)); // Receive total number of nnz values from device @@ -335,43 +650,43 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, CUDA_CALL(cudaMemcpy(&nnz, &(indptr[num_rows]), sizeof(IType), cudaMemcpyDeviceToHost)); // Allocate column index array and data array of the csr matrix - csr->CheckAndAllocAuxData(csr::kIdx, mshadow::Shape1(static_cast(nnz))); - csr->CheckAndAllocData(mshadow::Shape1(static_cast(nnz))); + csr->CheckAndAllocAuxData(csr::kIdx, Shape1(static_cast(nnz))); + csr->CheckAndAllocData(Shape1(static_cast(nnz))); // Compute and fill column index array and data array of the csr matrix switch (kernel_version) { case 1: num_threads = num_rows; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); break; case 2: num_threads = num_rows * threads_per_warp; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); break; case 3: num_threads = num_rows * threads_per_block; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); break; default: if (num_cols < threads_per_warp) { num_threads = num_rows; - mxnet_op::Kernel::Launch(s, num_threads, - csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), - indptr, dns_data, num_rows, num_cols); + Kernel::Launch(s, num_threads, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); } else if (num_cols < threads_per_block || num_rows > min_num_warps) { num_threads = num_rows * threads_per_warp; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); } else { num_threads = num_rows * threads_per_block; - mxnet_op::Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); } diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 46ae105b80b6..2ad1957a4648 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -20,21 +20,21 @@ namespace mxnet { namespace op { /*! - * \brief Kernel for marking row_idx of a RSP matrix per row + * \brief CPU Kernel for marking row_idx of a RSP tensor per row. */ struct MarkRspRowIdx { - // i represents the row index of the matrix data + // i represents the row index of the tensor data template MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* data, - const index_t num_cols) { + const index_t row_length) { index_t j = 0; - index_t offset = i * num_cols; - for (; j < num_cols; ++j) { + index_t offset = i * row_length; + for (; j < row_length; ++j) { if (data[offset+j] != 0) { break; } } - if (num_cols == j) { + if (row_length == j) { row_idx[i] = 0; // mark as zero for zero row } else { row_idx[i] = 1; // mark as one for non-zero row @@ -43,8 +43,7 @@ struct MarkRspRowIdx { }; /*! - * \brief - * CPU implementation of casting a dns tensor to rsp type. + * \brief CPU implementation of casting a dns tensor to rsp type. */ inline void CastStorageDnsRspImpl(const OpContext& ctx, const cpu& cpu_dev, @@ -89,19 +88,19 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, // TODO(haibin) Use memcopy instead will be much faster than assigning each individual element struct CastStorageRspDnsKernel { template - MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data, - DType* dns) { + MSHADOW_XINLINE static void Map(int i, const index_t row_length, const IType* idx, + const DType *data, DType* dns) { auto rid = idx[i]; - auto dns_offset = rid * width; - auto rsp_offset = i * width; - for (size_t col = 0; col < width; col++) { + auto dns_offset = rid * row_length; + auto rsp_offset = i * row_length; + for (size_t col = 0; col < row_length; col++) { dns[dns_offset + col] = data[rsp_offset + col]; } } }; /*! - * \brief This function assumes that the meomry for dns has been allocated already + * \brief This function assumes that the memory for dns has been allocated already * since the shape is known at binding stage. */ template @@ -128,7 +127,7 @@ void CastStorageRspDnsImpl(const OpContext& ctx, const NDArray& rsp, TBlob* dns) } /*! - * \brief This is the kernel for initializing the indptr in a csr tensor. + * \brief CPU kernel for initializing the indptr in a csr matrix. */ struct FillCsrIndPtr { /*! @@ -153,8 +152,7 @@ struct FillCsrIndPtr { }; /*! - * \brief This is the kernel for initializing the col_idx and value array - * of the csr tensor + * \brief CPU kernel for initializing the col_idx and value array of the csr matrix. */ struct FillCsrColIdxAndVals { /*! @@ -170,10 +168,10 @@ struct FillCsrColIdxAndVals { template MSHADOW_XINLINE static void Map(int i, DType* val, CType* col_idx, const IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { - const int offset = i * num_cols; - int k = indptr[i]; - for (int j = 0; j < num_cols; ++j) { + const index_t num_rows, const index_t num_cols) { + const index_t offset = i * num_cols; + IType k = indptr[i]; + for (index_t j = 0; j < num_cols; ++j) { if (dns[offset+j] != 0) { val[k] = dns[offset+j]; col_idx[k] = j; @@ -184,8 +182,7 @@ struct FillCsrColIdxAndVals { }; /*! - * \brief - * CPU implementation of casting a dns tensor to csr type. + * \brief CPU implementation of casting a dns matrix to csr type. */ inline void CastStorageDnsCsrImpl(const OpContext& ctx, const cpu& cpu_dev, @@ -226,7 +223,7 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, } /*! - * \brief This is the kernel for copying csr.data to its corresponding dns tensor. + * \brief This is the kernel for copying csr.data to its corresponding dns matrix. */ struct CopyCsrDataToDns { /*! @@ -250,7 +247,7 @@ struct CopyCsrDataToDns { }; /*! - * \brief Casts a csr tensor to dns format. + * \brief Casts a csr matrix to dns format. */ template void CastStorageCsrDnsImpl(const OpContext& ctx, const NDArray& csr, TBlob* dns) { diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 23fbb8fa2b05..2d073ac4e7a5 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -69,43 +69,49 @@ def test_elemwise_add_ex_multiple_stages(): # TODO(haibin) also add test for backward pass. def test_cast_storage_ex(): - def test_rsp_to_dns(shape): - rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') - dns_out = mx.nd.cast_storage(rsp, stype='default') - dns_expected = np.zeros(shape, dtype=default_dtype()) - if row_idx is not None: - for k, v in enumerate(row_idx): - dns_expected[v, :] = data[k] - assert same(dns_out.asnumpy(), dns_expected) - - def test_dns_to_rsp(shape): - dns_in = rand_ndarray(shape, 'default') - rsp_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), stype='row_sparse') - ret = mx.nd.cast_storage(rsp_out, stype='default') - assert same(ret.asnumpy(), dns_in.asnumpy()) + def test_rsp_to_dns(shape, density): + rsp_in, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse', density) + dns_out = mx.nd.cast_storage(rsp_in, stype='default') + assert same(rsp_in.asnumpy(), dns_out.asnumpy()) + + def test_dns_to_rsp(shape, density): + rsp_in, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse', density) + rsp_out = mx.nd.cast_storage(mx.nd.array(rsp_in.todense(), dtype=default_dtype()), stype='row_sparse') + assert same(rsp_in.asnumpy(), rsp_out.asnumpy()) def test_csr_to_dns(shape, density): csr_in, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr', density) - dns_out = csr_in.todense() + dns_out = mx.nd.cast_storage(csr_in, stype='default') assert same(csr_in.asnumpy(), dns_out.asnumpy()) def test_dns_to_csr(shape, density): csr_in, (indptr, colidx, data) = rand_sparse_ndarray(shape, 'csr', density) - dns_in = csr_in.todense() - csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), stype='csr') + csr_out = mx.nd.cast_storage(mx.nd.array(csr_in.todense(), dtype=default_dtype()), stype='csr') assert same(csr_in.asnumpy(), csr_out.asnumpy()) - shape = rand_shape_2d() - if default_context().device_type is 'cpu': - test_rsp_to_dns(shape) - test_dns_to_rsp(shape) - density = [1.00, 0.50, 0.10, 0.05, 0.01] for d in density: - test_csr_to_dns((rnd.randint(1, 10), rnd.randint( 1, 64)), d) - test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 1, 31)), d) # test gpu thread kernel - test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 32, 512)), d) # test gpu warp kernel - test_dns_to_csr((rnd.randint(1, 10), rnd.randint(513, 1024)), d) # test gpu block kernel + shape_2d = rand_shape_2d() + shape_3d = rand_shape_3d() + test_csr_to_dns(shape_2d, d) + test_dns_to_csr(shape_2d, d) + test_rsp_to_dns(shape_2d, d) + test_dns_to_rsp(shape_2d, d) + test_rsp_to_dns(shape_3d, d) + test_dns_to_rsp(shape_3d, d) + for i in range(4, 6): + shape = rand_shape_nd(i, 5) + test_dns_to_rsp(shape, d) + test_rsp_to_dns(shape, d) + # Test specific gpu kernels + if default_context().device_type is 'gpu': + test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 1, 32)), d) # test gpu thread kernel + test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 32, 512)), d) # test gpu warp kernel + test_dns_to_csr((rnd.randint(1, 10), rnd.randint(512, 1024)), d) # test gpu block kernel + test_dns_to_rsp((rnd.randint(1, 10), rnd.randint( 1, 32)), d) # test gpu thread kernel + test_dns_to_rsp((rnd.randint(1, 10), rnd.randint( 32, 512)), d) # test gpu warp kernel + test_dns_to_rsp((rnd.randint(1, 10), rnd.randint(512, 1024)), d) # test gpu block kernel + def test_sparse_dot():