Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU #11113

Merged
merged 5 commits into from
Jun 10, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 132 additions & 9 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a) {
target_stype = hint_has_value ? target_stype : kCSRStorage;
// dns, csr -> csr on CPU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) {
if (target_stype == kCSRStorage) {
if (dev_mask == mshadow::cpu::kDevMask) {
// dns, csr -> csr on CPU
if (target_stype == kCSRStorage && !param.transpose_b) {
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
DispatchMode::kFComputeEx);
// dns, csr/csr.T -> dns on CPU
} else if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
// dns, csr/csr.T -> dns on GPU
} else if (dev_mask == mshadow::gpu::kDevMask) {
Expand Down Expand Up @@ -327,7 +331,7 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = true;
}
}
if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a &&
if (!dispatched && !param.transpose_a &&
lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
ograd_stype == kDefaultStorage) {
if (type_assign(&lhs_grad_stype, kDefaultStorage) &&
Expand Down Expand Up @@ -655,7 +659,81 @@ struct DotDnsCsrCsrByRowBlocks {
}
};

/*!
* \brief CPU Kernel of dot(dns1, csr) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template<typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j+1]) continue;
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r*num_cols_r+col_idx] += data_l[r*num_cols_l+j] * val;
}
}
}
}
};

/*!
* \brief CPU Kernel of dot(dns1, csr.T) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrTransDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please complete documentation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, it's done.

*/
template<typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j+1]) continue;
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r*num_rows_r+j] += data_l[r*num_cols_l+col_idx] * val;
}
}
}
}
};

/*!
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
Expand Down Expand Up @@ -1031,13 +1109,58 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
}

/*
* \brief Impl of dot(dns, csr) = dense (GPU only)
* \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
*/
inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev,
const TBlob& dns, const NDArray& rhs,
const OpReqType req, NDArray* ret,
const bool transpose_b) {
LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU";
const TBlob& dns, const NDArray& rhs,
const OpReqType req, NDArray* ret,
const bool transpose_b) {
if (req == kNullOp) return;
CHECK_EQ(rhs.storage_type(), kCSRStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
return;
}

using nnvm::dim_t;

const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
const TBlob& data_l = dns;
const TBlob data_out = ret->data();

MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp16 will fail with this branch. I think there's a MSHADOW_REAL_TYPE_SWITCH

MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type
dim_t num_threads;
if (req == kWriteTo) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

req == writeto || req == writeinplace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, it's Done ! May you review it again :)

num_threads = data_out.Size();
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
s, num_threads, data_out.dptr<DType>());
}
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
// seg by output row
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (transpose_b) {
mxnet_op::Kernel<DotDnsCsrTransDnsByRowBlocks, cpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(),
data_r.dptr<DType>(), indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(), seg_len,
dns.shape_[0], dns.shape_[1],
rhs.shape()[0], rhs.shape()[1]);
} else {
mxnet_op::Kernel<DotDnsCsrDnsByRowBlocks, cpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(),
data_r.dptr<DType>(), indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(), seg_len,
dns.shape_[0], dns.shape_[1],
rhs.shape()[0], rhs.shape()[1]);
}
});
});
});
}

inline bool DotShape(const nnvm::NodeAttrs& attrs,
Expand Down