Skip to content

Commit

Permalink
[MXNET-689] add DataDesc type for the Scala Package (apache#11844)
Browse files Browse the repository at this point in the history
* add dataDesc

* Add amend

* add changes with dataLayout and labelLayout

* add depreciate and example changes

* Gan and Customop fixes

* change the DType

* add one more class to convert Strings to DTypes

* convert layout to global

* scala style fix

* Revert to 8c7d1f8

* fix coding style issue

* print full stacktraces

* apply changes to new constructor

* add databatch bcc

* introduce undefined field

* Fix crashes when change provideData to provideDataDesc

It looks like if we want to force conversion from Float32 to Int32 will cause a crash on JVM. Need to be addressed.

* change spacing and revert test

* apply DataDesc on DataBatch

* unit test for NDArrayIter and MXDataiter

* apply changes on CR

* change NDArrayIter and revert the rest

* revert change on examples

* apply final changes

* remove the provideLabelShape

* add TODO about the findings
  • Loading branch information
lanking520 authored and nswamy committed Aug 17, 2018
1 parent c927f1f commit 3d8627c
Show file tree
Hide file tree
Showing 19 changed files with 425 additions and 132 deletions.
11 changes: 11 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,15 @@ object DType extends Enumeration {
case DType.Unknown => 0
}
}
private[mxnet] def getType(dtypeStr: String): DType = {
dtypeStr match {
case "UInt8" => DType.UInt8
case "Int32" => DType.Int32
case "Float16" => DType.Float16
case "Float32" => DType.Float32
case "Float64" => DType.Float64
case _ => throw new IllegalArgumentException(
s"DType: $dtypeStr not found! please set it in DType.scala")
}
}
}
121 changes: 86 additions & 35 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory
import scala.annotation.varargs
import scala.collection.immutable.ListMap
import scala.collection.mutable.ListBuffer

/**
* IO iterators for loading training & validation data
*/
Expand Down Expand Up @@ -110,18 +109,22 @@ object IO {
}

// Convert data into canonical form.
private[mxnet] def initData(data: IndexedSeq[NDArray],
allowEmpty: Boolean,
defaultName: String): IndexedSeq[(String, NDArray)] = {
private[mxnet] def initDataDesc(data: IndexedSeq[NDArray],
allowEmpty: Boolean,
defaultName: String,
defaultDType: DType,
defaultLayout: String): IndexedSeq[(DataDesc, NDArray)] = {
require(data != null)
require(data != IndexedSeq.empty || allowEmpty)
if (data == IndexedSeq.empty) {
IndexedSeq()
} else if (data.length == 1) {
IndexedSeq((defaultName, data(0)))
IndexedSeq((new DataDesc(defaultName, data(0).shape,
defaultDType, defaultLayout), data(0)))
} else {
data.zipWithIndex.map(item => {
(defaultName + "_" + item._2, item._1)
(new DataDesc(defaultName + "_" + item._2, item._1.shape,
defaultDType, defaultLayout), item._1)
}).toIndexedSeq
}
}
Expand All @@ -136,11 +139,28 @@ class DataBatch(val data: IndexedSeq[NDArray],
val pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
val bucketKey: AnyRef = null,
// use ListMap to indicate the order of data/label loading
val bucketKey: AnyRef,
// use DataDesc to indicate the order of data/label loading
// (must match the order of input data/label)
private val providedData: ListMap[String, Shape] = null,
private val providedLabel: ListMap[String, Shape] = null) {
private val providedDataDesc: IndexedSeq[DataDesc],
private val providedLabelDesc: IndexedSeq[DataDesc]) {
// TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
// However, since the data and label can be accessed publicly (no getter and setter)
// the change on this will break BC
def this(data: IndexedSeq[NDArray],
label: IndexedSeq[NDArray],
index: IndexedSeq[Long],
pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
bucketKey: AnyRef = null,
// use ListMap to indicate the order of data/label loading
// (must match the order of input data/label)
providedData: ListMap[String, Shape] = null,
providedLabel: ListMap[String, Shape] = null) {
this(data, label, index, pad, bucketKey,
DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel))
}
/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
Expand All @@ -155,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray],
}

// The name and shape of data
def provideData: ListMap[String, Shape] = providedData
def provideData: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedDataDesc == null) null
else {
providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
temp
}
}

// The name and shape of label
def provideLabel: ListMap[String, Shape] = providedLabel
def provideLabel: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedLabelDesc == null) null
else {
providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
temp
}
}

def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc

def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc

}

object DataBatch {
Expand All @@ -171,8 +210,8 @@ object DataBatch {
private var index: IndexedSeq[Long] = null
private var pad: Int = 0
private var bucketKey: AnyRef = null
private var datatShapes: ListMap[String, Shape] = null
private var labelShapes: ListMap[String, Shape] = null
private var dataDesc: IndexedSeq[DataDesc] = null
private var labelDesc: IndexedSeq[DataDesc] = null

/**
* Set the input data.
Expand Down Expand Up @@ -228,37 +267,27 @@ object DataBatch {

/**
* Provide the shape of a data.
* @param name data name.
* @param shape data shape.
* @param dataDesc DataDescriptor
* @return this.
*/
def provideDataShape(name: String, shape: Shape): Builder = {
if (datatShapes == null) {
datatShapes = ListMap((name, shape))
} else {
datatShapes = datatShapes.updated(name, shape)
}
def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = {
this.dataDesc = dataDesc
this
}

/**
* Provide the shape of a label.
* @param name label name.
* @param shape label shape.
* @param labelDesc LabelDescriptor
* @return this.
*/
def provideLabelShape(name: String, shape: Shape): Builder = {
if (labelShapes == null) {
labelShapes = ListMap((name, shape))
} else {
labelShapes = labelShapes.updated(name, shape)
}
def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = {
this.labelDesc = labelDesc
this
}

def build(): DataBatch = {
require(data != null, "data is required.")
new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes)
new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc)
}
}
}
Expand All @@ -280,7 +309,8 @@ abstract class DataIter extends Iterator[DataBatch] {
*/
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
new DataBatch(getData(), getLabel(), getIndex(), getPad())
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
null, null, null)
}

/**
Expand Down Expand Up @@ -309,11 +339,19 @@ abstract class DataIter extends Iterator[DataBatch] {
def getIndex(): IndexedSeq[Long]

// The name and shape of data provided by this iterator
@deprecated
def provideData: ListMap[String, Shape]

// The name and shape of label provided by this iterator
@deprecated
def provideLabel: ListMap[String, Shape]

// Provide type:DataDesc of the data
def provideDataDesc: IndexedSeq[DataDesc]

// Provide type:DataDesc of the label
def provideLabelDesc: IndexedSeq[DataDesc]

// For bucketing io only
// The bucket key for the default symbol.
def defaultBucketKey: AnyRef = null
Expand All @@ -332,8 +370,9 @@ abstract class DataPack() extends Iterable[DataBatch] {

// Named data desc description contains name, shape, type and other extended attributes.
case class DataDesc(name: String, shape: Shape,
dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") {
require(shape.length == layout.length, ("number of dimensions in shape :%d with" +
dtype: DType = DType.Float32, layout: String = Layout.UNDEFINED) {
require(layout == Layout.UNDEFINED || shape.length == layout.length,
("number of dimensions in shape :%d with" +
" shape: %s should match the length of the layout: %d with layout: %s").
format(shape.length, shape.toString, layout.length, layout))

Expand All @@ -343,6 +382,8 @@ case class DataDesc(name: String, shape: Shape,
}

object DataDesc {

private val logger = LoggerFactory.getLogger(classOf[DataDesc])
/**
* Get the dimension that corresponds to the batch size.
* @param layout layout string. For example, "NCHW".
Expand All @@ -352,9 +393,19 @@ object DataDesc {
* for each data-parallelism device.
*/
def getBatchAxis(layout: Option[String]): Int = {
layout.map(_.indexOf('N')).getOrElse(0)
if (layout.isEmpty|| layout.get == Layout.UNDEFINED) {
logger.warn("Found Undefined Layout, will use default index 0 for batch axis")
0
} else {
if (layout.get.contains('N')) {
layout.get.indexOf("N")
} else {
throw new IllegalArgumentException("no Batch Axis('N') found in Layout!")
}
}
}

@deprecated
implicit def ListMap2Descs(shapes: ListMap[String, Shape]): IndexedSeq[DataDesc] = {
if (shapes != null) {
shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq
Expand Down
35 changes: 35 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

/**
* Layout definition of DataDesc
* N Batch size
* C channels
* H Height
* W Weight
* T sequence length
* __undefined__ default value of Layout
*/
object Layout {
val UNDEFINED = "__undefined__"
val NCHW = "NCHW"
val NTC = "NTC"
val NT = "NT"
val N = "N"
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream

/**
* Scala interface for read/write RecordIO data format
*
* @author Depeng Liang
*
* @param uri, path to recordIO file.
* @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
*/
Expand Down Expand Up @@ -144,7 +141,7 @@ object MXRecordIO {
*
* @author Depeng Liang
*
* @param idx_path, path to index file
* @param idxPath, path to index file
* @param uri, path to recordIO file.
* @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
* @param keyType, data type for keys.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.mxnet.io

import org.apache.mxnet.Base._
import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, WarnIfNotDisposed}
import org.apache.mxnet.DType.DType
import org.apache.mxnet._
import org.apache.mxnet.IO._
import org.slf4j.LoggerFactory

Expand All @@ -41,21 +42,31 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
// fix me if any better way found)
private var currentBatch: DataBatch = null

private val (_provideData: ListMap[String, Shape],
private val (_provideDataDesc: IndexedSeq[DataDesc],
_provideLabelDesc: IndexedSeq[DataDesc],
_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape],
_batchSize: Int) =
_batchSize: Int) = {
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
// properties
val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), data.shape(0))
val res = (
// TODO: need to allow user to specify DType and Layout
IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)),
IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)),
ListMap(dataName -> data.shape),
ListMap(labelName -> label.shape),
data.shape(0))
currentBatch.dispose()
reset()
res
} else {
(null, null, 0)
(null, null, null, null, 0)
}
}


private var disposed = false
protected def isDisposed = disposed
Expand Down Expand Up @@ -101,10 +112,12 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
private def iterNext(): Boolean = {
val next = new RefInt
checkCall(_LIB.mxDataIterNext(handle, next))
currentBatch = null
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
index = getIndex(), pad = getPad())
index = getIndex(), pad = getPad(),
null, null, null)
} else {
currentBatch = null
}
next.value > 0
}
Expand Down Expand Up @@ -152,11 +165,19 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
}

// The name and shape of data provided by this iterator
@deprecated
override def provideData: ListMap[String, Shape] = _provideData

// The name and shape of label provided by this iterator
@deprecated
override def provideLabel: ListMap[String, Shape] = _provideLabel

// Provide type:DataDesc of the data
override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc

// Provide type:DataDesc of the label
override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc

override def hasNext: Boolean = {
if (currentBatch != null) {
true
Expand Down
Loading

0 comments on commit 3d8627c

Please sign in to comment.