diff --git a/benchmark/python/cast_storage.py b/benchmark/python/cast_storage.py new file mode 100644 index 000000000000..a92da7295d2b --- /dev/null +++ b/benchmark/python/cast_storage.py @@ -0,0 +1,70 @@ +import ctypes + +from mxnet.test_utils import * +import os +import time +import argparse + +from mxnet.base import check_call, _LIB + +parser = argparse.ArgumentParser(description="Benchmark cast storage operators", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet') +args = parser.parse_args() + +def measure_cost(repeat, f, *args, **kwargs): + start = time.time() + results = [] + for i in range(repeat): + (f(*args, **kwargs)).wait_to_read() + end = time.time() + diff = end - start + return diff / repeat + + +def run_cast_storage_synthetic(): + def dns_to_csr(m, n, density, ctx, repeat): + set_default_context(ctx) + data_shape = (m, n) + dns_data = rand_ndarray(data_shape, 'csr', density).todense() + dns_data.wait_to_read() + + # do one warm up run, verify correctness + assert same(mx.nd.cast_storage(dns_data, stype='csr').asnumpy(), dns_data.asnumpy()) + + # start benchmarking + cost = measure_cost(repeat, mx.nd.cast_storage, dns_data, stype='csr') + results = '{:10.1f} {:>10} {:8d} {:8d} {:10.2f}'.format(density*100, str(ctx), m, n, cost*1000) + print(results) + + check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads))) + + # params + # m number of rows + # n number of columns + # density density of the matrix + # num_repeat number of benchmark runs to average over + # contexts mx.cpu(), mx.gpu() + # note: benchmark different contexts separately; to benchmark cpu, compile without CUDA + m = [ 512, 512] + n = [50000, 100000] + density = [1.00, 0.80, 0.60, 0.40, 0.20, 0.10, 0.05, 0.02, 0.01] + num_repeat = 10 + contexts = [mx.gpu()] + + # run benchmark + print("==================================================") + print(" cast_storage benchmark: dense to csr, size m x n ") + print("==================================================") + headline = '{:>10} {:>10} {:>8} {:>8} {:>10}'.format('density(%)', 'context', 'm', 'n', 'time(ms)') + print(headline) + for i in range(len(n)): + for ctx in contexts: + for den in density: + dns_to_csr(m[i], n[i], den, ctx, num_repeat) + print("") + print("==================================================") + + +if __name__ == "__main__": + run_cast_storage_synthetic() diff --git a/src/common/utils.cc b/src/common/utils.cc index 4bcae02e990c..c0f7b4603f15 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -10,14 +10,12 @@ namespace mxnet { namespace common { - template<> -void CastStorageDispatch(mshadow::Stream* s, +void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output) { - mxnet::op::CastStorageComputeImpl(s, input, output); + mxnet::op::CastStorageComputeImpl(ctx, input, output); } - } // namespace common } // namespace mxnet diff --git a/src/common/utils.cu b/src/common/utils.cu index 7221a2b6ec6c..67cfdd84671e 100644 --- a/src/common/utils.cu +++ b/src/common/utils.cu @@ -11,10 +11,10 @@ namespace mxnet { namespace common { template<> -void CastStorageDispatch(mshadow::Stream* s, +void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output) { - mxnet::op::CastStorageComputeImpl(s, input, output); + mxnet::op::CastStorageComputeImpl(ctx, input, output); } } // namespace common diff --git a/src/common/utils.h b/src/common/utils.h index 254b6ce5bd21..a5371bfc2d58 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -24,11 +24,10 @@ #include namespace mxnet { - namespace common { template -void CastStorageDispatch(mshadow::Stream* s, const NDArray& input, const NDArray& output); +void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output); /* * \brief Get the corresponding tensor blobs from default storage NDArrays. @@ -55,7 +54,7 @@ inline bool GetDefaultBlobs(const std::vector& nds, << "doesn't support NDArray inputs with non-default storage."; } NDArray temp(nd.shape(), nd.ctx(), false); - CastStorageDispatch(ctx.get_stream(), nd, temp); + CastStorageDispatch(ctx, nd, temp); temps->push_back(temp); blobs->push_back(temp.data()); casted = true; @@ -91,7 +90,7 @@ inline void CastNonDefaultStorage(const std::vector& dst, << "You are probably executing an operator which " << "doesn't support NDArray inputs with non-default storage."; } - CastStorageDispatch(ctx.get_stream(), src[src_idx++], dst[i]); + CastStorageDispatch(ctx, src[src_idx++], dst[i]); } } CHECK_EQ(src_idx, src.size()) << "Not all src NDArrays are casted"; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 4a402def143d..14b4922ea480 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -410,7 +410,7 @@ inline void CopyFromToDnsImpl(const NDArray from, NDArray *to, RunContext ctx) { // Make a copy of an NDArray based on storage type template -void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { +void CopyFromToImpl(const NDArray from, NDArray *to, RunContext rctx) { using namespace std; using namespace mshadow; // if storage type doesn't match, cast the storage first @@ -423,10 +423,20 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { << " to stype = " << to_stype << " is not supported"; const auto from_ctx = from.ctx(); const auto to_ctx = to->ctx(); - auto s = ctx.get_stream(); + auto s = rctx.get_stream(); + bool is_train = mxnet::autograd::AutogradRuntime::Get()->IsTraining(); + std::vector requested; + if (is_same::value && from_stype != to_stype) { + requested.push_back(ResourceManager::Get()->Request(from_ctx, + ResourceRequest(ResourceRequest::kTempSpace))); + } + OpContext opctx{is_train, + rctx, + engine::CallbackOnComplete(), + requested}; if (from_ctx == to_ctx && from_stype != to_stype) { // same ctx, different stypes, use cast op directly without copying - common::CastStorageDispatch(s, from, *to); + common::CastStorageDispatch(opctx, from, *to); } else { NDArray casted_nd; // an intermediate result before copying from to to if (from_stype == to_stype) { @@ -439,22 +449,22 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { casted_nd = NDArray(to_stype, shape, from_ctx); } // convert from_nd to the same stype as to_nd - common::CastStorageDispatch(s, from, casted_nd); + common::CastStorageDispatch(opctx, from, casted_nd); } if (to_stype == kDefaultStorage) { - CopyFromToDnsImpl(casted_nd, to, ctx); + CopyFromToDnsImpl(casted_nd, to, rctx); } else if (to_stype == kRowSparseStorage) { - CopyFromToRspImpl(casted_nd, to, ctx); + CopyFromToRspImpl(casted_nd, to, rctx); } else if (to_stype == kCSRStorage) { - CopyFromToCsrImpl(casted_nd, to, ctx); + CopyFromToCsrImpl(casted_nd, to, rctx); } else { LOG(FATAL) << "unknown storage type" << to_stype; } } if (is_same::value || is_same::value) { // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + rctx.get_stream()->Wait(); } } diff --git a/src/operator/tensor/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh index 0d4e601d0d2e..47fb3b42c356 100644 --- a/src/operator/tensor/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -9,15 +9,377 @@ #include #include +#include + namespace mxnet { namespace op { +using mshadow::cuda::kBaseThreadNum; -inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { +inline void CastStorageDnsRspImpl(const OpContext& ctx, const gpu& gpu_dev, const TBlob& dns, NDArray* rsp) { LOG(FATAL) << "CastStorageDnsRspImpl gpu version is not implemented."; } -inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { - LOG(FATAL) << "CastStorageDnsCsrImpl gpu version is not implemented."; +/*! + * \brief Thread kernel for initializing the indptr in a csr tensor. + * Parallelized by matrix rows: 1 thread/row + */ +struct FillCsrIndPtrThreadKernel { + /*! + * \brief + * \param tid global thread id + * \param indptr index pointer array of the csr matrix + * \param dns dense matrix + * \param num_rows number of rows of the dense matrix + * \param num_cols number of columns of the dense matrix + */ + template + __device__ __forceinline__ static void Map(int tid, IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + if (tid == 0) { + indptr[tid] = 0; + } + if (tid < num_rows) { + int nnz = 0; + const int offset = tid * num_cols; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + nnz++; + } + } + indptr[tid+1] = nnz; + } + } +}; + +/*! + * \brief Thread kernel for initializing the col_idx and value array of the csr matrix + * Parallelized by matrix rows: 1 thread/row + */ +struct FillCsrColIdxAndValsThreadKernel { + /*! + * \brief + * \param tid global thread id + * \param val data array of the csr matrix + * \param col_idx column index array of the csr matrix + * \param indptr index pointer array of the csr matrix + * \param dns dense matrix + * \param num_rows number of rows of the dense matrix + * \param num_cols number of columns of the dense matrix + */ + template + __device__ __forceinline__ static void Map(int tid, DType* val, CType* col_idx, + const IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + if (tid < num_rows) { + const int offset = tid * num_cols; + int k = indptr[tid]; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + val[k] = dns[offset+j]; + col_idx[k] = j; + ++k; + } + } + } + } +}; + +/*! + * \brief Warp kernel for initializing the indptr in a csr matrix + * Parallelized by matrix rows: 1 warp/row + */ +struct FillCsrIndPtrWarpKernel { + template + __device__ __forceinline__ static void Map(int tid, IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + typedef cub::WarpReduce WarpReduce; + const int warps_per_block = kBaseThreadNum / 32; + __shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block]; + + if (tid == 0) { + indptr[tid] = 0; + } + const int warp_id = tid / 32; // global warp id + const int warp_lane = threadIdx.x / 32; // local warp id within thread block + const int lane = tid & (32-1); // local thread id within warp + if (warp_id < num_rows) { + int lane_nnz = 0; + const int offset = warp_id * num_cols; + for (int j = lane; j < num_cols; j+=32) { + if (dns[offset+j] != 0) { + lane_nnz++; + } + } + int aggr = WarpReduce(temp_storage[warp_lane]).Sum(lane_nnz); + if (lane == 0) { + indptr[warp_id+1] = aggr; + } + } + } +}; + +/*! + * \brief Warp kernel for initializing the col_idx and value array of the csr matrix + * Parallelized by matrix rows: 1 warp/row + */ +struct FillCsrColIdxAndValsWarpKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* val, CType* col_idx, + const IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + typedef cub::WarpScan WarpScan; + const int warps_per_block = kBaseThreadNum / 32; + __shared__ typename WarpScan::TempStorage temp_storage[warps_per_block]; + __shared__ volatile int warp_nnz[warps_per_block]; + + const int warp_id = tid / 32; // global warp id + const int warp_lane = threadIdx.x / 32; // local warp id within thread block + const int lane = tid & (32-1); // local thread id within warp + if (warp_id < num_rows) { + const int offset = warp_id * num_cols; + int k = indptr[warp_id]; + int nnz; + for (int j = lane; j < num_cols+lane; j+=32) { + nnz = 0; + if (j < num_cols) { + if (dns[offset+j] != 0) { + nnz++; + } + } + if (lane == 31) { + warp_nnz[warp_lane] = nnz; + } + // Compute index each thread has to write to + WarpScan(temp_storage[warp_lane]).ExclusiveSum(nnz, nnz); + if (j < num_cols) { + if (dns[offset+j] != 0) { + val[k+nnz] = dns[offset+j]; + col_idx[k+nnz] = j; + } + } + if (lane == 31) { + warp_nnz[warp_lane] += nnz; + } + __syncwarp(); + k += warp_nnz[warp_lane]; + } + } + } +}; + +/*! + * \brief Block kernel for initializing the indptr in a csr tensor. + * Parallelized by matrix rows: 1 threadBlock/row + */ +struct FillCsrIndPtrBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + if (tid == 0) { + indptr[tid] = 0; + } + if (blockIdx.x < num_rows) { + int lane_nnz = 0; + const int offset = blockIdx.x * num_cols; + for (int j = threadIdx.x; j < num_cols; j+=kBaseThreadNum) { + if (dns[offset+j] != 0) { + lane_nnz++; + } + } + int aggr = BlockReduce(temp_storage).Sum(lane_nnz); + if (threadIdx.x == 0) { + indptr[blockIdx.x+1] = aggr; + } + } + } +}; + +/*! + * \brief Block kernel for initializing the col_idx and value array of the csr matrix + * Parallelized by matrix rows: 1 threadBlock/row + */ +struct FillCsrColIdxAndValsBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* val, CType* col_idx, + const IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ volatile int block_nnz; + + if (blockIdx.x < num_rows) { + const int offset = blockIdx.x * num_cols; + int k = indptr[blockIdx.x]; + int nnz; + for (int j = threadIdx.x; j < num_cols+threadIdx.x; j+=kBaseThreadNum) { + nnz = 0; + if (j < num_cols) { + if (dns[offset+j] != 0) { + nnz++; + } + } + if (threadIdx.x == kBaseThreadNum-1) { + block_nnz = nnz; + } + // Compute index each thread has to write to + BlockScan(temp_storage).ExclusiveSum(nnz, nnz); + if (j < num_cols) { + if (dns[offset+j] != 0) { + val[k+nnz] = dns[offset+j]; + col_idx[k+nnz] = j; + } + } + if (threadIdx.x == kBaseThreadNum-1) { + block_nnz += nnz; + } + __syncthreads(); + k += block_nnz; + } + } + } +}; + +/*! + * \brief + * GPU implementation of casting a dense matrix to csr type. + */ +inline void CastStorageDnsCsrImpl(const OpContext& ctx, + const gpu& gpu_dev, + const TBlob& dns, + NDArray* csr) { + CHECK(csr != nullptr); + CHECK_EQ(csr->storage_type(), kCSRStorage); + CHECK_EQ(dns.shape_.ndim(), 2); + CHECK_EQ(dns.shape_, csr->shape()); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col_idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + const int threads_per_warp = 32; + const int threads_per_block = kBaseThreadNum; + const int min_num_warps = 512; + int num_threads; + + csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); + IType* indptr = csr->aux_data(csr::kIndPtr).dptr(); + DType* dns_data = dns.dptr(); + + // Different kernel versions are optimized for different matrix instances + // (1) 'Thread kernel' (one thread computing one row) + // (2) 'Warp kernel' (one warp computing one row) + // (3) 'Block kernel' (one thread block computing one row) + const int kernel_version = 0; + switch (kernel_version) { + case 1: + num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); + break; + case 2: + num_threads = num_rows * threads_per_warp; + mxnet_op::Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); + break; + case 3: + num_threads = num_rows * threads_per_block; + mxnet_op::Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); + break; + default: + if (num_cols < threads_per_warp) { + num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); + } else if (num_cols < threads_per_block || num_rows > min_num_warps) { + num_threads = num_rows * threads_per_warp; + mxnet_op::Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); + } else { + num_threads = num_rows * threads_per_block; + mxnet_op::Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); + } + break; + } + + // Determine temporary device storage requirements + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + indptr, + indptr, + static_cast(num_rows+1), + mshadow::Stream::GetStream(s)); + + // Allocate temporary storage + mshadow::Tensor workspace = ctx.requested[0] + .get_space_typed(mshadow::Shape1(temp_storage_bytes), s); + d_temp_storage = workspace.dptr_; + + // Compute indptr through inclusive prefix sum + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + indptr, + indptr, + static_cast(num_rows+1), + mshadow::Stream::GetStream(s)); + + // Receive total number of nnz values from device + IType nnz = 0; + CUDA_CALL(cudaMemcpy(&nnz, &(indptr[num_rows]), sizeof(IType), cudaMemcpyDeviceToHost)); + + // Allocate column index array and data array of the csr matrix + csr->CheckAndAllocAuxData(csr::kIdx, mshadow::Shape1(static_cast(nnz))); + csr->CheckAndAllocData(mshadow::Shape1(static_cast(nnz))); + + // Compute and fill column index array and data array of the csr matrix + switch (kernel_version) { + case 1: + num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + break; + case 2: + num_threads = num_rows * threads_per_warp; + mxnet_op::Kernel::Launch(s, num_threads, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + break; + case 3: + num_threads = num_rows * threads_per_block; + mxnet_op::Kernel::Launch(s, num_threads, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + break; + default: + if (num_cols < threads_per_warp) { + num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + } else if (num_cols < threads_per_block || num_rows > min_num_warps) { + num_threads = num_rows * threads_per_warp; + mxnet_op::Kernel::Launch(s, num_threads, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + } else { + num_threads = num_rows * threads_per_block; + mxnet_op::Kernel::Launch(s, num_threads, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + } + break; + } + }); + }); + }); } } // namespace op diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 915c44bcceb3..46ae105b80b6 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -46,12 +46,16 @@ struct MarkRspRowIdx { * \brief * CPU implementation of casting a dns tensor to rsp type. */ -inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { +inline void CastStorageDnsRspImpl(const OpContext& ctx, + const cpu& cpu_dev, + const TBlob& dns, + NDArray* rsp) { using namespace rowsparse; using namespace mshadow; CHECK(rsp != nullptr); CHECK_EQ(rsp->storage_type(), kRowSparseStorage); CHECK_EQ(dns.shape_, rsp->shape()); + mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(kIdx), RType, { // row idx type const index_t num_rows = dns.shape_[0]; @@ -101,9 +105,8 @@ struct CastStorageRspDnsKernel { * since the shape is known at binding stage. */ template -void CastStorageRspDnsImpl(mshadow::Stream* s, const NDArray& rsp, TBlob* dns) { - using namespace mshadow; - using namespace mshadow::expr; +void CastStorageRspDnsImpl(const OpContext& ctx, const NDArray& rsp, TBlob* dns) { + mshadow::Stream* s = ctx.get_stream(); CHECK_EQ(rsp.storage_type(), kRowSparseStorage); MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { @@ -184,11 +187,15 @@ struct FillCsrColIdxAndVals { * \brief * CPU implementation of casting a dns tensor to csr type. */ -inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { +inline void CastStorageDnsCsrImpl(const OpContext& ctx, + const cpu& cpu_dev, + const TBlob& dns, + NDArray* csr) { CHECK(csr != nullptr); CHECK_EQ(csr->storage_type(), kCSRStorage); CHECK_EQ(dns.shape_.ndim(), 2); CHECK_EQ(dns.shape_, csr->shape()); + mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type @@ -246,11 +253,12 @@ struct CopyCsrDataToDns { * \brief Casts a csr tensor to dns format. */ template -void CastStorageCsrDnsImpl(mshadow::Stream* s, const NDArray& csr, TBlob* dns) { +void CastStorageCsrDnsImpl(const OpContext& ctx, const NDArray& csr, TBlob* dns) { CHECK(dns != nullptr); CHECK_EQ(csr.storage_type(), kCSRStorage); CHECK_EQ(dns->shape_.ndim(), 2); CHECK_EQ(dns->shape_, csr.shape()); + mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type @@ -270,25 +278,23 @@ void CastStorageCsrDnsImpl(mshadow::Stream* s, const NDArray& csr, TBlob* d } template -void CastStorageComputeImpl(mshadow::Stream* s, +void CastStorageComputeImpl(const OpContext& ctx, const NDArray& input, const NDArray& output) { - using namespace mshadow; - using namespace mshadow::expr; const auto src_stype = input.storage_type(); const auto dst_stype = output.storage_type(); if (src_stype == kRowSparseStorage && dst_stype == kDefaultStorage) { TBlob ret = output.data(); - CastStorageRspDnsImpl(s, input, &ret); + CastStorageRspDnsImpl(ctx, input, &ret); } else if (src_stype == kDefaultStorage && dst_stype == kRowSparseStorage) { NDArray ret = output; // get rid of the const qualifer - CastStorageDnsRspImpl(s, input.data(), &ret); + CastStorageDnsRspImpl(ctx, xpu(), input.data(), &ret); } else if (src_stype == kDefaultStorage && dst_stype == kCSRStorage) { NDArray ret = output; // get rid of the const qualifer - CastStorageDnsCsrImpl(s, input.data(), &ret); + CastStorageDnsCsrImpl(ctx, xpu(), input.data(), &ret); } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { TBlob ret = output.data(); - CastStorageCsrDnsImpl(s, input, &ret); + CastStorageCsrDnsImpl(ctx, input, &ret); } else { LOG(FATAL) << "Not implemented"; } @@ -326,12 +332,9 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), 1); CHECK_EQ(outputs.size(), 1); - CastStorageComputeImpl(s, inputs[0], outputs[0]); + CastStorageComputeImpl(ctx, inputs[0], outputs[0]); } } // namespace op diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc index c435146a730b..f32133171130 100644 --- a/src/operator/tensor/cast_storage.cc +++ b/src/operator/tensor/cast_storage.cc @@ -22,6 +22,10 @@ NNVM_REGISTER_OP(cast_storage) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FInferStorageType", CastStorageInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", IdentityCompute) .set_attr("FComputeEx", CastStorageComputeEx) .add_argument("data", "NDArray-or-Symbol", "The input.") diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 9cd041201e51..adc4c3b903bf 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -5,7 +5,7 @@ from test_operator import * from test_optimizer import * from test_random import * -from test_sparse_operator import test_sparse_nd_zeros, test_sparse_dot +from test_sparse_operator import test_cast_storage_ex, test_sparse_dot, test_sparse_nd_zeros from test_sparse_ndarray import test_create_csr, test_create_row_sparse import mxnet as mx import numpy as np diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index efcbdf288594..2495fdc51cc4 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -84,23 +84,28 @@ def test_dns_to_rsp(shape): ret = mx.nd.cast_storage(rsp_out, stype='default') assert same(ret.asnumpy(), dns_in.asnumpy()) - def test_csr_to_dns(shape): - csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') - mx_dns = csr.todense() - np_dns = sp.csr_matrix((values, indices, indptr), shape).todense() - assert_almost_equal(mx_dns.asnumpy(), np_dns) - - def test_dns_to_csr(dns_in): - dns_in = np.array(dns_in) + def test_csr_to_dns(shape, density): + csr_in, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr', density) + dns_out = csr_in.todense() + assert same(csr_in.asnumpy(), dns_out.asnumpy()) + + def test_dns_to_csr(shape, density): + csr_in, (indptr, colidx, data) = rand_sparse_ndarray(shape, 'csr', density) + dns_in = csr_in.todense() csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), stype='csr') - ret = mx.nd.cast_storage(csr_out, stype='default') - assert same(ret.asnumpy(), dns_in) + assert same(csr_in.asnumpy(), csr_out.asnumpy()) shape = rand_shape_2d() - test_rsp_to_dns(shape) - test_dns_to_rsp(shape) - test_csr_to_dns((4, 4)) - test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) + if default_context().device_type is 'cpu': + test_rsp_to_dns(shape) + test_dns_to_rsp(shape) + + density = [1.00, 0.50, 0.10, 0.05, 0.01] + for d in density: + test_csr_to_dns((rnd.randint(1, 10), rnd.randint( 1, 64)), d) + test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 1, 31)), d) # test gpu thread kernel + test_dns_to_csr((rnd.randint(1, 10), rnd.randint( 32, 512)), d) # test gpu warp kernel + test_dns_to_csr((rnd.randint(1, 10), rnd.randint(513, 1024)), d) # test gpu block kernel def test_sparse_dot(): @@ -133,16 +138,15 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): rtol=1e-3, atol=1e-4) lhs_shape = rand_shape_2d(50, 200) - test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False) - test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True) - test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) - test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) - test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) - test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) - test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, 0.05) - # TODO(haibin/jun/stefan) test dot(csr.T, row_sparse) = dns gpu version - if Context.default_ctx == mx.cpu(): - test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, 0.05) + test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False) # test gpu SpMV + test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True ) # (vector kernel) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False) # test gpu SpMM + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True ) # (scalar kernel) + if default_context().device_type is 'cpu': + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True ) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, 0.05) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True , 0.05) def test_sparse_slice():