Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
implement dot(dns, csr/csr.T)=dns on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaotaoChen committed Jun 2, 2018
1 parent 9854583 commit dd989d6
Showing 1 changed file with 132 additions and 9 deletions.
141 changes: 132 additions & 9 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a) {
target_stype = hint_has_value ? target_stype : kCSRStorage;
// dns, csr -> csr on CPU
if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) {
if (target_stype == kCSRStorage) {
if (dev_mask == mshadow::cpu::kDevMask) {
// dns, csr -> csr on CPU
if (target_stype == kCSRStorage && !param.transpose_b) {
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
DispatchMode::kFComputeEx);
// dns, csr/csr.T -> dns on CPU
} else if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
// dns, csr/csr.T -> dns on GPU
} else if (dev_mask == mshadow::gpu::kDevMask) {
Expand Down Expand Up @@ -327,7 +331,7 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = true;
}
}
if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a &&
if (!dispatched && !param.transpose_a &&
lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
ograd_stype == kDefaultStorage) {
if (type_assign(&lhs_grad_stype, kDefaultStorage) &&
Expand Down Expand Up @@ -655,7 +659,81 @@ struct DotDnsCsrCsrByRowBlocks {
}
};

/*!
* \brief CPU Kernel of dot(dns1, csr) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template<typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j+1]) continue;
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r*num_cols_r+col_idx] += data_l[r*num_cols_l+j] * val;
}
}
}
}
};

/*!
* \brief CPU Kernel of dot(dns1, csr.T) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrTransDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template<typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j+1]) continue;
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r*num_rows_r+j] += data_l[r*num_cols_l+col_idx] * val;
}
}
}
}
};

/*!
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
Expand Down Expand Up @@ -1031,13 +1109,58 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
}

/*
* \brief Impl of dot(dns, csr) = dense (GPU only)
* \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
*/
inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev,
const TBlob& dns, const NDArray& rhs,
const OpReqType req, NDArray* ret,
const bool transpose_b) {
LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU";
const TBlob& dns, const NDArray& rhs,
const OpReqType req, NDArray* ret,
const bool transpose_b) {
if (req == kNullOp) return;
CHECK_EQ(rhs.storage_type(), kCSRStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
return;
}

using nnvm::dim_t;

const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
const TBlob& data_l = dns;
const TBlob data_out = ret->data();

MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type
dim_t num_threads;
if (req == kWriteTo) {
num_threads = data_out.Size();
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
s, num_threads, data_out.dptr<DType>());
}
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
// seg by output row
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (transpose_b) {
mxnet_op::Kernel<DotDnsCsrTransDnsByRowBlocks, cpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(),
data_r.dptr<DType>(), indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(), seg_len,
dns.shape_[0], dns.shape_[1],
rhs.shape()[0], rhs.shape()[1]);
} else {
mxnet_op::Kernel<DotDnsCsrDnsByRowBlocks, cpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(),
data_r.dptr<DType>(), indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(), seg_len,
dns.shape_[0], dns.shape_[1],
rhs.shape()[0], rhs.shape()[1]);
}
});
});
});
}

inline bool DotShape(const nnvm::NodeAttrs& attrs,
Expand Down

0 comments on commit dd989d6

Please sign in to comment.