This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU #11113
Merged
eric-haibin-lin
merged 5 commits into
apache:master
from
XiaotaoChen:cxt-dns_csr/csr.T_dns
Jun 10, 2018
Merged
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
dd989d6
implement dot(dns, csr/csr.T)=dns on cpu
XiaotaoChen 8f53309
complete documentaion related to dot(dns, csr/csr.T)=dns on cpu
XiaotaoChen 4a7389c
Merge branch 'master' into cxt-dns_csr/csr.T_dns
XiaotaoChen 08394b4
Merge branch 'master' into cxt-dns_csr/csr.T_dns
XiaotaoChen 84cc230
support fp16 by replacing MSHADOW_SGL_DBL_TYPE_SWITCH with MSHADOW_RE…
XiaotaoChen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, | |
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && | ||
!param.transpose_a) { | ||
target_stype = hint_has_value ? target_stype : kCSRStorage; | ||
// dns, csr -> csr on CPU | ||
if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { | ||
if (target_stype == kCSRStorage) { | ||
if (dev_mask == mshadow::cpu::kDevMask) { | ||
// dns, csr -> csr on CPU | ||
if (target_stype == kCSRStorage && !param.transpose_b) { | ||
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, | ||
DispatchMode::kFComputeEx); | ||
// dns, csr/csr.T -> dns on CPU | ||
} else if (target_stype == kDefaultStorage) { | ||
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, | ||
DispatchMode::kFComputeEx); | ||
} | ||
// dns, csr/csr.T -> dns on GPU | ||
} else if (dev_mask == mshadow::gpu::kDevMask) { | ||
|
@@ -327,7 +331,7 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, | |
dispatched = true; | ||
} | ||
} | ||
if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && | ||
if (!dispatched && !param.transpose_a && | ||
lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && | ||
ograd_stype == kDefaultStorage) { | ||
if (type_assign(&lhs_grad_stype, kDefaultStorage) && | ||
|
@@ -655,7 +659,101 @@ struct DotDnsCsrCsrByRowBlocks { | |
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Kernel of dot(dns1, csr) = dns2 | ||
* Parallelization by row blocks | ||
*/ | ||
struct DotDnsCsrDnsByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
* \param out output matrix | ||
* \param data_l data of lhs | ||
* \param data_r values of csr | ||
* \param indptr_r row offsets of csr | ||
* \param col_idx_r column indices of csr | ||
* \param seg_len workload of this thread | ||
* \param num_rows_l number of rows in lhs | ||
* \param num_cols_l number of columns in lhs | ||
* \param num_rows_r number of rows in rhs | ||
* \param num_cols_r number of columns in rhs | ||
*/ | ||
template<typename DType, typename IType, typename CType> | ||
MSHADOW_CINLINE static void Map(int i, | ||
DType* out, | ||
const DType* data_l, | ||
const DType* data_r, | ||
const IType* indptr_r, | ||
const CType* col_idx_r, | ||
const nnvm::dim_t seg_len, | ||
const nnvm::dim_t num_rows_l, | ||
const nnvm::dim_t num_cols_l, | ||
const nnvm::dim_t num_rows_r, | ||
const nnvm::dim_t num_cols_r) { | ||
using nnvm::dim_t; | ||
const dim_t seg_start = i * seg_len; | ||
if (seg_start >= num_rows_l) return; | ||
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); | ||
for (dim_t j = 0; j < num_rows_r; ++j) { | ||
if (indptr_r[j] == indptr_r[j+1]) continue; | ||
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) { | ||
const CType col_idx = col_idx_r[k]; | ||
const DType val = data_r[k]; | ||
for (dim_t r = seg_start; r < seg_end; ++r) { | ||
out[r*num_cols_r+col_idx] += data_l[r*num_cols_l+j] * val; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Kernel of dot(dns1, csr.T) = dns2 | ||
* Parallelization by row blocks | ||
*/ | ||
struct DotDnsCsrTransDnsByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
* \param out output matrix | ||
* \param data_l data of lhs | ||
* \param data_r values of csr | ||
* \param indptr_r row offsets of csr | ||
* \param col_idx_r column indices of csr | ||
* \param seg_len workload of this thread | ||
* \param num_rows_l number of rows in lhs | ||
* \param num_cols_l number of columns in lhs | ||
* \param num_rows_r number of rows in rhs | ||
* \param num_cols_r number of columns in rhs | ||
*/ | ||
template<typename DType, typename IType, typename CType> | ||
MSHADOW_CINLINE static void Map(int i, | ||
DType* out, | ||
const DType* data_l, | ||
const DType* data_r, | ||
const IType* indptr_r, | ||
const CType* col_idx_r, | ||
const nnvm::dim_t seg_len, | ||
const nnvm::dim_t num_rows_l, | ||
const nnvm::dim_t num_cols_l, | ||
const nnvm::dim_t num_rows_r, | ||
const nnvm::dim_t num_cols_r) { | ||
using nnvm::dim_t; | ||
const dim_t seg_start = i * seg_len; | ||
if (seg_start >= num_rows_l) return; | ||
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); | ||
for (dim_t j = 0; j < num_rows_r; ++j) { | ||
if (indptr_r[j] == indptr_r[j+1]) continue; | ||
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) { | ||
const CType col_idx = col_idx_r[k]; | ||
const DType val = data_r[k]; | ||
for (dim_t r = seg_start; r < seg_end; ++r) { | ||
out[r*num_rows_r+j] += data_l[r*num_cols_l+col_idx] * val; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2 | ||
|
@@ -1031,13 +1129,58 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, | |
} | ||
|
||
/* | ||
* \brief Impl of dot(dns, csr) = dense (GPU only) | ||
* \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns | ||
*/ | ||
inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev, | ||
const TBlob& dns, const NDArray& rhs, | ||
const OpReqType req, NDArray* ret, | ||
const bool transpose_b) { | ||
LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; | ||
const TBlob& dns, const NDArray& rhs, | ||
const OpReqType req, NDArray* ret, | ||
const bool transpose_b) { | ||
if (req == kNullOp) return; | ||
CHECK_EQ(rhs.storage_type(), kCSRStorage); | ||
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); | ||
if (!rhs.storage_initialized()) { | ||
FillZerosCsrImpl(s, *ret); | ||
return; | ||
} | ||
|
||
using nnvm::dim_t; | ||
|
||
const TBlob data_r = rhs.data(); | ||
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); | ||
const TBlob col_idx_r = rhs.aux_data(csr::kIdx); | ||
const TBlob& data_l = dns; | ||
const TBlob data_out = ret->data(); | ||
|
||
MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fp16 will fail with this branch. I think there's a MSHADOW_REAL_TYPE_SWITCH |
||
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type | ||
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type | ||
dim_t num_threads; | ||
if (req == kWriteTo) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. req == writeto || req == writeinplace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx, it's Done ! May you review it again :) |
||
num_threads = data_out.Size(); | ||
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch( | ||
s, num_threads, data_out.dptr<DType>()); | ||
} | ||
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]); | ||
// seg by output row | ||
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; | ||
if (transpose_b) { | ||
mxnet_op::Kernel<DotDnsCsrTransDnsByRowBlocks, cpu>::Launch(s, num_threads, | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), | ||
data_r.dptr<DType>(), indptr_r.dptr<IType>(), | ||
col_idx_r.dptr<CType>(), seg_len, | ||
dns.shape_[0], dns.shape_[1], | ||
rhs.shape()[0], rhs.shape()[1]); | ||
} else { | ||
mxnet_op::Kernel<DotDnsCsrDnsByRowBlocks, cpu>::Launch(s, num_threads, | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), | ||
data_r.dptr<DType>(), indptr_r.dptr<IType>(), | ||
col_idx_r.dptr<CType>(), seg_len, | ||
dns.shape_[0], dns.shape_[1], | ||
rhs.shape()[0], rhs.shape()[1]); | ||
} | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
inline bool DotShape(const nnvm::NodeAttrs& attrs, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update doc in https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/dot.cc#L63-L64