Skip to content

Commit

Permalink
move init dns zeros to init_op.h for kvstore to use (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Jun 22, 2017
1 parent 1914471 commit ad8e74c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
24 changes: 3 additions & 21 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "./mshadow_op.h"
#include "./elemwise_op_common.h"
#include "mxnet_op.h"
#include "./tensor/init_op.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -414,25 +415,6 @@ struct SGDMomRspDnsKernel {
}
};

template<typename xpu>
inline void InitDnsZeros(mshadow::Stream<xpu> *s, NDArray *out) {
using namespace rowsparse;
using namespace mshadow::expr;
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(out->storage_type(), kRowSparseStorage);
MSHADOW_REAL_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
auto num_rows = out->shape()[0];
out->CheckAndAlloc({Shape1(num_rows)});
auto idx = out->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto val = out->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
});
});
}

template<typename xpu>
inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param,
const OpContext &ctx,
Expand All @@ -452,7 +434,7 @@ inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param,
// fill mom with zero values if not initialized yet
if (!mom.storage_initialized()) {
NDArray mom_zeros = mom;
InitDnsZeros(s, &mom_zeros);
FillDnsZerosRspImpl(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
Expand Down Expand Up @@ -493,7 +475,7 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
// fill mom with zero values in order to reuse the sgd mom dns impl
if (!mom.storage_initialized()) {
NDArray mom_zeros = mom;
InitDnsZeros(s, &mom_zeros);
FillDnsZerosRspImpl(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
Expand Down
21 changes: 21 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,27 @@ void FillCompute(const nnvm::NodeAttrs& attrs,
});
}

// Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
// instead of the usual compact representation.
template<typename xpu>
inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
using namespace rowsparse;
using namespace mshadow::expr;
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
MSHADOW_REAL_TYPE_SWITCH(dst->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
auto num_rows = dst->shape()[0];
dst->CheckAndAlloc({Shape1(num_rows)});
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto val = dst->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
});
});
}

// Fill a rsp NDArray with zeros by updating the aux shape.
template<typename xpu>
void FillZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
Expand Down

0 comments on commit ad8e74c

Please sign in to comment.