From c481dff0fbc2b2912344f82f75c903ad0ddd7cb3 Mon Sep 17 00:00:00 2001 From: XiaotaoChen Date: Fri, 1 Jun 2018 08:21:43 +0800 Subject: [PATCH] implement dot(dns, csr/csr.T)=dns on cpu --- src/operator/tensor/dot-inl.h | 136 ++++++++++++++++++++++++++++++++-- 1 file changed, 130 insertions(+), 6 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index ffdb706e5e3c..f4e0f9239ad9 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && !param.transpose_a) { target_stype = hint_has_value ? target_stype : kCSRStorage; - // dns, csr -> csr on CPU - if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { - if (target_stype == kCSRStorage) { + if (dev_mask == mshadow::cpu::kDevMask) { + // dns, csr -> csr on CPU + if (target_stype == kCSRStorage && !param.transpose_b) { dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx); + // dns, csr/csr.T -> dns on CPU + } else if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); } // dns, csr/csr.T -> dns on GPU } else if (dev_mask == mshadow::gpu::kDevMask) { @@ -327,7 +331,8 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, dispatched = true; } } - if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && + // if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && + if (!dispatched && !param.transpose_a && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && ograd_stype == kDefaultStorage) { if (type_assign(&lhs_grad_stype, kDefaultStorage) && @@ -655,7 +660,81 @@ struct DotDnsCsrCsrByRowBlocks { } }; +/*! + * \brief CPU Kernel of dot(dns1, csr) = dns2 + * Parallelization by row blocks + */ +struct DotDnsCsrDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_CINLINE static void Map(int i, + DType* out, + const DType* data_l, + const DType* data_r, + const IType* indptr_r, + const CType* col_idx_r, + const nnvm::dim_t seg_len, + const nnvm::dim_t num_rows_l, + const nnvm::dim_t num_cols_l, + const nnvm::dim_t num_rows_r, + const nnvm::dim_t num_cols_r) { + using nnvm::dim_t; + const dim_t seg_start = i * seg_len; + if (seg_start >= num_rows_l) return; + const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); + for (dim_t j = 0; j < num_rows_r; ++j) { + if (indptr_r[j] == indptr_r[j+1]) continue; + for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) { + const CType col_idx = col_idx_r[k]; + const DType val = data_r[k]; + for (dim_t r = seg_start; r < seg_end; ++r) { + out[r*num_cols_r+col_idx] += data_l[r*num_cols_l+j] * val; + } + } + } + } +}; +/*! + * \brief CPU Kernel of dot(dns1, csr.T) = dns2 + * Parallelization by row blocks + */ +struct DotDnsCsrTransDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_CINLINE static void Map(int i, + DType* out, + const DType* data_l, + const DType* data_r, + const IType* indptr_r, + const CType* col_idx_r, + const nnvm::dim_t seg_len, + const nnvm::dim_t num_rows_l, + const nnvm::dim_t num_cols_l, + const nnvm::dim_t num_rows_r, + const nnvm::dim_t num_cols_r) { + using nnvm::dim_t; + const dim_t seg_start = i * seg_len; + if (seg_start >= num_rows_l) return; + const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); + for (dim_t j = 0; j < num_rows_r; ++j) { + if (indptr_r[j] == indptr_r[j+1]) continue; + for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) { + const CType col_idx = col_idx_r[k]; + const DType val = data_r[k]; + for (dim_t r = seg_start; r < seg_end; ++r) { + out[r*num_rows_r+j] += data_l[r*num_cols_l+col_idx] * val; + } + } + } + } +}; /*! * \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2 @@ -1031,13 +1110,58 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, } /* - * \brief Impl of dot(dns, csr) = dense (GPU only) + * \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns */ inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev, const TBlob& dns, const NDArray& rhs, const OpReqType req, NDArray* ret, const bool transpose_b) { - LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; + if (kNullOp == req) return; + CHECK_EQ(rhs.storage_type(), kCSRStorage); + mshadow::Stream* s = ctx.get_stream(); + if (!rhs.storage_initialized()) { + FillZerosCsrImpl(s, *ret); + return; + } + + using nnvm::dim_t; + + const TBlob data_r = rhs.data(); + const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); + const TBlob col_idx_r = rhs.aux_data(csr::kIdx); + const TBlob& data_l = dns; + const TBlob data_out = ret->data(); + + MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type + dim_t num_threads; + if (kWriteTo == req) { + num_threads = data_out.Size(); + mxnet_op::Kernel::Launch( + s, num_threads, data_out.dptr()); + } + num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + //seg by output row + dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (transpose_b) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr() + , data_r.dptr(), indptr_r.dptr() + , col_idx_r.dptr(), seg_len + , dns.shape_[0], dns.shape_[1] + , rhs.shape()[0], rhs.shape()[1]); + } else { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr() + , data_r.dptr(), indptr_r.dptr() + , col_idx_r.dptr(), seg_len + , dns.shape_[0], dns.shape_[1] + , rhs.shape()[0], rhs.shape()[1]); + } + }); + }); + }); } inline bool DotShape(const nnvm::NodeAttrs& attrs,