Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU #11113

Merged
merged 5 commits into from
Jun 10, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 152 additions & 9 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a) {
target_stype = hint_has_value ? target_stype : kCSRStorage;
// dns, csr -> csr on CPU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) {
if (target_stype == kCSRStorage) {
if (dev_mask == mshadow::cpu::kDevMask) {
// dns, csr -> csr on CPU
if (target_stype == kCSRStorage && !param.transpose_b) {
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
DispatchMode::kFComputeEx);
// dns, csr/csr.T -> dns on CPU
} else if (target_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
// dns, csr/csr.T -> dns on GPU
} else if (dev_mask == mshadow::gpu::kDevMask) {
Expand Down Expand Up @@ -327,7 +331,7 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = true;
}
}
if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a &&
if (!dispatched && !param.transpose_a &&
lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
ograd_stype == kDefaultStorage) {
if (type_assign(&lhs_grad_stype, kDefaultStorage) &&
Expand Down Expand Up @@ -655,7 +659,101 @@ struct DotDnsCsrCsrByRowBlocks {
}
};

/*!
* \brief CPU Kernel of dot(dns1, csr) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param out output matrix
* \param data_l data of lhs
* \param data_r values of csr
* \param indptr_r row offsets of csr
* \param col_idx_r column indices of csr
* \param seg_len workload of this thread
* \param num_rows_l number of rows in lhs
* \param num_cols_l number of columns in lhs
* \param num_rows_r number of rows in rhs
* \param num_cols_r number of columns in rhs
*/
template<typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j+1]) continue;
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r*num_cols_r+col_idx] += data_l[r*num_cols_l+j] * val;
}
}
}
}
};

/*!
* \brief CPU Kernel of dot(dns1, csr.T) = dns2
* Parallelization by row blocks
*/
struct DotDnsCsrTransDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param out output matrix
* \param data_l data of lhs
* \param data_r values of csr
* \param indptr_r row offsets of csr
* \param col_idx_r column indices of csr
* \param seg_len workload of this thread
* \param num_rows_l number of rows in lhs
* \param num_cols_l number of columns in lhs
* \param num_rows_r number of rows in rhs
* \param num_cols_r number of columns in rhs
*/
template<typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i,
DType* out,
const DType* data_l,
const DType* data_r,
const IType* indptr_r,
const CType* col_idx_r,
const nnvm::dim_t seg_len,
const nnvm::dim_t num_rows_l,
const nnvm::dim_t num_cols_l,
const nnvm::dim_t num_rows_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
for (dim_t j = 0; j < num_rows_r; ++j) {
if (indptr_r[j] == indptr_r[j+1]) continue;
for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
const CType col_idx = col_idx_r[k];
const DType val = data_r[k];
for (dim_t r = seg_start; r < seg_end; ++r) {
out[r*num_rows_r+j] += data_l[r*num_cols_l+col_idx] * val;
}
}
}
}
};

/*!
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
Expand Down Expand Up @@ -1031,13 +1129,58 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
}

/*
* \brief Impl of dot(dns, csr) = dense (GPU only)
* \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
*/
inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev,
const TBlob& dns, const NDArray& rhs,
const OpReqType req, NDArray* ret,
const bool transpose_b) {
LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU";
const TBlob& dns, const NDArray& rhs,
const OpReqType req, NDArray* ret,
const bool transpose_b) {
if (req == kNullOp) return;
CHECK_EQ(rhs.storage_type(), kCSRStorage);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
return;
}

using nnvm::dim_t;

const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
const TBlob& data_l = dns;
const TBlob data_out = ret->data();

MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp16 will fail with this branch. I think there's a MSHADOW_REAL_TYPE_SWITCH

MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type
dim_t num_threads;
if (req == kWriteTo) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

req == writeto || req == writeinplace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, it's Done ! May you review it again :)

num_threads = data_out.Size();
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
s, num_threads, data_out.dptr<DType>());
}
num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
// seg by output row
dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (transpose_b) {
mxnet_op::Kernel<DotDnsCsrTransDnsByRowBlocks, cpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(),
data_r.dptr<DType>(), indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(), seg_len,
dns.shape_[0], dns.shape_[1],
rhs.shape()[0], rhs.shape()[1]);
} else {
mxnet_op::Kernel<DotDnsCsrDnsByRowBlocks, cpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(),
data_r.dptr<DType>(), indptr_r.dptr<IType>(),
col_idx_r.dptr<CType>(), seg_len,
dns.shape_[0], dns.shape_[1],
rhs.shape()[0], rhs.shape()[1]);
}
});
});
});
}

inline bool DotShape(const nnvm::NodeAttrs& attrs,
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ forward_stype option for output storage type. Implemented sparse operations incl
- dot(csr, default) = default
- dot(csr, row_sparse) = default
- dot(default, csr) = csr (CPU only)
- dot(default, csr, forward_stype='default') = default (GPU only)
- dot(default, csr, transpose_b=True, forward_stype='default') = default (GPU only)
- dot(default, csr, forward_stype='default') = default
- dot(default, csr, transpose_b=True, forward_stype='default') = default

If the combination of input storage types and forward_stype does not match any of the
above patterns, ``dot`` will fallback and generate output with default storage.
Expand Down