Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-836] RNN Example for Scala (#11753)
Browse files Browse the repository at this point in the history
* initial fix for RNN

* add CI test

* add encoding format

* scala style fix

* update readme

* test char RNN works

* ignore the test due to memory leaks
  • Loading branch information
lanking520 authored and nswamy committed Aug 21, 2018
1 parent 332a664 commit 38f80af
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 286 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object BucketIo {
type ReadContent = String => String

def defaultReadContent(path: String): String = {
Source.fromFile(path).mkString.replaceAll("\\. |\n", " <eos> ")
Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " <eos> ")
}

def defaultBuildVocab(path: String): Map[String, Int] = {
Expand All @@ -56,7 +56,7 @@ object BucketIo {
val tmp = sentence.split(" ").filter(_.length() > 0)
for (w <- tmp) yield theVocab(w)
}
words.toArray
words
}

def defaultGenBuckets(sentences: Array[String], batchSize: Int,
Expand Down Expand Up @@ -162,8 +162,6 @@ object BucketIo {
labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
}

private val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))

private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey))
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
Expand Down Expand Up @@ -208,12 +206,13 @@ object BucketIo {
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape)
new DataBatch(IndexedSeq(dataBuf) ++ initStateArrays,
IndexedSeq(labelBuf),
getIndex(),
getPad(),
this.buckets(bucketIdx).asInstanceOf[AnyRef],
batchProvideData, batchProvideLabel)
val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
IndexedSeq(labelBuf.copy()),
getIndex(),
getPad(),
this.buckets(bucketIdx).asInstanceOf[AnyRef],
batchProvideData, batchProvideLabel)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@

package org.apache.mxnetexamples.rnn

import org.apache.mxnet.Symbol
import org.apache.mxnet.{Shape, Symbol}

import scala.collection.mutable.ArrayBuffer

/**
* @author Depeng Liang
*/
object Lstm {

final case class LSTMState(c: Symbol, h: Symbol)
Expand All @@ -35,27 +32,22 @@ object Lstm {
def lstm(numHidden: Int, inData: Symbol, prevState: LSTMState,
param: LSTMParam, seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = {
val inDataa = {
if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout))
if (dropout > 0f) Symbol.api.Dropout(data = Some(inData), p = Some(dropout))
else inData
}
val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa,
"weight" -> param.i2hWeight,
"bias" -> param.i2hBias,
"num_hidden" -> numHidden * 4))
val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,
"weight" -> param.h2hWeight,
"bias" -> param.h2hBias,
"num_hidden" -> numHidden * 4))
val i2h = Symbol.api.FullyConnected(data = Some(inDataa), weight = Some(param.i2hWeight),
bias = Some(param.i2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_i2h")
val h2h = Symbol.api.FullyConnected(data = Some(prevState.h), weight = Some(param.h2hWeight),
bias = Some(param.h2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_h2h")
val gates = i2h + h2h
val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(
gates)(Map("num_outputs" -> 4))
val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid"))
val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh"))
val forgetGate = Symbol.Activation()()(
Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid"))
val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid"))
val sliceGates = Symbol.api.SliceChannel(data = Some(gates), num_outputs = 4,
name = s"t${seqIdx}_l${layerIdx}_slice")
val ingate = Symbol.api.Activation(data = Some(sliceGates.get(0)), act_type = "sigmoid")
val inTransform = Symbol.api.Activation(data = Some(sliceGates.get(1)), act_type = "tanh")
val forgetGate = Symbol.api.Activation(data = Some(sliceGates.get(2)), act_type = "sigmoid")
val outGate = Symbol.api.Activation(data = Some(sliceGates.get(3)), act_type = "sigmoid")
val nextC = (forgetGate * prevState.c) + (ingate * inTransform)
val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh"))
val nextH = outGate * Symbol.api.Activation(data = Some(nextC), "tanh")
LSTMState(c = nextC, h = nextH)
}

Expand All @@ -74,11 +66,11 @@ object Lstm {
val lastStatesBuf = ArrayBuffer[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCellsBuf.append(LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
h = Symbol.Variable(s"l${i}_init_h_beta")))
h = Symbol.Variable(s"l${i}_init_h_beta")))
}
val paramCells = paramCellsBuf.toArray
val lastStates = lastStatesBuf.toArray
Expand All @@ -87,10 +79,10 @@ object Lstm {
// embeding layer
val data = Symbol.Variable("data")
var label = Symbol.Variable("softmax_label")
val embed = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
"weight" -> embedWeight, "output_dim" -> numEmbed))
val wordvec = Symbol.SliceChannel()()(
Map("data" -> embed, "num_outputs" -> seqLen, "squeeze_axis" -> 1))
val embed = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
val wordvec = Symbol.api.SliceChannel(data = Some(embed),
num_outputs = seqLen, squeeze_axis = Some(true))

val hiddenAll = ArrayBuffer[Symbol]()
var dpRatio = 0f
Expand All @@ -101,22 +93,23 @@ object Lstm {
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
hiddenAll.append(hidden)
}
val hiddenConcat = Symbol.Concat()(hiddenAll: _*)(Map("dim" -> 0))
val pred = Symbol.FullyConnected("pred")()(Map("data" -> hiddenConcat, "num_hidden" -> numLabel,
"weight" -> clsWeight, "bias" -> clsBias))
label = Symbol.transpose()(label)()
label = Symbol.Reshape()()(Map("data" -> label, "target_shape" -> "(0,)"))
val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> pred, "label" -> label))
val hiddenConcat = Symbol.api.Concat(data = hiddenAll.toArray, num_args = hiddenAll.length,
dim = Some(0))
val pred = Symbol.api.FullyConnected(data = Some(hiddenConcat), num_hidden = numLabel,
weight = Some(clsWeight), bias = Some(clsBias))
label = Symbol.api.transpose(data = Some(label))
label = Symbol.api.Reshape(data = Some(label), target_shape = Some(Shape(0)))
val sm = Symbol.api.SoftmaxOutput(data = Some(pred), label = Some(label), name = "softmax")
sm
}

Expand All @@ -131,35 +124,35 @@ object Lstm {
var lastStates = Array[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCells = paramCells :+ LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
lastStates = lastStates :+ LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
h = Symbol.Variable(s"l${i}_init_h_beta"))
h = Symbol.Variable(s"l${i}_init_h_beta"))
}
assert(lastStates.length == numLstmLayer)

val data = Symbol.Variable("data")

var hidden = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
"weight" -> embedWeight, "output_dim" -> numEmbed))
var hidden = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
weight = Some(embedWeight), output_dim = numEmbed, name = "embed")

var dpRatio = 0f
// stack LSTM
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
prevState = lastStates(i),
param = paramCells(i),
seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
val fc = Symbol.FullyConnected("pred")()(Map("data" -> hidden, "num_hidden" -> numLabel,
"weight" -> clsWeight, "bias" -> clsBias))
val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> fc))
if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
val fc = Symbol.api.FullyConnected(data = Some(hidden),
num_hidden = numLabel, weight = Some(clsWeight), bias = Some(clsBias))
val sm = Symbol.api.SoftmaxOutput(data = Some(fc), name = "softmax")
var output = Array(sm)
for (state <- lastStates) {
output = output :+ state.c
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ import org.apache.mxnet.module.BucketingModule
import org.apache.mxnet.module.FitParams

/**
* Bucketing LSTM examples
* @author Yizhi Liu
*/
* Bucketing LSTM examples
*/
class LstmBucketing {
@Option(name = "--data-train", usage = "training set")
private val dataTrain: String = "example/rnn/sherlockholmes.train.txt"
Expand Down Expand Up @@ -61,6 +60,60 @@ object LstmBucketing {
Math.exp(loss / labelArr.length).toFloat
}

def runTraining(trainData : String, validationData : String,
ctx : Array[Context], numEpoch : Int): Unit = {
val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
val numEmbed = 200
val numLstmLayer = 2

logger.info("Building vocab ...")
val vocab = BucketIo.defaultBuildVocab(trainData)

def BucketSymGen(key: AnyRef):
(Symbol, IndexedSeq[String], IndexedSeq[String]) = {
val seqLen = key.asInstanceOf[Int]
val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
(sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
}

val initC = (0 until numLstmLayer).map(l =>
(s"l${l}_init_c_beta", (batchSize, numHidden))
)
val initH = (0 until numLstmLayer).map(l =>
(s"l${l}_init_h_beta", (batchSize, numHidden))
)
val initStates = initC ++ initH

val dataTrain = new BucketSentenceIter(trainData, vocab,
buckets, batchSize, initStates)
val dataVal = new BucketSentenceIter(validationData, vocab,
buckets, batchSize, initStates)

val model = new BucketingModule(
symGen = BucketSymGen,
defaultBucketKey = dataTrain.defaultBucketKey,
contexts = ctx)

val fitParams = new FitParams()
fitParams.setEvalMetric(
new CustomMetric(perplexity, name = "perplexity"))
fitParams.setKVStore("device")
fitParams.setOptimizer(
new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))

logger.info("Start training ...")
model.fit(
trainData = dataTrain,
evalData = Some(dataVal),
numEpoch = numEpoch, fitParams)
logger.info("Finished training...")
}

def main(args: Array[String]): Unit = {
val inst = new LstmBucketing
val parser: CmdLineParser = new CmdLineParser(inst)
Expand All @@ -71,56 +124,7 @@ object LstmBucketing {
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))

val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
val numEmbed = 200
val numLstmLayer = 2

logger.info("Building vocab ...")
val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)

def BucketSymGen(key: AnyRef):
(Symbol, IndexedSeq[String], IndexedSeq[String]) = {
val seqLen = key.asInstanceOf[Int]
val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
(sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
}

val initC = (0 until numLstmLayer).map(l =>
(s"l${l}_init_c_beta", (batchSize, numHidden))
)
val initH = (0 until numLstmLayer).map(l =>
(s"l${l}_init_h_beta", (batchSize, numHidden))
)
val initStates = initC ++ initH

val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab,
buckets, batchSize, initStates)
val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
buckets, batchSize, initStates)

val model = new BucketingModule(
symGen = BucketSymGen,
defaultBucketKey = dataTrain.defaultBucketKey,
contexts = contexts)

val fitParams = new FitParams()
fitParams.setEvalMetric(
new CustomMetric(perplexity, name = "perplexity"))
fitParams.setKVStore("device")
fitParams.setOptimizer(
new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))

logger.info("Start training ...")
model.fit(
trainData = dataTrain,
evalData = Some(dataVal),
numEpoch = inst.numEpoch, fitParams)
logger.info("Finished training...")
runTraining(inst.dataTrain, inst.dataVal, contexts, 5)
} catch {
case ex: Exception =>
logger.error(ex.getMessage, ex)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# RNN Example for MXNet Scala
This folder contains the following examples writing in new Scala type-safe API:
- [x] LSTM Bucketing
- [x] CharRNN Inference : Generate similar text based on the model
- [x] CharRNN Training: Training the language model using RNN

These example is only for Illustration and not modeled to achieve the best accuracy.

## Setup
### Download the Network Definition, Weights and Training Data
`obama.zip` contains the training inputs (Obama's speech) for CharCNN examples and `sherlockholmes` contains the data for LSTM Bucketing
```bash
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/obama.zip
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.train.txt
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.valid.txt
```
### Unzip the file
```bash
unzip obama.zip
```
### Arguement Configuration
Then you need to define the arguments that you would like to pass in the model:

#### LSTM Bucketing
```bash
--data-train
<path>/sherlockholmes.train.txt
--data-val
<path>/sherlockholmes.valid.txt
--cpus
<num_cpus>
--gpus
<num_gpu>
```
#### TrainCharRnn
```bash
--data-path
<path>/obama.txt
--save-model-path
<path>/
```
#### TestCharRnn
```bash
--data-path
<path>/obama.txt
--model-prefix
<path>/obama
```
Loading

0 comments on commit 38f80af

Please sign in to comment.