Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add CI test
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 13, 2018
1 parent 62bf495 commit 2541955
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ package org.apache.mxnetexamples.rnn

import org.apache.mxnet.Callback.Speedometer
import org.apache.mxnet._
import org.apache.mxnet.module.{BucketingModule, FitParams}
import BucketIo.BucketSentenceIter
import org.apache.mxnet.optimizer.SGD
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.{Logger, LoggerFactory}
import BucketIo.BucketSentenceIter

import scala.collection.JavaConverters._
import org.apache.mxnet.module.BucketingModule
import org.apache.mxnet.module.FitParams

/**
* Bucketing LSTM examples
Expand All @@ -53,14 +54,66 @@ object LstmBucketing {
pred.waitToRead()
val labelArr = label.T.toArray.map(_.toInt)
var loss = .0
(0 until pred.shape(0)).foreach(i => {
val temp = pred.slice(i)
loss -= Math.log(Math.max(1e-10f, temp.toArray(labelArr(i))))
temp.dispose()
})
(0 until pred.shape(0)).foreach(i =>
loss -= Math.log(Math.max(1e-10f, pred.slice(i).toArray(labelArr(i))))
)
Math.exp(loss / labelArr.length).toFloat
}

def runTraining(trainData : String, validationData : String,
ctx : Array[Context], numEpoch : Int): Unit = {
val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
val numEmbed = 200
val numLstmLayer = 2

logger.info("Building vocab ...")
val vocab = BucketIo.defaultBuildVocab(trainData)

def BucketSymGen(key: AnyRef):
(Symbol, IndexedSeq[String], IndexedSeq[String]) = {
val seqLen = key.asInstanceOf[Int]
val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
(sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
}

val initC = (0 until numLstmLayer).map(l =>
(s"l${l}_init_c_beta", (batchSize, numHidden))
)
val initH = (0 until numLstmLayer).map(l =>
(s"l${l}_init_h_beta", (batchSize, numHidden))
)
val initStates = initC ++ initH

val dataTrain = new BucketSentenceIter(trainData, vocab,
buckets, batchSize, initStates)
val dataVal = new BucketSentenceIter(validationData, vocab,
buckets, batchSize, initStates)

val model = new BucketingModule(
symGen = BucketSymGen,
defaultBucketKey = dataTrain.defaultBucketKey,
contexts = ctx)

val fitParams = new FitParams()
fitParams.setEvalMetric(
new CustomMetric(perplexity, name = "perplexity"))
fitParams.setKVStore("device")
fitParams.setOptimizer(
new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))

logger.info("Start training ...")
model.fit(
trainData = dataTrain,
evalData = Some(dataVal),
numEpoch = numEpoch, fitParams)
logger.info("Finished training...")
}

def main(args: Array[String]): Unit = {
val inst = new LstmBucketing
val parser: CmdLineParser = new CmdLineParser(inst)
Expand All @@ -71,56 +124,7 @@ object LstmBucketing {
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))

val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
val numEmbed = 200
val numLstmLayer = 2

logger.info("Building vocab ...")
val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)

def BucketSymGen(key: AnyRef):
(Symbol, IndexedSeq[String], IndexedSeq[String]) = {
val seqLen = key.asInstanceOf[Int]
val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
(sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
}

val initC = (0 until numLstmLayer).map(l =>
(s"l${l}_init_c_beta", (batchSize, numHidden))
)
val initH = (0 until numLstmLayer).map(l =>
(s"l${l}_init_h_beta", (batchSize, numHidden))
)
val initStates = initC ++ initH

val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab,
buckets, batchSize, initStates)
val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
buckets, batchSize, initStates)

val model = new BucketingModule(
symGen = BucketSymGen,
defaultBucketKey = dataTrain.defaultBucketKey,
contexts = contexts)

val fitParams = new FitParams()
fitParams.setEvalMetric(
new CustomMetric(perplexity, name = "perplexity"))
fitParams.setKVStore("device")
fitParams.setOptimizer(
new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))

logger.info("Start training ...")
model.fit(
trainData = dataTrain,
evalData = Some(dataVal),
numEpoch = inst.numEpoch, fitParams)
logger.info("Finished training...")
runTraining(inst.dataTrain, inst.dataVal, contexts, 5)
} catch {
case ex: Exception =>
logger.error(ex.getMessage, ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._
/**
* Follows the demo, to test the char rnn:
* https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
* @author Depeng Liang
*/
object TestCharRnn {

Expand Down
Loading

0 comments on commit 2541955

Please sign in to comment.