Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sgd with row_sparse weight, dns gradient #83

Merged
merged 5 commits into from
Jun 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mshadow
Submodule mshadow updated 2 files
+4 −0 mshadow/half.h
+8 −4 mshadow/half2.h
3 changes: 0 additions & 3 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
# state for the same index but on diff devs, TODO(mli)
# use a better solution later
w, g = p
# cast storage type if stype doesn't match
if g.storage_type != w.storage_type:
g = nd.cast_storage(g, w.storage_type)
updater(index*num_device+k, g, w)


Expand Down
9 changes: 9 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,15 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs,
CastNonDefaultStorage<xpu>(outputs, temp_out, ctx, true);
}

#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \
{ \
CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func \
<< " for RowSparse " << param << " is only implemented for " \
<< "RowSparse " << param << " with all rows containing non-zeros. " \
<< "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \
<< ") == " << param << ".shape[0] (" << rsp.shape()[0] << ")."; \
}


} // namespace op
} // namespace mxnet
Expand Down
276 changes: 209 additions & 67 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,29 +145,84 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
});
}

/*! \brief kernel for sparse sgd
*/
template<int req>
struct SGDRspDnsKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, const DType* weight,
const DType *grad, const DType clip_gradient, const DType lr,
const DType wd, const DType rescale_grad) {
bool contains_non_zeros = false;
index_t j = 0;
index_t offset = i * num_cols;
for (; j < num_cols; ++j) {
if (grad[offset + j] != 0) {
contains_non_zeros = true;
break;
}
}
if (!contains_non_zeros) return;
const DType rate = 1.f - lr * wd;
for (index_t j = 0; j < num_cols; j++) {
auto index = offset + j;
if (clip_gradient >= 0.0f) {
KERNEL_ASSIGN(out[index], req, rate * weight[index] -
lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient));
} else {
KERNEL_ASSIGN(out[index], req, rate * weight[index] -
lr * rescale_grad * grad[index]);
}
}
}
};

template<typename xpu>
inline void SGDUpdateRspDnsImpl(const SGDParam& param,
const OpContext &ctx,
const NDArray& weight,
const TBlob& grad,
const OpReqType req,
NDArray *out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
CHECK_EQ(weight.storage_type(), kRowSparseStorage);
if (req == kNullOp) return;
CHECK(weight.storage_initialized());
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.data().dptr<DType>();
auto grad_data = grad.dptr<DType>();
auto num_rows = weight.aux_shape(kIdx)[0];
auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim());
Kernel<SGDRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, num_cols,
out->data().dptr<DType>(), weight_data, grad_data,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
});
});
}

template<typename xpu>
inline void SGDUpdateRspRspImpl(const SGDParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const OpReqType& req,
NDArray *out) {
if (weight.storage_shape()[0] == weight.shape()[0] &&
out->storage_shape()[0] == out->shape()[0]) {
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented for "
<< "RowSparse weights with all rows containing non-zeros. "
<< "Expects weights.values.shape[0] (" << weight.storage_shape()[0]
<< ") == weights.shape[0] (" << weight.shape()[0] << ").";
}
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
}

template<typename xpu>
Expand All @@ -188,6 +243,9 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) {
NDArray out = outputs[0];
SGDUpdateRspDnsImpl<xpu>(param, ctx, inputs[0], inputs[1].data(), req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDUpdate<xpu>, "SGDUpdate");
}
Expand Down Expand Up @@ -267,21 +325,22 @@ struct SGDMomDnsRspDnsKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, size_t width, DType* out_data,
DType* mom_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
const DType param_lr, const DType param_wd, const DType param_rescale_grad) {
const DType* grad_data, const DType clip_gradient, const DType momentum,
const DType lr, const DType wd, const DType rescale_grad) {
const DType rate = lr * wd;
for (size_t j = 0; j < width; j++) {
uint64_t data_i = grad_idx[i] * width + j;
uint64_t grad_i = i * width + j;
if (param_clip_gradient >= 0.0f) {
mom_data[data_i] = param_momentum * mom_data[data_i]
- param_lr * param_wd * weight_data[data_i]
- param_lr *
mshadow_op::clip::Map(param_rescale_grad * grad_data[grad_i],
param_clip_gradient);
if (clip_gradient >= 0.0f) {
mom_data[data_i] = momentum * mom_data[data_i]
- rate * weight_data[data_i]
- lr *
mshadow_op::clip::Map(rescale_grad * grad_data[grad_i],
clip_gradient);
} else {
mom_data[data_i] = param_momentum * mom_data[data_i]
- param_lr * param_wd * weight_data[data_i]
- param_lr * param_rescale_grad * grad_data[grad_i];
mom_data[data_i] = momentum * mom_data[data_i]
- rate * weight_data[data_i]
- lr * rescale_grad * grad_data[grad_i];
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
}
Expand Down Expand Up @@ -323,6 +382,100 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
});
}

template<int req>
struct SGDMomRspDnsKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, DType* mom,
const DType* weight, const DType *grad,
const DType clip_gradient, const DType momentum,
const DType lr, const DType wd, const DType rescale_grad) {
bool contains_non_zeros = false;
index_t j = 0;
index_t offset = i * num_cols;
for (; j < num_cols; ++j) {
if (grad[offset + j] != 0) {
contains_non_zeros = true;
break;
}
}
if (!contains_non_zeros) return;
const DType rate = lr * wd;
for (index_t j = 0; j < num_cols; j++) {
auto index = offset + j;
if (clip_gradient >= 0.0f) {
mom[index] = momentum * mom[index] - rate * weight[index]
- lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient);
} else {
mom[index] = momentum * mom[index] - rate * weight[index]
- lr * rescale_grad * grad[index];
}
KERNEL_ASSIGN(out[index], req, weight[index] + mom[index]);
}
}
};

template<typename xpu>
inline void InitDnsZeros(mshadow::Stream<xpu> *s, NDArray *out) {
using namespace rowsparse;
using namespace mshadow::expr;
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(out->storage_type(), kRowSparseStorage);
MSHADOW_REAL_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
auto num_rows = out->shape()[0];
out->CheckAndAlloc({Shape1(num_rows)});
auto idx = out->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto val = out->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
});
});
}

template<typename xpu>
inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param,
const OpContext &ctx,
const NDArray& weight,
const TBlob& grad,
const NDArray& mom,
const OpReqType req,
NDArray *out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(weight.storage_type(), kRowSparseStorage);
if (req == kNullOp) return;
CHECK(weight.storage_initialized());
// fill mom with zero values if not initialized yet
if (!mom.storage_initialized()) {
NDArray mom_zeros = mom;
InitDnsZeros(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, {
auto weight_data = weight.data().dptr<DType>();
auto grad_data = grad.dptr<DType>();
auto mom_data = mom.data().dptr<DType>();
auto num_rows = weight.aux_shape(kIdx)[0];
auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim());
Kernel<SGDMomRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, num_cols,
out->data().dptr<DType>(), mom_data, weight_data, grad_data,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
});
});
}


template<typename xpu>
inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
const OpContext& ctx,
Expand All @@ -335,38 +488,22 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace rowsparse;
if (weight.storage_shape()[0] == weight.shape()[0] &&
out->storage_shape()[0] == out->shape()[0]) {
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mom with zero values in order to reuse the sgd mom dns impl
if (!mom.storage_initialized()) {
MSHADOW_REAL_TYPE_SWITCH(mom.dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(mom.aux_type(kIdx), IType, {
auto num_rows = mom.shape()[0];
mom.CheckAndAlloc({Shape1(num_rows)});
auto mom_idx = mom.aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto mom_val = mom.data();
// TODO(haibin) this is single-thread execution
Kernel<set_zero, xpu>::Launch(s, mom_val.Size(), mom_val.dptr<DType>());
ASSIGN_DISPATCH(mom_idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
});
});
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented for "
<< "RowSparse weights with all rows containing non-zeros. "
<< "Expects weights.values.shape[0] (" << weight.storage_shape()[0]
<< ") == weights.shape[0] (" << weight.shape()[0] << ").";
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mom with zero values in order to reuse the sgd mom dns impl
if (!mom.storage_initialized()) {
NDArray mom_zeros = mom;
InitDnsZeros(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), out_req, &out_blob);
}

template<typename xpu>
Expand All @@ -377,23 +514,28 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
using namespace mxnet_op;
const SGDMomParam& param = nnvm::get<SGDMomParam>(attrs.parsed);
auto weight_stype = inputs[0].storage_type();
auto grad_stype = inputs[1].storage_type();
auto mom_stype = inputs[2].storage_type();
auto &weight = inputs[0];
auto &grad = inputs[1];
auto &mom = inputs[2];
auto weight_stype = weight.storage_type();
auto grad_stype = grad.storage_type();
auto mom_stype = mom.storage_type();
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage &&
mom_stype == kDefaultStorage) {
TBlob out = outputs[0].data();
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
inputs[2].data(), req[0], &out);
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
mom_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1],
inputs[2], req[0], &out);
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage &&
mom_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDMomUpdateRspDnsImpl<xpu>(param, ctx, weight, grad.data(), mom, req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage &&
mom_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs,
SGDMomUpdate<xpu>, "SGDMomUpdate");
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDMomUpdate<xpu>, "SGDMomUpdate");
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ It updates the weights using::

weight = weight - learning_rate * gradient

If gradients are stored with `row_sparse` storage,
where update is applied only to rows whose gradient has non-zero entries.
If weights are stored with `row_sparse` storage,
update is applied only to rows whose gradient has non-zero entries.

)code" ADD_FILELINE)
.set_num_inputs(2)
Expand Down Expand Up @@ -56,7 +56,7 @@ It updates the weights using::

Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.

If gradients are stored with `row_sparse` storage,
If weights are stored with `row_sparse` storage,
only rows whose gradients contain non-zero entries are updated (for both weight and momentum).

)code" ADD_FILELINE)
Expand Down
Loading