diff --git a/Makefile b/Makefile
index a4b41b8d8371..fe2df2c20afa 100644
--- a/Makefile
+++ b/Makefile
@@ -606,7 +606,7 @@ scalaclean:
scalapkg:
(cd $(ROOTDIR)/scala-package; \
- mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE) -Dcxx="$(CXX)" \
+ mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE),integrationtest -Dcxx="$(CXX)" \
-Dbuild.platform="$(SCALA_PKG_PROFILE)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" \
diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index ea3a2d68c9f4..6e2d8d6e9cc7 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -86,7 +86,10 @@
maven-surefire-plugin
2.22.0
- false
+
+ -Djava.library.path=${project.parent.basedir}/native/${platform}/target
+
+ ${skipTests}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
index 5f0caedcc402..2f4f3e6409ed 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -42,6 +42,5 @@ object Context {
val gpu: Context = org.apache.mxnet.Context.gpu()
val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
val devstr2type = org.apache.mxnet.Context.devstr2type.asJava
-
def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
new file mode 100644
index 000000000000..c77b440d8802
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.javaapi.DType.DType
+
+import collection.JavaConverters._
+
+@AddJNDArrayAPIs(false)
+object NDArray {
+ implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)
+
+ implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
+
+ def waitall(): Unit = org.apache.mxnet.NDArray.waitall()
+
+ def onehotEncode(indices: NDArray, out: NDArray): NDArray
+ = org.apache.mxnet.NDArray.onehotEncode(indices, out)
+
+ def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
+ def empty(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+ def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+ = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+ def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
+ def zeros(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+ def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+ = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+ def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
+ def ones(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+ def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+ = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+ def full(shape: Shape, value: Float, ctx: Context): NDArray
+ = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
+ def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+
+ def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+
+ def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+
+ def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+ def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+
+ def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+ def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+
+ def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+ def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+
+ def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+ def greaterEqual(lhs: NDArray, rhs: Float): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+
+ def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+ def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+
+ def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+ def lesserEqual(lhs: NDArray, rhs: Float): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+
+ def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
+ = org.apache.mxnet.NDArray.array(
+ sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+
+ def arange(start: Float, stop: Float, step: Float, repeat: Int,
+ ctx: Context, dType: DType.DType): NDArray =
+ org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
+}
+
+class NDArray(val nd : org.apache.mxnet.NDArray ) {
+
+ def this(arr : Array[Float], shape : Shape, ctx : Context) = {
+ this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+ }
+
+ def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : Context) = {
+ this(NDArray.array(arr, shape, ctx))
+ }
+
+ def serialize() : Array[Byte] = nd.serialize()
+
+ def dispose() : Unit = nd.dispose()
+ def disposeDeps() : NDArray = nd.disposeDepsExcept()
+ // def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()
+
+ def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)
+
+ def slice (i : Int) : NDArray = nd.slice(i)
+
+ def at(idx : Int) : NDArray = nd.at(idx)
+
+ def T : NDArray = nd.T
+
+ def dtype : DType = nd.dtype
+
+ def asType(dtype : DType) : NDArray = nd.asType(dtype)
+
+ def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)
+
+ def waitToRead(): Unit = nd.waitToRead()
+
+ def context : Context = nd.context
+
+ def set(value : Float) : NDArray = nd.set(value)
+ def set(other : NDArray) : NDArray = nd.set(other)
+ def set(other : Array[Float]) : NDArray = nd.set(other)
+
+ def add(other : NDArray) : NDArray = this.nd + other.nd
+ def add(other : Float) : NDArray = this.nd + other
+ def _add(other : NDArray) : NDArray = this.nd += other
+ def _add(other : Float) : NDArray = this.nd += other
+ def subtract(other : NDArray) : NDArray = this.nd - other
+ def subtract(other : Float) : NDArray = this.nd - other
+ def _subtract(other : NDArray) : NDArray = this.nd -= other
+ def _subtract(other : Float) : NDArray = this.nd -= other
+ def multiply(other : NDArray) : NDArray = this.nd * other
+ def multiply(other : Float) : NDArray = this.nd * other
+ def _multiply(other : NDArray) : NDArray = this.nd *= other
+ def _multiply(other : Float) : NDArray = this.nd *= other
+ def div(other : NDArray) : NDArray = this.nd / other
+ def div(other : Float) : NDArray = this.nd / other
+ def _div(other : NDArray) : NDArray = this.nd /= other
+ def _div(other : Float) : NDArray = this.nd /= other
+ def pow(other : NDArray) : NDArray = this.nd ** other
+ def pow(other : Float) : NDArray = this.nd ** other
+ def _pow(other : NDArray) : NDArray = this.nd **= other
+ def _pow(other : Float) : NDArray = this.nd **= other
+ def mod(other : NDArray) : NDArray = this.nd % other
+ def mod(other : Float) : NDArray = this.nd % other
+ def _mod(other : NDArray) : NDArray = this.nd %= other
+ def _mod(other : Float) : NDArray = this.nd %= other
+ def greater(other : NDArray) : NDArray = this.nd > other
+ def greater(other : Float) : NDArray = this.nd > other
+ def greaterEqual(other : NDArray) : NDArray = this.nd >= other
+ def greaterEqual(other : Float) : NDArray = this.nd >= other
+ def lesser(other : NDArray) : NDArray = this.nd < other
+ def lesser(other : Float) : NDArray = this.nd < other
+ def lesserEqual(other : NDArray) : NDArray = this.nd <= other
+ def lesserEqual(other : Float) : NDArray = this.nd <= other
+
+ def toArray : Array[Float] = nd.toArray
+
+ def toScalar : Float = nd.toScalar
+
+ def copyTo(other : NDArray) : NDArray = nd.copyTo(other)
+
+ def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)
+
+ def copy() : NDArray = copyTo(this.context)
+
+ def shape : Shape = nd.shape
+
+ def size : Int = shape.product
+
+ def asInContext(context: Context): NDArray = nd.asInContext(context)
+
+ override def equals(obj: Any): Boolean = nd.equals(obj)
+ override def hashCode(): Int = nd.hashCode
+}
+
+object NDArrayFuncReturn {
+ implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn)
+ : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn
+ implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn)
+ : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn)
+}
+
+private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) {
+ def head : NDArray = ndFuncReturn.head
+ def get : NDArray = ndFuncReturn.get
+ def apply(i : Int) : NDArray = ndFuncReturn.apply(i)
+ // TODO: Add JavaNDArray operational stuff
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
new file mode 100644
index 000000000000..a9bad83f62d6
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertTrue;
+
+public class NDArrayTest {
+ @Test
+ public void testCreateNDArray() {
+ NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f},
+ new Shape(new int[]{1, 3}),
+ new Context("cpu", 0));
+ int[] arr = new int[]{1, 3};
+ assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+ assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f);
+ List list = Arrays.asList(1.0f, 2.0f, 3.0f);
+ // Second way creating NDArray
+ nd = NDArray.array(list,
+ new Shape(new int[]{1, 3}),
+ new Context("cpu", 0));
+ assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+ }
+
+ @Test
+ public void testZeroOneEmpty(){
+ NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100});
+ NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100});
+ NDArray empty = NDArray.empty(new Context("cpu", 0), new int[]{100, 100});
+ int[] arr = new int[]{100, 100};
+ assertTrue(Arrays.equals(ones.shape().toArray(), arr));
+ assertTrue(Arrays.equals(zeros.shape().toArray(), arr));
+ assertTrue(Arrays.equals(empty.shape().toArray(), arr));
+ }
+
+ @Test
+ public void testComparison(){
+ NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
+ NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
+ nd = nd.add(nd2);
+ float[] greater = new float[]{1, 1, 1};
+ assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater));
+ nd = nd.subtract(nd2);
+ nd = nd.subtract(nd2);
+ float[] lesser = new float[]{0, 0, 0};
+ assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
+ }
+
+ @Test
+ public void testGenerated(){
+ NDArray$ NDArray = NDArray$.MODULE$;
+ float[] arr = new float[]{1.0f, 2.0f, 3.0f};
+ NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
+ float result = NDArray.norm(nd).invoke().get().toArray()[0];
+ float cal = 0.0f;
+ for (float ele : arr) {
+ cal += ele * ele;
+ }
+ cal = (float) Math.sqrt(cal);
+ assertTrue(Math.abs(result - cal) < 1e-5);
+ NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
+ NDArray.dot(nd, nd).setout(dotResult).invoke().get();
+ assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
+ }
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
new file mode 100644
index 000000000000..c530c730a449
--- /dev/null
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.init.Base._
+import org.apache.mxnet.utils.CToScalaUtils
+
+import scala.annotation.StaticAnnotation
+import scala.collection.mutable.ListBuffer
+import scala.language.experimental.macros
+import scala.reflect.macros.blackbox
+
+private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) = macro JavaNDArrayMacro.typeSafeAPIDefs
+}
+
+private[mxnet] object JavaNDArrayMacro {
+ case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
+ case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
+
+ // scalastyle:off havetype
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+ typeSafeAPIImpl(c)(annottees: _*)
+ }
+ // scalastyle:off havetype
+
+ private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+
+ private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+ import c.universe._
+
+ val isContrib: Boolean = c.prefix.tree match {
+ case q"new AddJNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
+ }
+ // Defines Operators that should not generated
+ val notGenerated = Set("Custom")
+
+ val newNDArrayFunctions = {
+ if (isContrib) ndarrayFunctions.filter(
+ func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
+ else ndarrayFunctions.filterNot(_.name.startsWith("_"))
+ }.filterNot(ele => notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => {
+ // Pattern matching for not generating depreciated method
+ if (ele._2.length == 1) ele._2.head
+ else {
+ if (ele._2.head.name.head.isLower) ele._2.head
+ else ele._2.last
+ }
+ })
+
+ val functionDefs = ListBuffer[DefDef]()
+ val classDefs = ListBuffer[ClassDef]()
+
+ newNDArrayFunctions.foreach { ndarrayfunction =>
+
+ // Construct argument field with all required args
+ var argDef = ListBuffer[String]()
+ // Construct Optional Arg
+ var OptionArgDef = ListBuffer[String]()
+ // Construct function Implementation field (e.g norm)
+ var impl = ListBuffer[String]()
+ impl += "val map = scala.collection.mutable.Map[String, Any]()"
+ // scalastyle:off
+ impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
+ // scalastyle:on
+ // Construct Class Implementation (e.g normBuilder)
+ var classImpl = ListBuffer[String]()
+ ndarrayfunction.listOfArgs.foreach({ ndarrayArg =>
+ // var is a special word used to define variable in Scala,
+ // need to changed to something else in order to make it work
+ var currArgName = ndarrayArg.argName match {
+ case "var" => "vari"
+ case "type" => "typeOf"
+ case _ => ndarrayArg.argName
+ }
+ if (ndarrayArg.isOptional) {
+ OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null"
+ val tempDef = s"def set$currArgName($currArgName : ${ndarrayArg.argType})"
+ val tempImpl = s"this.$currArgName = $currArgName\nthis"
+ classImpl += s"$tempDef = {$tempImpl}"
+ } else {
+ argDef += s"$currArgName : ${ndarrayArg.argType}"
+ }
+ // NDArray arg implementation
+ val returnType = "org.apache.mxnet.javaapi.NDArray"
+ val base =
+ if (ndarrayArg.argType.equals(returnType)) {
+ s"args += this.$currArgName"
+ } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){
+ s"this.$currArgName.foreach(args+=_)"
+ } else {
+ "map(\"" + ndarrayArg.argName + "\") = this." + currArgName
+ }
+ impl.append(
+ if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base"
+ else base
+ )
+ })
+ // add default out parameter
+ classImpl +=
+ "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}"
+ impl += "if (this.out != null) map(\"out\") = this.out"
+ OptionArgDef += "private var out : org.apache.mxnet.NDArray = null"
+ val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn"
+ // scalastyle:off
+ // Combine and build the function string
+ impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
+ val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})"
+ val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}"
+ val classFinal = s"$classDef {$classBody}"
+ val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})"
+ val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
+ val functionFinal = s"$functionDef = $functionBody"
+ // scalastyle:on
+ functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
+ classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
+ }
+
+ structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*)
+ }
+
+ private def structGeneration(c: blackbox.Context)
+ (funcDef : List[c.universe.DefDef],
+ classDef : List[c.universe.ClassDef],
+ annottees: c.Expr[Any]*)
+ : c.Expr[Any] = {
+ import c.universe._
+ val inputs = annottees.map(_.tree).toList
+ // pattern match on the inputs
+ var modDefs = inputs map {
+ case ClassDef(mods, name, something, template) =>
+ val q = template match {
+ case Template(superMaybe, emptyValDef, defs) =>
+ Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid template: $ex")
+ }
+ ClassDef(mods, name, something, q)
+ case ModuleDef(mods, name, template) =>
+ val q = template match {
+ case Template(superMaybe, emptyValDef, defs) =>
+ Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid template: $ex")
+ }
+ ModuleDef(mods, name, q)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid macro input: $ex")
+ }
+ // modDefs ++= classDef
+ // wrap the result up in an Expr, and return it
+ val result = c.Expr(Block(modDefs, Literal(Constant())))
+ result
+ }
+
+ // List and add all the atomic symbol functions to current module.
+ private def initNDArrayModule(): List[NDArrayFunction] = {
+ val opNames = ListBuffer.empty[String]
+ _LIB.mxListAllOpNames(opNames)
+ opNames.map(opName => {
+ val opHandle = new RefLong
+ _LIB.nnGetOpHandle(opName, opHandle)
+ makeNDArrayFunction(opHandle.value, opName)
+ }).toList
+ }
+
+ // Create an atomic symbol function by handle and function name.
+ private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
+ : NDArrayFunction = {
+ val name = new RefString
+ val desc = new RefString
+ val keyVarNumArgs = new RefString
+ val numArgs = new RefInt
+ val argNames = ListBuffer.empty[String]
+ val argTypes = ListBuffer.empty[String]
+ val argDescs = ListBuffer.empty[String]
+
+ _LIB.mxSymbolGetAtomicSymbolInfo(
+ handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
+ val argList = argNames zip argTypes map { case (argName, argType) =>
+ val typeAndOption =
+ CToScalaUtils.argumentCleaner(argName, argType,
+ "org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
+ new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+ }
+ new NDArrayFunction(aliasName, argList.toList)
+ }
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index d0ebe5b1d2cb..48d8fdf38bc4 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -21,19 +21,19 @@ private[mxnet] object CToScalaUtils {
// Convert C++ Types to Scala Types
- def typeConversion(in : String, argType : String = "",
- argName : String, returnType : String) : String = {
+ def typeConversion(in : String, argType : String = "", argName : String,
+ returnType : String, shapeType : String = "Shape") : String = {
in match {
- case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
+ case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> s"Array[$returnType]"
- case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
- case "int" | "intorNone" | "int(non-negative)" => "Int"
- case "long" | "long(non-negative)" => "Long"
- case "double" | "doubleorNone" => "Double"
+ case "float" | "real_t" | "floatorNone" => "java.lang.Float"
+ case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
+ case "long" | "long(non-negative)" => "java.lang.Long"
+ case "double" | "doubleorNone" => "java.lang.Double"
case "string" => "String"
- case "boolean" | "booleanorNone" => "Boolean"
+ case "boolean" | "booleanorNone" => "java.lang.Boolean"
case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
@@ -52,8 +52,8 @@ private[mxnet] object CToScalaUtils {
* @param argType Raw arguement Type description
* @return (Scala_Type, isOptional)
*/
- def argumentCleaner(argName: String,
- argType : String, returnType : String) : (String, Boolean) = {
+ def argumentCleaner(argName: String, argType : String,
+ returnType : String, shapeType : String = "Shape") : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -73,7 +73,7 @@ private[mxnet] object CToScalaUtils {
s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
- val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
+ val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index c3a7c58c1afc..4404b0885d57 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)
val output = List(
("org.apache.mxnet.Symbol", true),
- ("Int", false),
+ ("java.lang.Integer", false),
("org.apache.mxnet.Shape", true),
("String", true),
("Any", false)