diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala index c43e35568fe8..f7a01bad133a 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala @@ -20,13 +20,14 @@ package org.apache.mxnetexamples.rnn import org.apache.mxnet.Callback.Speedometer import org.apache.mxnet._ -import org.apache.mxnet.module.{BucketingModule, FitParams} +import BucketIo.BucketSentenceIter import org.apache.mxnet.optimizer.SGD import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.{Logger, LoggerFactory} -import BucketIo.BucketSentenceIter import scala.collection.JavaConverters._ +import org.apache.mxnet.module.BucketingModule +import org.apache.mxnet.module.FitParams /** * Bucketing LSTM examples @@ -53,14 +54,66 @@ object LstmBucketing { pred.waitToRead() val labelArr = label.T.toArray.map(_.toInt) var loss = .0 - (0 until pred.shape(0)).foreach(i => { - val temp = pred.slice(i) - loss -= Math.log(Math.max(1e-10f, temp.toArray(labelArr(i)))) - temp.dispose() - }) + (0 until pred.shape(0)).foreach(i => + loss -= Math.log(Math.max(1e-10f, pred.slice(i).toArray(labelArr(i)))) + ) Math.exp(loss / labelArr.length).toFloat } + def runTraining(trainData : String, validationData : String, + ctx : Array[Context], numEpoch : Int): Unit = { + val batchSize = 32 + val buckets = Array(10, 20, 30, 40, 50, 60) + val numHidden = 200 + val numEmbed = 200 + val numLstmLayer = 2 + + logger.info("Building vocab ...") + val vocab = BucketIo.defaultBuildVocab(trainData) + + def BucketSymGen(key: AnyRef): + (Symbol, IndexedSeq[String], IndexedSeq[String]) = { + val seqLen = key.asInstanceOf[Int] + val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size, + numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size) + (sym, IndexedSeq("data"), IndexedSeq("softmax_label")) + } + + val initC = (0 until numLstmLayer).map(l => + (s"l${l}_init_c_beta", (batchSize, numHidden)) + ) + val initH = (0 until numLstmLayer).map(l => + (s"l${l}_init_h_beta", (batchSize, numHidden)) + ) + val initStates = initC ++ initH + + val dataTrain = new BucketSentenceIter(trainData, vocab, + buckets, batchSize, initStates) + val dataVal = new BucketSentenceIter(validationData, vocab, + buckets, batchSize, initStates) + + val model = new BucketingModule( + symGen = BucketSymGen, + defaultBucketKey = dataTrain.defaultBucketKey, + contexts = ctx) + + val fitParams = new FitParams() + fitParams.setEvalMetric( + new CustomMetric(perplexity, name = "perplexity")) + fitParams.setKVStore("device") + fitParams.setOptimizer( + new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f)) + fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f)) + fitParams.setBatchEndCallback(new Speedometer(batchSize, 50)) + + logger.info("Start training ...") + model.fit( + trainData = dataTrain, + evalData = Some(dataVal), + numEpoch = numEpoch, fitParams) + logger.info("Finished training...") + } + def main(args: Array[String]): Unit = { val inst = new LstmBucketing val parser: CmdLineParser = new CmdLineParser(inst) @@ -71,56 +124,7 @@ object LstmBucketing { else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt)) else Array(Context.cpu(0)) - val batchSize = 32 - val buckets = Array(10, 20, 30, 40, 50, 60) - val numHidden = 200 - val numEmbed = 200 - val numLstmLayer = 2 - - logger.info("Building vocab ...") - val vocab = BucketIo.defaultBuildVocab(inst.dataTrain) - - def BucketSymGen(key: AnyRef): - (Symbol, IndexedSeq[String], IndexedSeq[String]) = { - val seqLen = key.asInstanceOf[Int] - val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size, - numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size) - (sym, IndexedSeq("data"), IndexedSeq("softmax_label")) - } - - val initC = (0 until numLstmLayer).map(l => - (s"l${l}_init_c_beta", (batchSize, numHidden)) - ) - val initH = (0 until numLstmLayer).map(l => - (s"l${l}_init_h_beta", (batchSize, numHidden)) - ) - val initStates = initC ++ initH - - val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab, - buckets, batchSize, initStates) - val dataVal = new BucketSentenceIter(inst.dataVal, vocab, - buckets, batchSize, initStates) - - val model = new BucketingModule( - symGen = BucketSymGen, - defaultBucketKey = dataTrain.defaultBucketKey, - contexts = contexts) - - val fitParams = new FitParams() - fitParams.setEvalMetric( - new CustomMetric(perplexity, name = "perplexity")) - fitParams.setKVStore("device") - fitParams.setOptimizer( - new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f)) - fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f)) - fitParams.setBatchEndCallback(new Speedometer(batchSize, 50)) - - logger.info("Start training ...") - model.fit( - trainData = dataTrain, - evalData = Some(dataVal), - numEpoch = inst.numEpoch, fitParams) - logger.info("Finished training...") + runTraining(inst.dataTrain, inst.dataVal, contexts, 5) } catch { case ex: Exception => logger.error(ex.getMessage, ex) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala index 243b70c0670d..ef572863dcfe 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._ /** * Follows the demo, to test the char rnn: * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb - * @author Depeng Liang */ object TestCharRnn { diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala index 3afb93686b00..fb59705c9ef0 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala @@ -24,143 +24,144 @@ import scala.collection.JavaConverters._ import org.apache.mxnet.optimizer.Adam /** - * Follows the demo, to train the char rnn: - * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb - * @author Depeng Liang - */ + * Follows the demo, to train the char rnn: + * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb + */ object TrainCharRnn { private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn]) - def main(args: Array[String]): Unit = { - val incr = new TrainCharRnn - val parser: CmdLineParser = new CmdLineParser(incr) - try { - parser.parseArgument(args.toList.asJava) - assert(incr.dataPath != null && incr.saveModelPath != null) - - // The batch size for training - val batchSize = 32 - // We can support various length input - // For this problem, we cut each input sentence to length of 129 - // So we only need fix length bucket - val buckets = Array(129) - // hidden unit in LSTM cell - val numHidden = 512 - // embedding dimension, which is, map a char to a 256 dim vector - val numEmbed = 256 - // number of lstm layer - val numLstmLayer = 3 - // we will show a quick demo in 2 epoch - // and we will see result by training 75 epoch - val numEpoch = 75 - // learning rate - val learningRate = 0.001f - // we will use pure sgd without momentum - val momentum = 0.0f - - val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu) - val vocab = Utils.buildVocab(incr.dataPath) - - // generate symbol for a length - def symGen(seqLen: Int): Symbol = { - Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1, - numHidden = numHidden, numEmbed = numEmbed, - numLabel = vocab.size + 1, dropout = 0.2f) - } + def runTrainCharRnn(dataPath: String, saveModelPath: String, + ctx : Context, numEpoch : Int): Unit = { + // The batch size for training + val batchSize = 32 + // We can support various length input + // For this problem, we cut each input sentence to length of 129 + // So we only need fix length bucket + val buckets = Array(129) + // hidden unit in LSTM cell + val numHidden = 512 + // embedding dimension, which is, map a char to a 256 dim vector + val numEmbed = 256 + // number of lstm layer + val numLstmLayer = 3 + // we will show a quick demo in 2 epoch + // learning rate + val learningRate = 0.001f + // we will use pure sgd without momentum + val momentum = 0.0f + + val vocab = Utils.buildVocab(dataPath) + + // generate symbol for a length + def symGen(seqLen: Int): Symbol = { + Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1, + numHidden = numHidden, numEmbed = numEmbed, + numLabel = vocab.size + 1, dropout = 0.2f) + } - // initalize states for LSTM - val initC = for (l <- 0 until numLstmLayer) - yield (s"l${l}_init_c_beta", (batchSize, numHidden)) - val initH = for (l <- 0 until numLstmLayer) - yield (s"l${l}_init_h_beta", (batchSize, numHidden)) - val initStates = initC ++ initH + // initalize states for LSTM + val initC = for (l <- 0 until numLstmLayer) + yield (s"l${l}_init_c_beta", (batchSize, numHidden)) + val initH = for (l <- 0 until numLstmLayer) + yield (s"l${l}_init_h_beta", (batchSize, numHidden)) + val initStates = initC ++ initH - val dataTrain = new BucketIo.BucketSentenceIter(incr.dataPath, vocab, buckets, - batchSize, initStates, seperateChar = "\n", - text2Id = Utils.text2Id, readContent = Utils.readContent) + val dataTrain = new BucketIo.BucketSentenceIter(dataPath, vocab, buckets, + batchSize, initStates, seperateChar = "\n", + text2Id = Utils.text2Id, readContent = Utils.readContent) - // the network symbol - val symbol = symGen(buckets(0)) + // the network symbol + val symbol = symGen(buckets(0)) - val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel - val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels) + val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel + val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels) - val initializer = new Xavier(factorType = "in", magnitude = 2.34f) + val initializer = new Xavier(factorType = "in", magnitude = 2.34f) - val argNames = symbol.listArguments() - val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap - val auxNames = symbol.listAuxiliaryStates() - val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap + val argNames = symbol.listArguments() + val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap + val auxNames = symbol.listAuxiliaryStates() + val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap - val gradDict = argNames.zip(argShapes).filter { case (name, shape) => - !datasAndLabels.contains(name) - }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap + val gradDict = argNames.zip(argShapes).filter { case (name, shape) => + !datasAndLabels.contains(name) + }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap - argDict.foreach { case (name, ndArray) => - if (!datasAndLabels.contains(name)) { - initializer.initWeight(name, ndArray) - } + argDict.foreach { case (name, ndArray) => + if (!datasAndLabels.contains(name)) { + initializer.initWeight(name, ndArray) } + } - val data = argDict("data") - val label = argDict("softmax_label") + val data = argDict("data") + val label = argDict("softmax_label") - val executor = symbol.bind(ctx, argDict, gradDict) + val executor = symbol.bind(ctx, argDict, gradDict) - val opt = new Adam(learningRate = learningRate, wd = 0.0001f) + val opt = new Adam(learningRate = learningRate, wd = 0.0001f) - val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) => - (idx, name, grad, opt.createState(idx, argDict(name))) - } + val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) => + (idx, name, grad, opt.createState(idx, argDict(name))) + } - val evalMetric = new CustomMetric(Utils.perplexity, "perplexity") - val batchEndCallback = new Callback.Speedometer(batchSize, 50) - val epochEndCallback = Utils.doCheckpoint(s"${incr.saveModelPath}/obama") - - for (epoch <- 0 until numEpoch) { - // Training phase - val tic = System.currentTimeMillis - evalMetric.reset() - var nBatch = 0 - var epochDone = false - // Iterate over training data. - dataTrain.reset() - while (!epochDone) { - var doReset = true - while (doReset && dataTrain.hasNext) { - val dataBatch = dataTrain.next() - - data.set(dataBatch.data(0)) - label.set(dataBatch.label(0)) - executor.forward(isTrain = true) - executor.backward() - paramsGrads.foreach { case (idx, name, grad, optimState) => - opt.update(idx, argDict(name), grad, optimState) - } - - // evaluate at end, so out_cpu_array can lazy copy - evalMetric.update(dataBatch.label, executor.outputs) - - nBatch += 1 - batchEndCallback.invoke(epoch, nBatch, evalMetric) + val evalMetric = new CustomMetric(Utils.perplexity, "perplexity") + val batchEndCallback = new Callback.Speedometer(batchSize, 50) + val epochEndCallback = Utils.doCheckpoint(s"${saveModelPath}/obama") + + for (epoch <- 0 until numEpoch) { + // Training phase + val tic = System.currentTimeMillis + evalMetric.reset() + var nBatch = 0 + var epochDone = false + // Iterate over training data. + dataTrain.reset() + while (!epochDone) { + var doReset = true + while (doReset && dataTrain.hasNext) { + val dataBatch = dataTrain.next() + + data.set(dataBatch.data(0)) + label.set(dataBatch.label(0)) + executor.forward(isTrain = true) + executor.backward() + paramsGrads.foreach { case (idx, name, grad, optimState) => + opt.update(idx, argDict(name), grad, optimState) } - if (doReset) { - dataTrain.reset() - } - // this epoch is done - epochDone = true + + // evaluate at end, so out_cpu_array can lazy copy + evalMetric.update(dataBatch.label, executor.outputs) + + nBatch += 1 + batchEndCallback.invoke(epoch, nBatch, evalMetric) } - val (name, value) = evalMetric.get - name.zip(value).foreach { case (n, v) => - logger.info(s"Epoch[$epoch] Train-$n=$v") + if (doReset) { + dataTrain.reset() } - val toc = System.currentTimeMillis - logger.info(s"Epoch[$epoch] Time cost=${toc - tic}") - - epochEndCallback.invoke(epoch, symbol, argDict, auxDict) + // this epoch is done + epochDone = true } - executor.dispose() + val (name, value) = evalMetric.get + name.zip(value).foreach { case (n, v) => + logger.info(s"Epoch[$epoch] Train-$n=$v") + } + val toc = System.currentTimeMillis + logger.info(s"Epoch[$epoch] Time cost=${toc - tic}") + + epochEndCallback.invoke(epoch, symbol, argDict, auxDict) + } + executor.dispose() + } + + def main(args: Array[String]): Unit = { + val incr = new TrainCharRnn + val parser: CmdLineParser = new CmdLineParser(incr) + try { + parser.parseArgument(args.toList.asJava) + val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu) + assert(incr.dataPath != null && incr.saveModelPath != null) + runTrainCharRnn(incr.dataPath, incr.saveModelPath, ctx, 75) } catch { case ex: Exception => { logger.error(ex.getMessage, ex) @@ -172,12 +173,6 @@ object TrainCharRnn { } class TrainCharRnn { - /* - * Get Training Data: E.g. - * mkdir data; cd data - * wget "http://data.mxnet.io/mxnet/data/char_lstm.zip" - * unzip -o char_lstm.zip - */ @Option(name = "--data-path", usage = "the input train data file") private val dataPath: String = "./data/obama.txt" @Option(name = "--save-model-path", usage = "the model saving path") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala index c2902309679d..3f9a9842e0a9 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala @@ -25,9 +25,6 @@ import org.apache.mxnet.Model import org.apache.mxnet.Symbol import scala.util.Random -/** - * @author Depeng Liang - */ object Utils { def readContent(path: String): String = Source.fromFile(path).mkString diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala new file mode 100644 index 000000000000..71157d48675c --- /dev/null +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.rnn + +import java.io.File +import java.net.URL + +import org.apache.commons.io.FileUtils +import org.apache.mxnet.Context +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.slf4j.LoggerFactory + +import scala.sys.process.Process + +class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll { + private val logger = LoggerFactory.getLogger(classOf[ExampleRNNSuite]) + + def downloadUrl(url: String, filePath: String) : Unit = { + val tmpFile = new File(filePath) + if (!tmpFile.exists()) { + FileUtils.copyURLToFile(new URL(url), tmpFile) + } + } + + override def beforeAll(): Unit = { + logger.info("Downloading LSTM model") + val tempDirPath = System.getProperty("java.io.tmpdir") + logger.info("tempDirPath: %s".format(tempDirPath)) + val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/" + downloadUrl(baseUrl + "obama.zip", tempDirPath + "/RNN/obama.zip") + downloadUrl(baseUrl + "sherlockholmes.train.txt", tempDirPath + "/RNN/sherlockholmes.train.txt") + downloadUrl(baseUrl + "sherlockholmes.valid.txt", tempDirPath + "/RNN/sherlockholmes.valid.txt") + // TODO: Need to confirm with Windows + Process(s"unzip $tempDirPath/RNN/obama.zip -d $tempDirPath/RNN/") ! + } + + test("Example CI: Test LSTM Bucketing") { + val tempDirPath = System.getProperty("java.io.tmpdir") + var ctx = Context.cpu() + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { + ctx = Context.gpu() + } + LstmBucketing.runTraining(tempDirPath + "/RNN/sherlockholmes.train.txt", + tempDirPath + "/RNN/sherlockholmes.valid.txt", Array(ctx), 3) + } + + test("Example CI: Test TrainCharRNN") { + val tempDirPath = System.getProperty("java.io.tmpdir") + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { + val ctx = Context.gpu() + TrainCharRnn.runTrainCharRnn(tempDirPath + "/RNN/obama.txt", + tempDirPath, ctx, 1) + } else { + logger.info("CPU not supported for this test, skipped...") + } + } +}