diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index d4b17074d48c..6d414bb0328a 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -34,7 +34,7 @@ object BucketIo { type ReadContent = String => String def defaultReadContent(path: String): String = { - Source.fromFile(path).mkString.replaceAll("\\. |\n", " ") + Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " ") } def defaultBuildVocab(path: String): Map[String, Int] = { @@ -56,7 +56,7 @@ object BucketIo { val tmp = sentence.split(" ").filter(_.length() > 0) for (w <- tmp) yield theVocab(w) } - words.toArray + words } def defaultGenBuckets(sentences: Array[String], batchSize: Int, @@ -162,8 +162,6 @@ object BucketIo { labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket))) } - private val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2)) - private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey)) tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2)) } @@ -208,12 +206,13 @@ object BucketIo { tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2)) } val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape) - new DataBatch(IndexedSeq(dataBuf) ++ initStateArrays, - IndexedSeq(labelBuf), - getIndex(), - getPad(), - this.buckets(bucketIdx).asInstanceOf[AnyRef], - batchProvideData, batchProvideLabel) + val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2)) + new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays, + IndexedSeq(labelBuf.copy()), + getIndex(), + getPad(), + this.buckets(bucketIdx).asInstanceOf[AnyRef], + batchProvideData, batchProvideLabel) } /** diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala index bf29a47fcf81..872ef7871fb0 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala @@ -18,13 +18,10 @@ package org.apache.mxnetexamples.rnn -import org.apache.mxnet.Symbol +import org.apache.mxnet.{Shape, Symbol} import scala.collection.mutable.ArrayBuffer -/** - * @author Depeng Liang - */ object Lstm { final case class LSTMState(c: Symbol, h: Symbol) @@ -35,27 +32,22 @@ object Lstm { def lstm(numHidden: Int, inData: Symbol, prevState: LSTMState, param: LSTMParam, seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = { val inDataa = { - if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout)) + if (dropout > 0f) Symbol.api.Dropout(data = Some(inData), p = Some(dropout)) else inData } - val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa, - "weight" -> param.i2hWeight, - "bias" -> param.i2hBias, - "num_hidden" -> numHidden * 4)) - val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h, - "weight" -> param.h2hWeight, - "bias" -> param.h2hBias, - "num_hidden" -> numHidden * 4)) + val i2h = Symbol.api.FullyConnected(data = Some(inDataa), weight = Some(param.i2hWeight), + bias = Some(param.i2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_i2h") + val h2h = Symbol.api.FullyConnected(data = Some(prevState.h), weight = Some(param.h2hWeight), + bias = Some(param.h2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_h2h") val gates = i2h + h2h - val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")( - gates)(Map("num_outputs" -> 4)) - val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid")) - val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh")) - val forgetGate = Symbol.Activation()()( - Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid")) - val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid")) + val sliceGates = Symbol.api.SliceChannel(data = Some(gates), num_outputs = 4, + name = s"t${seqIdx}_l${layerIdx}_slice") + val ingate = Symbol.api.Activation(data = Some(sliceGates.get(0)), act_type = "sigmoid") + val inTransform = Symbol.api.Activation(data = Some(sliceGates.get(1)), act_type = "tanh") + val forgetGate = Symbol.api.Activation(data = Some(sliceGates.get(2)), act_type = "sigmoid") + val outGate = Symbol.api.Activation(data = Some(sliceGates.get(3)), act_type = "sigmoid") val nextC = (forgetGate * prevState.c) + (ingate * inTransform) - val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh")) + val nextH = outGate * Symbol.api.Activation(data = Some(nextC), "tanh") LSTMState(c = nextC, h = nextH) } @@ -74,11 +66,11 @@ object Lstm { val lastStatesBuf = ArrayBuffer[LSTMState]() for (i <- 0 until numLstmLayer) { paramCellsBuf.append(LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"), - i2hBias = Symbol.Variable(s"l${i}_i2h_bias"), - h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"), - h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))) + i2hBias = Symbol.Variable(s"l${i}_i2h_bias"), + h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"), + h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))) lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"), - h = Symbol.Variable(s"l${i}_init_h_beta"))) + h = Symbol.Variable(s"l${i}_init_h_beta"))) } val paramCells = paramCellsBuf.toArray val lastStates = lastStatesBuf.toArray @@ -87,10 +79,10 @@ object Lstm { // embeding layer val data = Symbol.Variable("data") var label = Symbol.Variable("softmax_label") - val embed = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize, - "weight" -> embedWeight, "output_dim" -> numEmbed)) - val wordvec = Symbol.SliceChannel()()( - Map("data" -> embed, "num_outputs" -> seqLen, "squeeze_axis" -> 1)) + val embed = Symbol.api.Embedding(data = Some(data), input_dim = inputSize, + weight = Some(embedWeight), output_dim = numEmbed, name = "embed") + val wordvec = Symbol.api.SliceChannel(data = Some(embed), + num_outputs = seqLen, squeeze_axis = Some(true)) val hiddenAll = ArrayBuffer[Symbol]() var dpRatio = 0f @@ -101,22 +93,23 @@ object Lstm { for (i <- 0 until numLstmLayer) { if (i == 0) dpRatio = 0f else dpRatio = dropout val nextState = lstm(numHidden, inData = hidden, - prevState = lastStates(i), - param = paramCells(i), - seqIdx = seqIdx, layerIdx = i, dropout = dpRatio) + prevState = lastStates(i), + param = paramCells(i), + seqIdx = seqIdx, layerIdx = i, dropout = dpRatio) hidden = nextState.h lastStates(i) = nextState } // decoder - if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout)) + if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout)) hiddenAll.append(hidden) } - val hiddenConcat = Symbol.Concat()(hiddenAll: _*)(Map("dim" -> 0)) - val pred = Symbol.FullyConnected("pred")()(Map("data" -> hiddenConcat, "num_hidden" -> numLabel, - "weight" -> clsWeight, "bias" -> clsBias)) - label = Symbol.transpose()(label)() - label = Symbol.Reshape()()(Map("data" -> label, "target_shape" -> "(0,)")) - val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> pred, "label" -> label)) + val hiddenConcat = Symbol.api.Concat(data = hiddenAll.toArray, num_args = hiddenAll.length, + dim = Some(0)) + val pred = Symbol.api.FullyConnected(data = Some(hiddenConcat), num_hidden = numLabel, + weight = Some(clsWeight), bias = Some(clsBias)) + label = Symbol.api.transpose(data = Some(label)) + label = Symbol.api.Reshape(data = Some(label), target_shape = Some(Shape(0))) + val sm = Symbol.api.SoftmaxOutput(data = Some(pred), label = Some(label), name = "softmax") sm } @@ -131,35 +124,35 @@ object Lstm { var lastStates = Array[LSTMState]() for (i <- 0 until numLstmLayer) { paramCells = paramCells :+ LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"), - i2hBias = Symbol.Variable(s"l${i}_i2h_bias"), - h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"), - h2hBias = Symbol.Variable(s"l${i}_h2h_bias")) + i2hBias = Symbol.Variable(s"l${i}_i2h_bias"), + h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"), + h2hBias = Symbol.Variable(s"l${i}_h2h_bias")) lastStates = lastStates :+ LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"), - h = Symbol.Variable(s"l${i}_init_h_beta")) + h = Symbol.Variable(s"l${i}_init_h_beta")) } assert(lastStates.length == numLstmLayer) val data = Symbol.Variable("data") - var hidden = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize, - "weight" -> embedWeight, "output_dim" -> numEmbed)) + var hidden = Symbol.api.Embedding(data = Some(data), input_dim = inputSize, + weight = Some(embedWeight), output_dim = numEmbed, name = "embed") var dpRatio = 0f // stack LSTM for (i <- 0 until numLstmLayer) { if (i == 0) dpRatio = 0f else dpRatio = dropout val nextState = lstm(numHidden, inData = hidden, - prevState = lastStates(i), - param = paramCells(i), - seqIdx = seqIdx, layerIdx = i, dropout = dpRatio) + prevState = lastStates(i), + param = paramCells(i), + seqIdx = seqIdx, layerIdx = i, dropout = dpRatio) hidden = nextState.h lastStates(i) = nextState } // decoder - if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout)) - val fc = Symbol.FullyConnected("pred")()(Map("data" -> hidden, "num_hidden" -> numLabel, - "weight" -> clsWeight, "bias" -> clsBias)) - val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> fc)) + if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout)) + val fc = Symbol.api.FullyConnected(data = Some(hidden), + num_hidden = numLabel, weight = Some(clsWeight), bias = Some(clsBias)) + val sm = Symbol.api.SoftmaxOutput(data = Some(fc), name = "softmax") var output = Array(sm) for (state <- lastStates) { output = output :+ state.c diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala index 44ee6e778d27..f7a01bad133a 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala @@ -30,9 +30,8 @@ import org.apache.mxnet.module.BucketingModule import org.apache.mxnet.module.FitParams /** - * Bucketing LSTM examples - * @author Yizhi Liu - */ + * Bucketing LSTM examples + */ class LstmBucketing { @Option(name = "--data-train", usage = "training set") private val dataTrain: String = "example/rnn/sherlockholmes.train.txt" @@ -61,6 +60,60 @@ object LstmBucketing { Math.exp(loss / labelArr.length).toFloat } + def runTraining(trainData : String, validationData : String, + ctx : Array[Context], numEpoch : Int): Unit = { + val batchSize = 32 + val buckets = Array(10, 20, 30, 40, 50, 60) + val numHidden = 200 + val numEmbed = 200 + val numLstmLayer = 2 + + logger.info("Building vocab ...") + val vocab = BucketIo.defaultBuildVocab(trainData) + + def BucketSymGen(key: AnyRef): + (Symbol, IndexedSeq[String], IndexedSeq[String]) = { + val seqLen = key.asInstanceOf[Int] + val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size, + numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size) + (sym, IndexedSeq("data"), IndexedSeq("softmax_label")) + } + + val initC = (0 until numLstmLayer).map(l => + (s"l${l}_init_c_beta", (batchSize, numHidden)) + ) + val initH = (0 until numLstmLayer).map(l => + (s"l${l}_init_h_beta", (batchSize, numHidden)) + ) + val initStates = initC ++ initH + + val dataTrain = new BucketSentenceIter(trainData, vocab, + buckets, batchSize, initStates) + val dataVal = new BucketSentenceIter(validationData, vocab, + buckets, batchSize, initStates) + + val model = new BucketingModule( + symGen = BucketSymGen, + defaultBucketKey = dataTrain.defaultBucketKey, + contexts = ctx) + + val fitParams = new FitParams() + fitParams.setEvalMetric( + new CustomMetric(perplexity, name = "perplexity")) + fitParams.setKVStore("device") + fitParams.setOptimizer( + new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f)) + fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f)) + fitParams.setBatchEndCallback(new Speedometer(batchSize, 50)) + + logger.info("Start training ...") + model.fit( + trainData = dataTrain, + evalData = Some(dataVal), + numEpoch = numEpoch, fitParams) + logger.info("Finished training...") + } + def main(args: Array[String]): Unit = { val inst = new LstmBucketing val parser: CmdLineParser = new CmdLineParser(inst) @@ -71,56 +124,7 @@ object LstmBucketing { else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt)) else Array(Context.cpu(0)) - val batchSize = 32 - val buckets = Array(10, 20, 30, 40, 50, 60) - val numHidden = 200 - val numEmbed = 200 - val numLstmLayer = 2 - - logger.info("Building vocab ...") - val vocab = BucketIo.defaultBuildVocab(inst.dataTrain) - - def BucketSymGen(key: AnyRef): - (Symbol, IndexedSeq[String], IndexedSeq[String]) = { - val seqLen = key.asInstanceOf[Int] - val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size, - numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size) - (sym, IndexedSeq("data"), IndexedSeq("softmax_label")) - } - - val initC = (0 until numLstmLayer).map(l => - (s"l${l}_init_c_beta", (batchSize, numHidden)) - ) - val initH = (0 until numLstmLayer).map(l => - (s"l${l}_init_h_beta", (batchSize, numHidden)) - ) - val initStates = initC ++ initH - - val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab, - buckets, batchSize, initStates) - val dataVal = new BucketSentenceIter(inst.dataVal, vocab, - buckets, batchSize, initStates) - - val model = new BucketingModule( - symGen = BucketSymGen, - defaultBucketKey = dataTrain.defaultBucketKey, - contexts = contexts) - - val fitParams = new FitParams() - fitParams.setEvalMetric( - new CustomMetric(perplexity, name = "perplexity")) - fitParams.setKVStore("device") - fitParams.setOptimizer( - new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f)) - fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f)) - fitParams.setBatchEndCallback(new Speedometer(batchSize, 50)) - - logger.info("Start training ...") - model.fit( - trainData = dataTrain, - evalData = Some(dataVal), - numEpoch = inst.numEpoch, fitParams) - logger.info("Finished training...") + runTraining(inst.dataTrain, inst.dataVal, contexts, 5) } catch { case ex: Exception => logger.error(ex.getMessage, ex) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md new file mode 100644 index 000000000000..5289fc7b1b4e --- /dev/null +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md @@ -0,0 +1,48 @@ +# RNN Example for MXNet Scala +This folder contains the following examples writing in new Scala type-safe API: +- [x] LSTM Bucketing +- [x] CharRNN Inference : Generate similar text based on the model +- [x] CharRNN Training: Training the language model using RNN + +These example is only for Illustration and not modeled to achieve the best accuracy. + +## Setup +### Download the Network Definition, Weights and Training Data +`obama.zip` contains the training inputs (Obama's speech) for CharCNN examples and `sherlockholmes` contains the data for LSTM Bucketing +```bash +https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/obama.zip +https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.train.txt +https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.valid.txt +``` +### Unzip the file +```bash +unzip obama.zip +``` +### Arguement Configuration +Then you need to define the arguments that you would like to pass in the model: + +#### LSTM Bucketing +```bash +--data-train +/sherlockholmes.train.txt +--data-val +/sherlockholmes.valid.txt +--cpus + +--gpus + +``` +#### TrainCharRnn +```bash +--data-path +/obama.txt +--save-model-path +/ +``` +#### TestCharRnn +```bash +--data-path +/obama.txt +--model-prefix +/obama +``` \ No newline at end of file diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala index 243b70c0670d..4786d5d59535 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala @@ -25,66 +25,68 @@ import scala.collection.JavaConverters._ /** * Follows the demo, to test the char rnn: * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb - * @author Depeng Liang */ object TestCharRnn { private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn]) - def main(args: Array[String]): Unit = { - val stcr = new TestCharRnn - val parser: CmdLineParser = new CmdLineParser(stcr) - try { - parser.parseArgument(args.toList.asJava) - assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null) + def runTestCharRNN(dataPath: String, modelPrefix: String, starterSentence : String): Unit = { + // The batch size for training + val batchSize = 32 + // We can support various length input + // For this problem, we cut each input sentence to length of 129 + // So we only need fix length bucket + val buckets = List(129) + // hidden unit in LSTM cell + val numHidden = 512 + // embedding dimension, which is, map a char to a 256 dim vector + val numEmbed = 256 + // number of lstm layer + val numLstmLayer = 3 - // The batch size for training - val batchSize = 32 - // We can support various length input - // For this problem, we cut each input sentence to length of 129 - // So we only need fix length bucket - val buckets = List(129) - // hidden unit in LSTM cell - val numHidden = 512 - // embedding dimension, which is, map a char to a 256 dim vector - val numEmbed = 256 - // number of lstm layer - val numLstmLayer = 3 + // build char vocabluary from input + val vocab = Utils.buildVocab(dataPath) - // build char vocabluary from input - val vocab = Utils.buildVocab(stcr.dataPath) + // load from check-point + val (_, argParams, _) = Model.loadCheckpoint(modelPrefix, 75) - // load from check-point - val (_, argParams, _) = Model.loadCheckpoint(stcr.modelPrefix, 75) + // build an inference model + val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1, + numHidden = numHidden, numEmbed = numEmbed, + numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f) - // build an inference model - val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1, - numHidden = numHidden, numEmbed = numEmbed, - numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f) + // generate a sequence of 1200 chars + val seqLength = 1200 + val inputNdarray = NDArray.zeros(1) + val revertVocab = Utils.makeRevertVocab(vocab) - // generate a sequence of 1200 chars - val seqLength = 1200 - val inputNdarray = NDArray.zeros(1) - val revertVocab = Utils.makeRevertVocab(vocab) + // Feel free to change the starter sentence + var output = starterSentence + val randomSample = true + var newSentence = true + val ignoreLength = output.length() - // Feel free to change the starter sentence - var output = stcr.starterSentence - val randomSample = true - var newSentence = true - val ignoreLength = output.length() + for (i <- 0 until seqLength) { + if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray) + else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray) + val prob = model.forward(inputNdarray, newSentence) + newSentence = false + val nextChar = Utils.makeOutput(prob, revertVocab, randomSample) + if (nextChar == "") newSentence = true + if (i >= ignoreLength) output = output ++ nextChar + } - for (i <- 0 until seqLength) { - if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray) - else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray) - val prob = model.forward(inputNdarray, newSentence) - newSentence = false - val nextChar = Utils.makeOutput(prob, revertVocab, randomSample) - if (nextChar == "") newSentence = true - if (i >= ignoreLength) output = output ++ nextChar - } + // Let's see what we can learned from char in Obama's speech. + logger.info(output) + } - // Let's see what we can learned from char in Obama's speech. - logger.info(output) + def main(args: Array[String]): Unit = { + val stcr = new TestCharRnn + val parser: CmdLineParser = new CmdLineParser(stcr) + try { + parser.parseArgument(args.toList.asJava) + assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null) + runTestCharRNN(stcr.dataPath, stcr.modelPrefix, stcr.starterSentence) } catch { case ex: Exception => { logger.error(ex.getMessage, ex) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala index 3afb93686b00..fb59705c9ef0 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala @@ -24,143 +24,144 @@ import scala.collection.JavaConverters._ import org.apache.mxnet.optimizer.Adam /** - * Follows the demo, to train the char rnn: - * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb - * @author Depeng Liang - */ + * Follows the demo, to train the char rnn: + * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb + */ object TrainCharRnn { private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn]) - def main(args: Array[String]): Unit = { - val incr = new TrainCharRnn - val parser: CmdLineParser = new CmdLineParser(incr) - try { - parser.parseArgument(args.toList.asJava) - assert(incr.dataPath != null && incr.saveModelPath != null) - - // The batch size for training - val batchSize = 32 - // We can support various length input - // For this problem, we cut each input sentence to length of 129 - // So we only need fix length bucket - val buckets = Array(129) - // hidden unit in LSTM cell - val numHidden = 512 - // embedding dimension, which is, map a char to a 256 dim vector - val numEmbed = 256 - // number of lstm layer - val numLstmLayer = 3 - // we will show a quick demo in 2 epoch - // and we will see result by training 75 epoch - val numEpoch = 75 - // learning rate - val learningRate = 0.001f - // we will use pure sgd without momentum - val momentum = 0.0f - - val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu) - val vocab = Utils.buildVocab(incr.dataPath) - - // generate symbol for a length - def symGen(seqLen: Int): Symbol = { - Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1, - numHidden = numHidden, numEmbed = numEmbed, - numLabel = vocab.size + 1, dropout = 0.2f) - } + def runTrainCharRnn(dataPath: String, saveModelPath: String, + ctx : Context, numEpoch : Int): Unit = { + // The batch size for training + val batchSize = 32 + // We can support various length input + // For this problem, we cut each input sentence to length of 129 + // So we only need fix length bucket + val buckets = Array(129) + // hidden unit in LSTM cell + val numHidden = 512 + // embedding dimension, which is, map a char to a 256 dim vector + val numEmbed = 256 + // number of lstm layer + val numLstmLayer = 3 + // we will show a quick demo in 2 epoch + // learning rate + val learningRate = 0.001f + // we will use pure sgd without momentum + val momentum = 0.0f + + val vocab = Utils.buildVocab(dataPath) + + // generate symbol for a length + def symGen(seqLen: Int): Symbol = { + Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1, + numHidden = numHidden, numEmbed = numEmbed, + numLabel = vocab.size + 1, dropout = 0.2f) + } - // initalize states for LSTM - val initC = for (l <- 0 until numLstmLayer) - yield (s"l${l}_init_c_beta", (batchSize, numHidden)) - val initH = for (l <- 0 until numLstmLayer) - yield (s"l${l}_init_h_beta", (batchSize, numHidden)) - val initStates = initC ++ initH + // initalize states for LSTM + val initC = for (l <- 0 until numLstmLayer) + yield (s"l${l}_init_c_beta", (batchSize, numHidden)) + val initH = for (l <- 0 until numLstmLayer) + yield (s"l${l}_init_h_beta", (batchSize, numHidden)) + val initStates = initC ++ initH - val dataTrain = new BucketIo.BucketSentenceIter(incr.dataPath, vocab, buckets, - batchSize, initStates, seperateChar = "\n", - text2Id = Utils.text2Id, readContent = Utils.readContent) + val dataTrain = new BucketIo.BucketSentenceIter(dataPath, vocab, buckets, + batchSize, initStates, seperateChar = "\n", + text2Id = Utils.text2Id, readContent = Utils.readContent) - // the network symbol - val symbol = symGen(buckets(0)) + // the network symbol + val symbol = symGen(buckets(0)) - val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel - val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels) + val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel + val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels) - val initializer = new Xavier(factorType = "in", magnitude = 2.34f) + val initializer = new Xavier(factorType = "in", magnitude = 2.34f) - val argNames = symbol.listArguments() - val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap - val auxNames = symbol.listAuxiliaryStates() - val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap + val argNames = symbol.listArguments() + val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap + val auxNames = symbol.listAuxiliaryStates() + val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap - val gradDict = argNames.zip(argShapes).filter { case (name, shape) => - !datasAndLabels.contains(name) - }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap + val gradDict = argNames.zip(argShapes).filter { case (name, shape) => + !datasAndLabels.contains(name) + }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap - argDict.foreach { case (name, ndArray) => - if (!datasAndLabels.contains(name)) { - initializer.initWeight(name, ndArray) - } + argDict.foreach { case (name, ndArray) => + if (!datasAndLabels.contains(name)) { + initializer.initWeight(name, ndArray) } + } - val data = argDict("data") - val label = argDict("softmax_label") + val data = argDict("data") + val label = argDict("softmax_label") - val executor = symbol.bind(ctx, argDict, gradDict) + val executor = symbol.bind(ctx, argDict, gradDict) - val opt = new Adam(learningRate = learningRate, wd = 0.0001f) + val opt = new Adam(learningRate = learningRate, wd = 0.0001f) - val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) => - (idx, name, grad, opt.createState(idx, argDict(name))) - } + val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) => + (idx, name, grad, opt.createState(idx, argDict(name))) + } - val evalMetric = new CustomMetric(Utils.perplexity, "perplexity") - val batchEndCallback = new Callback.Speedometer(batchSize, 50) - val epochEndCallback = Utils.doCheckpoint(s"${incr.saveModelPath}/obama") - - for (epoch <- 0 until numEpoch) { - // Training phase - val tic = System.currentTimeMillis - evalMetric.reset() - var nBatch = 0 - var epochDone = false - // Iterate over training data. - dataTrain.reset() - while (!epochDone) { - var doReset = true - while (doReset && dataTrain.hasNext) { - val dataBatch = dataTrain.next() - - data.set(dataBatch.data(0)) - label.set(dataBatch.label(0)) - executor.forward(isTrain = true) - executor.backward() - paramsGrads.foreach { case (idx, name, grad, optimState) => - opt.update(idx, argDict(name), grad, optimState) - } - - // evaluate at end, so out_cpu_array can lazy copy - evalMetric.update(dataBatch.label, executor.outputs) - - nBatch += 1 - batchEndCallback.invoke(epoch, nBatch, evalMetric) + val evalMetric = new CustomMetric(Utils.perplexity, "perplexity") + val batchEndCallback = new Callback.Speedometer(batchSize, 50) + val epochEndCallback = Utils.doCheckpoint(s"${saveModelPath}/obama") + + for (epoch <- 0 until numEpoch) { + // Training phase + val tic = System.currentTimeMillis + evalMetric.reset() + var nBatch = 0 + var epochDone = false + // Iterate over training data. + dataTrain.reset() + while (!epochDone) { + var doReset = true + while (doReset && dataTrain.hasNext) { + val dataBatch = dataTrain.next() + + data.set(dataBatch.data(0)) + label.set(dataBatch.label(0)) + executor.forward(isTrain = true) + executor.backward() + paramsGrads.foreach { case (idx, name, grad, optimState) => + opt.update(idx, argDict(name), grad, optimState) } - if (doReset) { - dataTrain.reset() - } - // this epoch is done - epochDone = true + + // evaluate at end, so out_cpu_array can lazy copy + evalMetric.update(dataBatch.label, executor.outputs) + + nBatch += 1 + batchEndCallback.invoke(epoch, nBatch, evalMetric) } - val (name, value) = evalMetric.get - name.zip(value).foreach { case (n, v) => - logger.info(s"Epoch[$epoch] Train-$n=$v") + if (doReset) { + dataTrain.reset() } - val toc = System.currentTimeMillis - logger.info(s"Epoch[$epoch] Time cost=${toc - tic}") - - epochEndCallback.invoke(epoch, symbol, argDict, auxDict) + // this epoch is done + epochDone = true } - executor.dispose() + val (name, value) = evalMetric.get + name.zip(value).foreach { case (n, v) => + logger.info(s"Epoch[$epoch] Train-$n=$v") + } + val toc = System.currentTimeMillis + logger.info(s"Epoch[$epoch] Time cost=${toc - tic}") + + epochEndCallback.invoke(epoch, symbol, argDict, auxDict) + } + executor.dispose() + } + + def main(args: Array[String]): Unit = { + val incr = new TrainCharRnn + val parser: CmdLineParser = new CmdLineParser(incr) + try { + parser.parseArgument(args.toList.asJava) + val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu) + assert(incr.dataPath != null && incr.saveModelPath != null) + runTrainCharRnn(incr.dataPath, incr.saveModelPath, ctx, 75) } catch { case ex: Exception => { logger.error(ex.getMessage, ex) @@ -172,12 +173,6 @@ object TrainCharRnn { } class TrainCharRnn { - /* - * Get Training Data: E.g. - * mkdir data; cd data - * wget "http://data.mxnet.io/mxnet/data/char_lstm.zip" - * unzip -o char_lstm.zip - */ @Option(name = "--data-path", usage = "the input train data file") private val dataPath: String = "./data/obama.txt" @Option(name = "--save-model-path", usage = "the model saving path") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala index c2902309679d..3f9a9842e0a9 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala @@ -25,9 +25,6 @@ import org.apache.mxnet.Model import org.apache.mxnet.Symbol import scala.util.Random -/** - * @author Depeng Liang - */ object Utils { def readContent(path: String): String = Source.fromFile(path).mkString diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala new file mode 100644 index 000000000000..b393a433305a --- /dev/null +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.rnn + + +import org.apache.mxnet.{Context, NDArrayCollector} +import org.apache.mxnetexamples.Util +import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} +import org.slf4j.LoggerFactory + +import scala.sys.process.Process + +@Ignore +class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll { + private val logger = LoggerFactory.getLogger(classOf[ExampleRNNSuite]) + + override def beforeAll(): Unit = { + logger.info("Downloading LSTM model") + val tempDirPath = System.getProperty("java.io.tmpdir") + logger.info("tempDirPath: %s".format(tempDirPath)) + val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/" + Util.downloadUrl(baseUrl + "obama.zip", tempDirPath + "/RNN/obama.zip") + Util.downloadUrl(baseUrl + "sherlockholmes.train.txt", + tempDirPath + "/RNN/sherlockholmes.train.txt") + Util.downloadUrl(baseUrl + "sherlockholmes.valid.txt", + tempDirPath + "/RNN/sherlockholmes.valid.txt") + // TODO: Need to confirm with Windows + Process(s"unzip $tempDirPath/RNN/obama.zip -d $tempDirPath/RNN/") ! + } + + test("Example CI: Test LSTM Bucketing") { + val tempDirPath = System.getProperty("java.io.tmpdir") + var ctx = Context.cpu() + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { + ctx = Context.gpu() + } + LstmBucketing.runTraining(tempDirPath + "/RNN/sherlockholmes.train.txt", + tempDirPath + "/RNN/sherlockholmes.valid.txt", Array(ctx), 1) + } + + test("Example CI: Test TrainCharRNN") { + val tempDirPath = System.getProperty("java.io.tmpdir") + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { + val ctx = Context.gpu() + TrainCharRnn.runTrainCharRnn(tempDirPath + "/RNN/obama.txt", + tempDirPath, ctx, 1) + } else { + logger.info("CPU not supported for this test, skipped...") + } + } + + test("Example CI: Test TestCharRNN") { + val tempDirPath = System.getProperty("java.io.tmpdir") + val ctx = Context.gpu() + TestCharRnn.runTestCharRNN(tempDirPath + "/RNN/obama.txt", + tempDirPath + "/RNN/obama", "The joke") + } +}