Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
standard sgd_update (#10614)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueHuang authored and eric-haibin-lin committed Apr 20, 2018
1 parent 615392b commit 9ccd647
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False):
sgd_mom_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight,
sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
lr=lr, wd=wd, **kwargs)
else:
if state[0] is not None:
Expand Down
15 changes: 13 additions & 2 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
float wd;
float rescale_grad;
float clip_gradient;
bool lazy_update;
DMLC_DECLARE_PARAMETER(SGDParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
Expand All @@ -63,6 +64,9 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(lazy_update)
.set_default(true)
.describe("If true, lazy updates are applied.");
}
};

Expand Down Expand Up @@ -177,14 +181,21 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
// if gradients are zeros, no weights are updated
if (!grad.storage_initialized() || req == kNullOp) return;
if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
float wd = param.wd;
if (!param.lazy_update) {
Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(),
weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd));
wd = 0;
}
if (!grad.storage_initialized()) return;
const IType* grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
const DType* grad_val = grad.data().dptr<DType>();
const nnvm::dim_t num_rows = grad.aux_shape(rowsparse::kIdx)[0];
Expand All @@ -196,7 +207,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
Kernel<SGDDnsRspKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length,
out->dptr<DType>(), weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.lr), static_cast<DType>(wd),
static_cast<DType>(param.rescale_grad));
});
});
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_std_sparse_sgd():
opt1 = PySGD
opt2 = mx.optimizer.SGD
shape = (3, 4, 5)
mom_options = [{'momentum': 0.9}]
mom_options = [{'momentum': 0.0}, {'momentum': 0.9}]
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
Expand Down

0 comments on commit 9ccd647

Please sign in to comment.