forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#8 from javelinjs/scala-package-cc
NDArray functions and SGD optimizer
- Loading branch information
Showing
7 changed files
with
160 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,4 +51,3 @@ class FactorScheduler(protected var step: Int, protected var factor: Float) exte | |
this.baseLR | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package ml.dmlc.mxnet.optimizer | ||
|
||
import ml.dmlc.mxnet.{Optimizer, LRScheduler, NDArray} | ||
import ml.dmlc.mxnet.NDArrayConversions._ | ||
|
||
/** | ||
* A very simple SGD optimizer with momentum and weight regularization. | ||
* @author Yizhi Liu | ||
*/ | ||
class SGD(val learningRate: Float = 0.01f, val momentum: Float = 0.0f, | ||
val wd: Float = 0.0001f, rescaleGrad: Float = 1f, val clipGradient: Float = 0f, | ||
val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) { | ||
/** | ||
* Update the parameters. | ||
* @param index An unique integer key used to index the parameters | ||
* @param weight weight ndarray | ||
* @param grad grad ndarray | ||
* @param state NDArray or other objects returned by initState | ||
* The auxiliary state used in optimization. | ||
*/ | ||
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { | ||
// TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python package) | ||
val lr = | ||
(if (lrScheduler != null) { | ||
val scheduledLr = lrScheduler(numUpdate) | ||
updateCount(index) | ||
scheduledLr | ||
} else { | ||
this.learningRate | ||
}) * lrScale.getOrElse(index, 1f) | ||
|
||
var resdGrad = grad * rescaleGrad | ||
if (clipGradient != 0f) { | ||
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) | ||
} | ||
if (state != null) { | ||
val mom = state.asInstanceOf[NDArray] | ||
mom *= momentum | ||
mom += -lr * (grad + wd * weight) | ||
weight += mom | ||
} else { | ||
require(momentum == 0f) | ||
weight += -lr * (grad + wd * weight) | ||
} | ||
} | ||
|
||
// Create additional optimizer state such as momentum. | ||
override def createState(index: Int, weight: NDArray): AnyRef = { | ||
if (momentum == 0.0f) { | ||
null | ||
} else { | ||
NDArray.zeros(weight.shape, weight.context) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters