Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

documentation changes. added full reference #12153

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -15,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=too-many-lines
"""Weight updating functions."""
import logging
Expand Down Expand Up @@ -548,15 +548,19 @@ def update_multi_precision(self, index, weight, grad, state):

@register
class Signum(Optimizer):
"""The Signum optimizer that takes the sign of gradient or momentum.
r"""The Signum optimizer that takes the sign of gradient or momentum.

The optimizer updates the weight by::

rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)

See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
Reference:
Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli & Anima Anandkumar. (2018).
signSGD: Compressed Optimisation for Non-Convex Problems. In ICML'18.

See: https://arxiv.org/abs/1802.04434

For details of the update algorithm see
:class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
Expand Down