From abbe283ef8b1d78b002cb492651f002ae27ba544 Mon Sep 17 00:00:00 2001 From: Lanking Date: Fri, 17 Aug 2018 10:07:30 -0700 Subject: [PATCH] [MXNET-689] add DataDesc type for the Scala Package (#11844) * add dataDesc * Add amend * add changes with dataLayout and labelLayout * add depreciate and example changes * Gan and Customop fixes * change the DType * add one more class to convert Strings to DTypes * convert layout to global * scala style fix * Revert to 8c7d1f8 * fix coding style issue * print full stacktraces * apply changes to new constructor * add databatch bcc * introduce undefined field * Fix crashes when change provideData to provideDataDesc It looks like if we want to force conversion from Float32 to Int32 will cause a crash on JVM. Need to be addressed. * change spacing and revert test * apply DataDesc on DataBatch * unit test for NDArrayIter and MXDataiter * apply changes on CR * change NDArrayIter and revert the rest * revert change on examples * apply final changes * remove the provideLabelShape * add TODO about the findings --- .../main/scala/org/apache/mxnet/DType.scala | 11 ++ .../src/main/scala/org/apache/mxnet/IO.scala | 121 +++++++++++++----- .../main/scala/org/apache/mxnet/Layout.scala | 35 +++++ .../scala/org/apache/mxnet/RecordIO.scala | 5 +- .../org/apache/mxnet/io/MXDataIter.scala | 35 ++++- .../org/apache/mxnet/io/NDArrayIter.scala | 91 +++++++++---- .../org/apache/mxnet/io/PrefetchingIter.scala | 69 ++++++++-- .../org/apache/mxnet/io/ResizeIter.scala | 15 ++- .../test/scala/org/apache/mxnet/IOSuite.scala | 18 ++- .../scala/org/apache/mxnet/ModuleSuite.scala | 6 +- .../apache/mxnetexamples/multitask/Data.scala | 3 - .../multitask/ExampleMultiTask.scala | 31 +++-- .../apache/mxnetexamples/rnn/BucketIo.scala | 54 +++++--- .../mxnet/infer/ObjectDetectorSuite.scala | 8 +- .../apache/mxnet/infer/PredictorSuite.scala | 16 ++- scala-package/pom.xml | 1 + .../mxnet/spark/io/LabeledPointIter.scala | 16 ++- .../mxnet/spark/io/LongLivingDataBatch.scala | 6 +- .../org/apache/mxnet/spark/io/PointIter.scala | 16 ++- 19 files changed, 425 insertions(+), 132 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala index 4458a7c7aeb8..f3a8e8e9a4a5 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala @@ -35,4 +35,15 @@ object DType extends Enumeration { case DType.Unknown => 0 } } + private[mxnet] def getType(dtypeStr: String): DType = { + dtypeStr match { + case "UInt8" => DType.UInt8 + case "Int32" => DType.Int32 + case "Float16" => DType.Float16 + case "Float32" => DType.Float32 + case "Float64" => DType.Float64 + case _ => throw new IllegalArgumentException( + s"DType: $dtypeStr not found! please set it in DType.scala") + } + } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 47fd4eee939a..a1095cf04833 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory import scala.annotation.varargs import scala.collection.immutable.ListMap import scala.collection.mutable.ListBuffer - /** * IO iterators for loading training & validation data */ @@ -110,18 +109,22 @@ object IO { } // Convert data into canonical form. - private[mxnet] def initData(data: IndexedSeq[NDArray], - allowEmpty: Boolean, - defaultName: String): IndexedSeq[(String, NDArray)] = { + private[mxnet] def initDataDesc(data: IndexedSeq[NDArray], + allowEmpty: Boolean, + defaultName: String, + defaultDType: DType, + defaultLayout: String): IndexedSeq[(DataDesc, NDArray)] = { require(data != null) require(data != IndexedSeq.empty || allowEmpty) if (data == IndexedSeq.empty) { IndexedSeq() } else if (data.length == 1) { - IndexedSeq((defaultName, data(0))) + IndexedSeq((new DataDesc(defaultName, data(0).shape, + defaultDType, defaultLayout), data(0))) } else { data.zipWithIndex.map(item => { - (defaultName + "_" + item._2, item._1) + (new DataDesc(defaultName + "_" + item._2, item._1.shape, + defaultDType, defaultLayout), item._1) }).toIndexedSeq } } @@ -136,11 +139,28 @@ class DataBatch(val data: IndexedSeq[NDArray], val pad: Int, // the key for the bucket that should be used for this batch, // for bucketing io only - val bucketKey: AnyRef = null, - // use ListMap to indicate the order of data/label loading + val bucketKey: AnyRef, + // use DataDesc to indicate the order of data/label loading // (must match the order of input data/label) - private val providedData: ListMap[String, Shape] = null, - private val providedLabel: ListMap[String, Shape] = null) { + private val providedDataDesc: IndexedSeq[DataDesc], + private val providedLabelDesc: IndexedSeq[DataDesc]) { + // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)] + // However, since the data and label can be accessed publicly (no getter and setter) + // the change on this will break BC + def this(data: IndexedSeq[NDArray], + label: IndexedSeq[NDArray], + index: IndexedSeq[Long], + pad: Int, + // the key for the bucket that should be used for this batch, + // for bucketing io only + bucketKey: AnyRef = null, + // use ListMap to indicate the order of data/label loading + // (must match the order of input data/label) + providedData: ListMap[String, Shape] = null, + providedLabel: ListMap[String, Shape] = null) { + this(data, label, index, pad, bucketKey, + DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel)) + } /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -155,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray], } // The name and shape of data - def provideData: ListMap[String, Shape] = providedData + def provideData: ListMap[String, Shape] = { + var temp = ListMap[String, Shape]() + if (providedDataDesc == null) null + else { + providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape)) + temp + } + } // The name and shape of label - def provideLabel: ListMap[String, Shape] = providedLabel + def provideLabel: ListMap[String, Shape] = { + var temp = ListMap[String, Shape]() + if (providedLabelDesc == null) null + else { + providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape)) + temp + } + } + + def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc + + def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc + } object DataBatch { @@ -171,8 +210,8 @@ object DataBatch { private var index: IndexedSeq[Long] = null private var pad: Int = 0 private var bucketKey: AnyRef = null - private var datatShapes: ListMap[String, Shape] = null - private var labelShapes: ListMap[String, Shape] = null + private var dataDesc: IndexedSeq[DataDesc] = null + private var labelDesc: IndexedSeq[DataDesc] = null /** * Set the input data. @@ -228,37 +267,27 @@ object DataBatch { /** * Provide the shape of a data. - * @param name data name. - * @param shape data shape. + * @param dataDesc DataDescriptor * @return this. */ - def provideDataShape(name: String, shape: Shape): Builder = { - if (datatShapes == null) { - datatShapes = ListMap((name, shape)) - } else { - datatShapes = datatShapes.updated(name, shape) - } + def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = { + this.dataDesc = dataDesc this } /** * Provide the shape of a label. - * @param name label name. - * @param shape label shape. + * @param labelDesc LabelDescriptor * @return this. */ - def provideLabelShape(name: String, shape: Shape): Builder = { - if (labelShapes == null) { - labelShapes = ListMap((name, shape)) - } else { - labelShapes = labelShapes.updated(name, shape) - } + def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = { + this.labelDesc = labelDesc this } def build(): DataBatch = { require(data != null, "data is required.") - new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes) + new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc) } } } @@ -280,7 +309,8 @@ abstract class DataIter extends Iterator[DataBatch] { */ @throws(classOf[NoSuchElementException]) def next(): DataBatch = { - new DataBatch(getData(), getLabel(), getIndex(), getPad()) + new DataBatch(getData(), getLabel(), getIndex(), getPad(), + null, null, null) } /** @@ -309,11 +339,19 @@ abstract class DataIter extends Iterator[DataBatch] { def getIndex(): IndexedSeq[Long] // The name and shape of data provided by this iterator + @deprecated def provideData: ListMap[String, Shape] // The name and shape of label provided by this iterator + @deprecated def provideLabel: ListMap[String, Shape] + // Provide type:DataDesc of the data + def provideDataDesc: IndexedSeq[DataDesc] + + // Provide type:DataDesc of the label + def provideLabelDesc: IndexedSeq[DataDesc] + // For bucketing io only // The bucket key for the default symbol. def defaultBucketKey: AnyRef = null @@ -332,8 +370,9 @@ abstract class DataPack() extends Iterable[DataBatch] { // Named data desc description contains name, shape, type and other extended attributes. case class DataDesc(name: String, shape: Shape, - dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") { - require(shape.length == layout.length, ("number of dimensions in shape :%d with" + + dtype: DType = DType.Float32, layout: String = Layout.UNDEFINED) { + require(layout == Layout.UNDEFINED || shape.length == layout.length, + ("number of dimensions in shape :%d with" + " shape: %s should match the length of the layout: %d with layout: %s"). format(shape.length, shape.toString, layout.length, layout)) @@ -343,6 +382,8 @@ case class DataDesc(name: String, shape: Shape, } object DataDesc { + + private val logger = LoggerFactory.getLogger(classOf[DataDesc]) /** * Get the dimension that corresponds to the batch size. * @param layout layout string. For example, "NCHW". @@ -352,9 +393,19 @@ object DataDesc { * for each data-parallelism device. */ def getBatchAxis(layout: Option[String]): Int = { - layout.map(_.indexOf('N')).getOrElse(0) + if (layout.isEmpty|| layout.get == Layout.UNDEFINED) { + logger.warn("Found Undefined Layout, will use default index 0 for batch axis") + 0 + } else { + if (layout.get.contains('N')) { + layout.get.indexOf("N") + } else { + throw new IllegalArgumentException("no Batch Axis('N') found in Layout!") + } + } } + @deprecated implicit def ListMap2Descs(shapes: ListMap[String, Shape]): IndexedSeq[DataDesc] = { if (shapes != null) { shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala new file mode 100644 index 000000000000..cb75dbc40803 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +/** + * Layout definition of DataDesc + * N Batch size + * C channels + * H Height + * W Weight + * T sequence length + * __undefined__ default value of Layout + */ +object Layout { + val UNDEFINED = "__undefined__" + val NCHW = "NCHW" + val NTC = "NTC" + val NT = "NT" + val N = "N" +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala index ee3e950512e7..578f00a76f9a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala @@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream /** * Scala interface for read/write RecordIO data format - * - * @author Depeng Liang - * * @param uri, path to recordIO file. * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing. */ @@ -144,7 +141,7 @@ object MXRecordIO { * * @author Depeng Liang * - * @param idx_path, path to index file + * @param idxPath, path to index file * @param uri, path to recordIO file. * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing. * @param keyType, data type for keys. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 2a0c333ebf10..f7f858deb82d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -18,7 +18,8 @@ package org.apache.mxnet.io import org.apache.mxnet.Base._ -import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, WarnIfNotDisposed} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.apache.mxnet.IO._ import org.slf4j.LoggerFactory @@ -41,21 +42,31 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, // fix me if any better way found) private var currentBatch: DataBatch = null - private val (_provideData: ListMap[String, Shape], + private val (_provideDataDesc: IndexedSeq[DataDesc], + _provideLabelDesc: IndexedSeq[DataDesc], + _provideData: ListMap[String, Shape], _provideLabel: ListMap[String, Shape], - _batchSize: Int) = + _batchSize: Int) = { if (hasNext) { iterNext() val data = currentBatch.data(0) val label = currentBatch.label(0) // properties - val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), data.shape(0)) + val res = ( + // TODO: need to allow user to specify DType and Layout + IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)), + IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)), + ListMap(dataName -> data.shape), + ListMap(labelName -> label.shape), + data.shape(0)) currentBatch.dispose() reset() res } else { - (null, null, 0) + (null, null, null, null, 0) } + } + private var disposed = false protected def isDisposed = disposed @@ -101,10 +112,12 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, private def iterNext(): Boolean = { val next = new RefInt checkCall(_LIB.mxDataIterNext(handle, next)) - currentBatch = null if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), - index = getIndex(), pad = getPad()) + index = getIndex(), pad = getPad(), + null, null, null) + } else { + currentBatch = null } next.value > 0 } @@ -152,11 +165,19 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = _provideData // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = _provideLabel + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def hasNext: Boolean = { if (currentBatch != null) { true diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 10461315c198..e6be0ad02f83 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -20,6 +20,7 @@ package org.apache.mxnet.io import java.util.NoSuchElementException import org.apache.mxnet.Base._ +import org.apache.mxnet.DType.DType import org.apache.mxnet._ import org.slf4j.LoggerFactory @@ -39,35 +40,35 @@ import scala.collection.immutable.ListMap * the size of data does not match batch_size. Roll over is intended * for training and can cause problems if used for prediction. */ -class NDArrayIter(data: IndexedSeq[(String, NDArray)], - label: IndexedSeq[(String, NDArray)], +class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], + label: IndexedSeq[(DataDesc, NDArray)], private val dataBatchSize: Int, shuffle: Boolean, lastBatchHandle: String) extends DataIter { /** - * @param data Specify the data. Data names will be data_0, data_1, ..., etc. - * @param label Same as data, but is not fed to the model during testing. - * Label names will be label_0, label_1, ..., etc. - * @param dataBatchSize Batch Size - * @param shuffle Whether to shuffle the data - * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch - * - * This iterator will pad, discard or roll over the last batch if - * the size of data does not match batch_size. Roll over is intended - * for training and can cause problems if used for prediction. - */ + * @param data Specify the data. Data names will be data_0, data_1, ..., etc. + * @param label Same as data, but is not fed to the model during testing. + * Label names will be label_0, label_1, ..., etc. + * @param dataBatchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label") { - this(IO.initData(data, allowEmpty = false, dataName), - IO.initData(label, allowEmpty = true, labelName), + this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, Layout.UNDEFINED), + IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED), dataBatchSize, shuffle, lastBatchHandle) } private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) - val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = { + val (initData: IndexedSeq[(DataDesc, NDArray)], initLabel: IndexedSeq[(DataDesc, NDArray)]) = { // data should not be null and size > 0 require(data != null && data.size > 0, "data should not be null and data.size should not be zero") @@ -101,20 +102,30 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], private var cursor = -dataBatchSize private val (_provideData: ListMap[String, Shape], - _provideLabel: ListMap[String, Shape]) = { + _provideLabel: ListMap[String, Shape], + _provideDataDesc: IndexedSeq[DataDesc], + _provideLabelDesc: IndexedSeq[DataDesc]) = { val pData = ListMap.empty[String, Shape] ++ initData.map(getShape) val pLabel = ListMap.empty[String, Shape] ++ initLabel.map(getShape) - (pData, pLabel) + val pDData = IndexedSeq.empty[DataDesc] ++ initData.map(ele => { + val temp = getShape(ele) + new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout) + }) + val pDLabel = IndexedSeq.empty[DataDesc] ++ initLabel.map(ele => { + val temp = getShape(ele) + new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout) + }) + (pData, pLabel, pDData, pDLabel) } /** * get shape via dataBatchSize * @param dataItem */ - private def getShape(dataItem: (String, NDArray)): (String, Shape) = { + private def getShape(dataItem: (DataDesc, NDArray)): (String, Shape) = { val len = dataItem._2.shape.size val newShape = dataItem._2.shape.slice(1, len) - (dataItem._1, Shape(Array[Int](dataBatchSize)) ++ newShape) + (dataItem._1.name, Shape(Array[Int](dataBatchSize)) ++ newShape) } @@ -148,7 +159,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], override def next(): DataBatch = { if (hasNext) { cursor += dataBatchSize - new DataBatch(getData(), getLabel(), getIndex(), getPad()) + new DataBatch(getData(), getLabel(), getIndex(), getPad(), + null, null, null) } else { throw new NoSuchElementException } @@ -172,7 +184,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], } } - private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = { + private def _getData(data: IndexedSeq[(DataDesc, NDArray)]): IndexedSeq[NDArray] = { require(cursor < numData, "DataIter needs reset.") if (data == null) { null @@ -223,12 +235,21 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], } } + // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = _provideData // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = _provideLabel + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def batchSize: Int = dataBatchSize } @@ -238,8 +259,8 @@ object NDArrayIter { * Builder class for NDArrayIter. */ class Builder() { - private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty - private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty + private var data: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty + private var label: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty private var dataBatchSize: Int = 1 private var lastBatchHandle: String = "pad" @@ -250,7 +271,8 @@ object NDArrayIter { * @return The builder object itself. */ def addData(name: String, data: NDArray): Builder = { - this.data = this.data ++ IndexedSeq((name, data)) + this.data = this.data ++ IndexedSeq((new DataDesc(name, + data.shape, DType.Float32, Layout.UNDEFINED), data)) this } @@ -261,7 +283,24 @@ object NDArrayIter { * @return The builder object itself. */ def addLabel(name: String, label: NDArray): Builder = { - this.label = this.label ++ IndexedSeq((name, label)) + this.label = this.label ++ IndexedSeq((new DataDesc(name, + label.shape, DType.Float32, Layout.UNDEFINED), label)) + this + } + + /** + * Add one data input with its DataDesc + */ + def addDataWithDesc(dataDesc: DataDesc, data: NDArray): Builder = { + this.data = this.data ++ IndexedSeq((dataDesc, data)) + this + } + + /** + * Add one label input with its DataDesc + */ + def addLabelWithDesc(labelDesc: DataDesc, label: NDArray): Builder = { + this.data = this.data ++ IndexedSeq((labelDesc, label)) this } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index c0c0d1793b54..e59e3706317d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -17,10 +17,12 @@ package org.apache.mxnet.io -import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape} +import org.apache.mxnet._ import org.slf4j.LoggerFactory import java.util.concurrent.Semaphore +import org.apache.mxnet.DType.DType + import scala.collection.immutable.ListMap /** @@ -68,6 +70,42 @@ class PrefetchingIter( } } + private val _provideDataDesc: IndexedSeq[DataDesc] = { + if (dataNames == null) { + iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } else { + iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2)) + .map(m => + m._1.map(t => + new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout) + ) + ) + .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } + } + + private val _provideLabelDesc: IndexedSeq[DataDesc] = { + if (labelNames == null) { + iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } else { + iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2)) + .map(m => + m._1.map(t => + new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout) + ) + ) + .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => + acc ++ elem + } + } + } + private val _batchSize: Int = this._provideData.toList(0)._2(0) private val dataReady: IndexedSeq[Semaphore] = (0 until iters.length).map(i => new Semaphore(0)) @@ -132,19 +170,27 @@ class PrefetchingIter( */ override def getIndex(): IndexedSeq[Long] = currentBatch.index - // The name and shape of label provided by this iterator - override def provideLabel: ListMap[String, Shape] = this._provideLabel - /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): Int = this.currentBatch.pad + // The name and shape of label provided by this iterator + @deprecated + override def provideLabel: ListMap[String, Shape] = this._provideLabel + // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = this._provideData + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def hasNext: Boolean = { for (e <- dataReady) e.acquire() if (nextBatch(0) == null) { @@ -161,9 +207,10 @@ class PrefetchingIter( val datas = for (batch <- nextBatch) yield batch.data val labels = for (batch <- nextBatch) yield batch.label currentBatch = new DataBatch(datas.toIndexedSeq.flatten, - labels.toIndexedSeq.flatten, - nextBatch(0).index, - nextBatch(0).pad) + labels.toIndexedSeq.flatten, + nextBatch(0).index, + nextBatch(0).pad, + null, null, null) for (e <- dataTaken) e.release() true } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala index 75d88d1ae72f..e840af9395f7 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala @@ -19,7 +19,8 @@ package org.apache.mxnet.io import java.util.NoSuchElementException -import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.slf4j.LoggerFactory import scala.collection.immutable.ListMap @@ -133,12 +134,24 @@ class ResizeIter( } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = { dataIter.provideData } // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { dataIter.provideLabel } + + // The name and shape of data provided by this iterator + override def provideDataDesc: IndexedSeq[DataDesc] = { + dataIter.provideDataDesc + } + + // The name and shape of label provided by this iterator + override def provideLabelDesc: IndexedSeq[DataDesc] = { + dataIter.provideLabelDesc + } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 1b922b3c05b6..2ec6f668dbcc 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -243,7 +243,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val batchLabel = NDArray.ones(Shape(Array(128, 1))) // test pad - val dataIter0 = new NDArrayIter(data, label, 128, false, "pad") + val dataIter0 = new NDArrayIter(data, label, 128, false, "pad", + dataName = "data", labelName = "label") var batchCount = 0 val nBatch0 = 8 while(dataIter0.hasNext) { @@ -277,7 +278,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) // test empty label (for prediction) - val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard") + val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, shuffle = false, + lastBatchHandle = "discard") batchCount = 0 while(dataIter2.hasNext) { val tBatch = dataIter2.next() @@ -289,5 +291,17 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(batchCount === nBatch1) assert(dataIter2.initLabel == IndexedSeq.empty) + + // test implementation with DataDesc + val dataIter3 = new NDArrayIter( + IO.initDataDesc(data, false, "data", DType.Float32, Layout.NTC), + IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT), + 128, false, "pad") + val dataDesc = dataIter3.provideDataDesc + val labelDesc = dataIter3.provideLabelDesc + assert(dataDesc(0).dtype == DType.Float32) + assert(dataDesc(0).layout == Layout.NTC) + assert(labelDesc(0).dtype == DType.Int32) + assert(labelDesc(0).layout == Layout.NT) } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 8234568d7d9f..88e314e2a72c 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -24,7 +24,7 @@ import org.apache.mxnet.io._ class ModuleSuite extends FunSuite with BeforeAndAfterAll { test ("model dtype") { - val dType = DType.Float16 + val dType = DType.Float32 val dShape = Shape(3, 8, 7) var sym = Symbol.Variable("data") @@ -196,8 +196,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // create module val mod = new Module(x, contexts = Array(Context.cpu())) - mod.bind(dataShapes = trainData.provideData, - Option(trainData.provideLabel)) + mod.bind(dataShapes = trainData.provideDataDesc, + Option(trainData.provideLabelDesc)) val argParamsCorrect = Map( "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)), "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)), diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala index bb17046b8b2b..068aa6314f89 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala @@ -21,9 +21,6 @@ import org.apache.mxnet.Shape import org.apache.mxnet.IO import org.apache.mxnet.DataIter -/** - * @author Depeng Liang - */ object Data { // return train and val iterators for mnist diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 9df2bcc0566d..825e46596755 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -25,14 +25,9 @@ import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ import org.apache.commons.io.FileUtils -import org.apache.mxnet.Symbol -import org.apache.mxnet.DataIter -import org.apache.mxnet.DataBatch -import org.apache.mxnet.NDArray -import org.apache.mxnet.Shape -import org.apache.mxnet.EvalMetric -import org.apache.mxnet.Context -import org.apache.mxnet.Xavier + +import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, NDArray, Shape, Symbol, Xavier} +import org.apache.mxnet.DType.DType import org.apache.mxnet.optimizer.RMSProp import org.apache.mxnet.Executor import org.apache.mxnetexamples.Util @@ -70,9 +65,9 @@ object ExampleMultiTask { val batch = this.dataIter.next() val label = batch.label(0) new DataBatch(batch.data, - IndexedSeq(label, label), - batch.index, - batch.pad) + IndexedSeq(label, label), + batch.index, + batch.pad, null, null, null) } else { throw new NoSuchElementException } @@ -107,6 +102,7 @@ object ExampleMultiTask { override def getIndex(): IndexedSeq[Long] = this.dataIter.getIndex() // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { val provideLabel = this.dataIter.provideLabel.toArray // Different labels should be used here for actual application @@ -114,6 +110,16 @@ object ExampleMultiTask { "softmax2_label" -> provideLabel(0)._2) } + // The name and shape of label provided by this iterator + override def provideLabelDesc: IndexedSeq[DataDesc] = { + val head = this.dataIter.provideLabelDesc(0) + // Different labels should be used here for actual application + IndexedSeq( + new DataDesc("softmax1_label", head.shape, head.dtype, head.layout), + new DataDesc("softmax2_label", head.shape, head.dtype, head.layout) + ) + } + /** * get the number of padding examples * in current batch @@ -122,8 +128,11 @@ object ExampleMultiTask { override def getPad(): Int = this.dataIter.getPad() // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = this.dataIter.provideData + override def provideDataDesc: IndexedSeq[DataDesc] = this.dataIter.provideDataDesc + override def hasNext: Boolean = this.dataIter.hasNext } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index f0eae6890c52..d4b17074d48c 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -18,17 +18,16 @@ package org.apache.mxnetexamples.rnn -import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.slf4j.LoggerFactory + import scala.collection.immutable.ListMap import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.util.Random import scala.collection.mutable -/** - * @author Depeng Liang - */ object BucketIo { type Text2Id = (String, Map[String, Int]) => Array[Int] @@ -92,11 +91,14 @@ object BucketIo { } class BucketSentenceIter( - path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int], - _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))], - seperateChar: String = " ", text2Id: Text2Id = defaultText2Id, + path: String, + vocab: Map[String, Int], + var buckets: IndexedSeq[Int], + _batchSize: Int, + private val initStates: IndexedSeq[(String, (Int, Int))], + seperateChar: String = " ", + text2Id: Text2Id = defaultText2Id, readContent: ReadContent = defaultReadContent) extends DataIter { - private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) private val content = readContent(path) @@ -165,8 +167,22 @@ object BucketIo { private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey)) tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2)) } + private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, _defaultBucketKey)) + private val _provideDataDesc = { + // TODO: need to allow user to specify DType and Layout + val tmp = IndexedSeq(new DataDesc("data", + Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED)) + tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), + DType.Float32, Layout.UNDEFINED)) + } + + private val _provideLabelDesc = IndexedSeq( + // TODO: need to allow user to specify DType and Layout + new DataDesc("softmax_label", + Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED)) + private var iBucket = 0 override def next(): DataBatch = { @@ -228,19 +244,27 @@ object BucketIo { */ override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]() - // The name and shape of label provided by this iterator - override def provideLabel: ListMap[String, Shape] = this._provideLabel - /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): Int = 0 + // The name and shape of label provided by this iterator + @deprecated + override def provideLabel: ListMap[String, Shape] = this._provideLabel + // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = this._provideData + // Provide type:DataDesc of the data + override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc + + // Provide type:DataDesc of the label + override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc + override def hasNext: Boolean = { iBucket < bucketPlan.length } diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala index 8160f0f6eb41..39139f8d3d2e 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala @@ -19,6 +19,8 @@ package org.apache.mxnet.infer // scalastyle:off import java.awt.image.BufferedImage + +import org.apache.mxnet.{DType, Layout} // scalastyle:on import org.apache.mxnet.Context import org.apache.mxnet.DataDesc @@ -69,7 +71,8 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll { } test("objectDetectWithInputImage") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), + DType.Float32, Layout.NCHW)) val inputImage = new BufferedImage(512, 512, BufferedImage.TYPE_INT_RGB) val testObjectDetector: ObjectDetector = new MyObjectDetector(modelPath, inputDescriptor) @@ -109,7 +112,8 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll { } test("objectDetectWithBatchImages") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), + DType.Float32, Layout.NCHW)) val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB) val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage) diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala index 53fd7f310689..509ffb35db8d 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala @@ -19,7 +19,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.module.{BaseModule, Module} -import org.apache.mxnet.{DataDesc, NDArray, Shape} +import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape} import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -40,15 +40,17 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { } test("PredictorSuite-testPredictorConstruction") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2), + layout = Layout.NCHW)) val mockPredictor = new MyPredictor("xyz", inputDescriptor) assert(mockPredictor.getBatchSize == 1) assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N')) - val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)), - new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2), + layout = Layout.NCHW), + new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW)) assertThrows[IllegalArgumentException] { val mockPredictor = new MyPredictor("xyz", inputDescriptor2) @@ -63,7 +65,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { test("PredictorSuite-testWithFlatArrays") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), + layout = Layout.NCHW)) val inputData = Array.fill[Float](12)(1) // this will disposed at the end of the predict call on Predictor. @@ -89,7 +92,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { } test("PredictorSuite-testWithNDArray") { - val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), + layout = Layout.NCHW)) val inputData = NDArray.ones(Shape(1, 3, 2, 2)) // this will disposed at the end of the predict call on Predictor. diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 3511f4acfffd..c221b4721d81 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -231,6 +231,7 @@ ${skipTests} ${project.build.directory}/surefire-reports . + F WDF TestSuite.txt diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index adc723ecdacb..bf1b26e4b48d 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.spark.io -import org.apache.mxnet.{DataBatch, NDArray, Shape, DataIter} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.apache.spark.mllib.regression.LabeledPoint import scala.collection.immutable.ListMap @@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer /** * A helper converter for LabeledPoint - * @author Yizhi Liu */ class LabeledPointIter private[mxnet]( private val points: Iterator[LabeledPoint], @@ -115,15 +115,27 @@ class LabeledPointIter private[mxnet]( } // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { ListMap(labelName -> Shape(_batchSize)) } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = { ListMap(dataName -> dataShape) } + override def provideDataDesc: IndexedSeq[DataDesc] = { + // TODO: need to allow user to specify DType and Layout + IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, Layout.UNDEFINED)) + } + + override def provideLabelDesc: IndexedSeq[DataDesc] = { + // TODO: need to allow user to specify DType and Layout + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), DType.Float32, Layout.UNDEFINED)) + } + /** * Get the number of padding examples * in current batch diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index 339f7e2d76ca..e3272a4066b5 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.spark.io -import org.apache.mxnet.{NDArray, DataBatch} +import org.apache.mxnet.DType.DType +import org.apache.mxnet.{DataBatch, NDArray} /** * Dispose only when 'disposeForce' called @@ -27,7 +28,8 @@ class LongLivingDataBatch( override val data: IndexedSeq[NDArray], override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], - override val pad: Int) extends DataBatch(data, label, index, pad) { + override val pad: Int) extends DataBatch(data, label, index, pad, + null, null, null) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index 21329291cfb5..a955ee74e7e2 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.spark.io -import org.apache.mxnet.{NDArray, DataBatch, DataIter, Shape} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import org.apache.spark.mllib.linalg.Vector import scala.collection.immutable.ListMap @@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer /** * A temporary helper implementation for predicting Vectors - * @author Yizhi Liu */ class PointIter private[mxnet]( private val points: Iterator[Vector], @@ -114,15 +114,27 @@ class PointIter private[mxnet]( } // The name and shape of label provided by this iterator + @deprecated override def provideLabel: ListMap[String, Shape] = { ListMap(labelName -> Shape(_batchSize)) } // The name and shape of data provided by this iterator + @deprecated override def provideData: ListMap[String, Shape] = { ListMap(dataName -> dataShape) } + override def provideDataDesc: IndexedSeq[DataDesc] = { + // TODO: Make DType, Layout configurable + IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, Layout.UNDEFINED)) + } + + override def provideLabelDesc: IndexedSeq[DataDesc] = { + IndexedSeq(new DataDesc(labelName, Shape(_batchSize), + DType.Float32, Layout.UNDEFINED)) + } + /** * Get the number of padding examples * in current batch