-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU #11113
Changes from 1 commit
dd989d6
8f53309
4a7389c
08394b4
84cc230
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, | |
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && | ||
!param.transpose_a) { | ||
target_stype = hint_has_value ? target_stype : kCSRStorage; | ||
// dns, csr -> csr on CPU | ||
if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { | ||
if (target_stype == kCSRStorage) { | ||
if (dev_mask == mshadow::cpu::kDevMask) { | ||
// dns, csr -> csr on CPU | ||
if (target_stype == kCSRStorage && !param.transpose_b) { | ||
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, | ||
DispatchMode::kFComputeEx); | ||
// dns, csr/csr.T -> dns on CPU | ||
} else if (target_stype == kDefaultStorage) { | ||
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, | ||
DispatchMode::kFComputeEx); | ||
} | ||
// dns, csr/csr.T -> dns on GPU | ||
} else if (dev_mask == mshadow::gpu::kDevMask) { | ||
|
@@ -327,7 +331,7 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, | |
dispatched = true; | ||
} | ||
} | ||
if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && | ||
if (!dispatched && !param.transpose_a && | ||
lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && | ||
ograd_stype == kDefaultStorage) { | ||
if (type_assign(&lhs_grad_stype, kDefaultStorage) && | ||
|
@@ -655,7 +659,81 @@ struct DotDnsCsrCsrByRowBlocks { | |
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Kernel of dot(dns1, csr) = dns2 | ||
* Parallelization by row blocks | ||
*/ | ||
struct DotDnsCsrDnsByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
*/ | ||
template<typename DType, typename IType, typename CType> | ||
MSHADOW_CINLINE static void Map(int i, | ||
DType* out, | ||
const DType* data_l, | ||
const DType* data_r, | ||
const IType* indptr_r, | ||
const CType* col_idx_r, | ||
const nnvm::dim_t seg_len, | ||
const nnvm::dim_t num_rows_l, | ||
const nnvm::dim_t num_cols_l, | ||
const nnvm::dim_t num_rows_r, | ||
const nnvm::dim_t num_cols_r) { | ||
using nnvm::dim_t; | ||
const dim_t seg_start = i * seg_len; | ||
if (seg_start >= num_rows_l) return; | ||
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); | ||
for (dim_t j = 0; j < num_rows_r; ++j) { | ||
if (indptr_r[j] == indptr_r[j+1]) continue; | ||
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) { | ||
const CType col_idx = col_idx_r[k]; | ||
const DType val = data_r[k]; | ||
for (dim_t r = seg_start; r < seg_end; ++r) { | ||
out[r*num_cols_r+col_idx] += data_l[r*num_cols_l+j] * val; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Kernel of dot(dns1, csr.T) = dns2 | ||
* Parallelization by row blocks | ||
*/ | ||
struct DotDnsCsrTransDnsByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please complete documentation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx, it's done. |
||
*/ | ||
template<typename DType, typename IType, typename CType> | ||
MSHADOW_CINLINE static void Map(int i, | ||
DType* out, | ||
const DType* data_l, | ||
const DType* data_r, | ||
const IType* indptr_r, | ||
const CType* col_idx_r, | ||
const nnvm::dim_t seg_len, | ||
const nnvm::dim_t num_rows_l, | ||
const nnvm::dim_t num_cols_l, | ||
const nnvm::dim_t num_rows_r, | ||
const nnvm::dim_t num_cols_r) { | ||
using nnvm::dim_t; | ||
const dim_t seg_start = i * seg_len; | ||
if (seg_start >= num_rows_l) return; | ||
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); | ||
for (dim_t j = 0; j < num_rows_r; ++j) { | ||
if (indptr_r[j] == indptr_r[j+1]) continue; | ||
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) { | ||
const CType col_idx = col_idx_r[k]; | ||
const DType val = data_r[k]; | ||
for (dim_t r = seg_start; r < seg_end; ++r) { | ||
out[r*num_rows_r+j] += data_l[r*num_cols_l+col_idx] * val; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2 | ||
|
@@ -1031,13 +1109,58 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, | |
} | ||
|
||
/* | ||
* \brief Impl of dot(dns, csr) = dense (GPU only) | ||
* \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns | ||
*/ | ||
inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev, | ||
const TBlob& dns, const NDArray& rhs, | ||
const OpReqType req, NDArray* ret, | ||
const bool transpose_b) { | ||
LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; | ||
const TBlob& dns, const NDArray& rhs, | ||
const OpReqType req, NDArray* ret, | ||
const bool transpose_b) { | ||
if (req == kNullOp) return; | ||
CHECK_EQ(rhs.storage_type(), kCSRStorage); | ||
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); | ||
if (!rhs.storage_initialized()) { | ||
FillZerosCsrImpl(s, *ret); | ||
return; | ||
} | ||
|
||
using nnvm::dim_t; | ||
|
||
const TBlob data_r = rhs.data(); | ||
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); | ||
const TBlob col_idx_r = rhs.aux_data(csr::kIdx); | ||
const TBlob& data_l = dns; | ||
const TBlob data_out = ret->data(); | ||
|
||
MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fp16 will fail with this branch. I think there's a MSHADOW_REAL_TYPE_SWITCH |
||
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type | ||
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type | ||
dim_t num_threads; | ||
if (req == kWriteTo) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. req == writeto || req == writeinplace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx, it's Done ! May you review it again :) |
||
num_threads = data_out.Size(); | ||
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch( | ||
s, num_threads, data_out.dptr<DType>()); | ||
} | ||
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]); | ||
// seg by output row | ||
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; | ||
if (transpose_b) { | ||
mxnet_op::Kernel<DotDnsCsrTransDnsByRowBlocks, cpu>::Launch(s, num_threads, | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), | ||
data_r.dptr<DType>(), indptr_r.dptr<IType>(), | ||
col_idx_r.dptr<CType>(), seg_len, | ||
dns.shape_[0], dns.shape_[1], | ||
rhs.shape()[0], rhs.shape()[1]); | ||
} else { | ||
mxnet_op::Kernel<DotDnsCsrDnsByRowBlocks, cpu>::Launch(s, num_threads, | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), | ||
data_r.dptr<DType>(), indptr_r.dptr<IType>(), | ||
col_idx_r.dptr<CType>(), seg_len, | ||
dns.shape_[0], dns.shape_[1], | ||
rhs.shape()[0], rhs.shape()[1]); | ||
} | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
inline bool DotShape(const nnvm::NodeAttrs& attrs, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update doc in https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/dot.cc#L63-L64