diff --git a/Makefile b/Makefile index 620679a44f80..ab19607a5aa3 100644 --- a/Makefile +++ b/Makefile @@ -604,7 +604,7 @@ scalaclean: scalapkg: (cd $(ROOTDIR)/scala-package; \ - mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE) -Dcxx="$(CXX)" \ + mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE),integrationtest -Dcxx="$(CXX)" \ -Dbuild.platform="$(SCALA_PKG_PROFILE)" \ -Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \ -Dcurrent_libdir="$(ROOTDIR)/lib" \ diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index e93169f08faa..d5396dab1e67 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -86,7 +86,10 @@ maven-surefire-plugin 2.22.0 - false + + -Djava.library.path=${project.parent.basedir}/native/${platform}/target + + ${skipTests} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala index 5f0caedcc402..2f4f3e6409ed 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala @@ -42,6 +42,5 @@ object Context { val gpu: Context = org.apache.mxnet.Context.gpu() val devtype2str = org.apache.mxnet.Context.devstr2type.asJava val devstr2type = org.apache.mxnet.Context.devstr2type.asJava - def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala new file mode 100644 index 000000000000..c77b440d8802 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi + +import org.apache.mxnet.javaapi.DType.DType + +import collection.JavaConverters._ + +@AddJNDArrayAPIs(false) +object NDArray { + implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd) + + implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd + + def waitall(): Unit = org.apache.mxnet.NDArray.waitall() + + def onehotEncode(indices: NDArray, out: NDArray): NDArray + = org.apache.mxnet.NDArray.onehotEncode(indices, out) + + def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray + = org.apache.mxnet.NDArray.empty(shape, ctx, dtype) + def empty(ctx: Context, shape: Array[Int]): NDArray + = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx) + def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray + = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx) + def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray + = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype) + def zeros(ctx: Context, shape: Array[Int]): NDArray + = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx) + def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray + = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx) + def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray + = org.apache.mxnet.NDArray.ones(shape, ctx, dtype) + def ones(ctx: Context, shape: Array[Int]): NDArray + = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx) + def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray + = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx) + def full(shape: Shape, value: Float, ctx: Context): NDArray + = org.apache.mxnet.NDArray.full(shape, value, ctx) + + def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + + def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + + def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + + def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs) + def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs) + + def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs) + def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs) + + def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) + def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) + + def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray + = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) + def greaterEqual(lhs: NDArray, rhs: Float): NDArray + = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) + + def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) + def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) + + def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray + = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) + def lesserEqual(lhs: NDArray, rhs: Float): NDArray + = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) + + def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray + = org.apache.mxnet.NDArray.array( + sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx) + + def arange(start: Float, stop: Float, step: Float, repeat: Int, + ctx: Context, dType: DType.DType): NDArray = + org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType) +} + +class NDArray(val nd : org.apache.mxnet.NDArray ) { + + def this(arr : Array[Float], shape : Shape, ctx : Context) = { + this(org.apache.mxnet.NDArray.array(arr, shape, ctx)) + } + + def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : Context) = { + this(NDArray.array(arr, shape, ctx)) + } + + def serialize() : Array[Byte] = nd.serialize() + + def dispose() : Unit = nd.dispose() + def disposeDeps() : NDArray = nd.disposeDepsExcept() + // def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept() + + def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop) + + def slice (i : Int) : NDArray = nd.slice(i) + + def at(idx : Int) : NDArray = nd.at(idx) + + def T : NDArray = nd.T + + def dtype : DType = nd.dtype + + def asType(dtype : DType) : NDArray = nd.asType(dtype) + + def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims) + + def waitToRead(): Unit = nd.waitToRead() + + def context : Context = nd.context + + def set(value : Float) : NDArray = nd.set(value) + def set(other : NDArray) : NDArray = nd.set(other) + def set(other : Array[Float]) : NDArray = nd.set(other) + + def add(other : NDArray) : NDArray = this.nd + other.nd + def add(other : Float) : NDArray = this.nd + other + def _add(other : NDArray) : NDArray = this.nd += other + def _add(other : Float) : NDArray = this.nd += other + def subtract(other : NDArray) : NDArray = this.nd - other + def subtract(other : Float) : NDArray = this.nd - other + def _subtract(other : NDArray) : NDArray = this.nd -= other + def _subtract(other : Float) : NDArray = this.nd -= other + def multiply(other : NDArray) : NDArray = this.nd * other + def multiply(other : Float) : NDArray = this.nd * other + def _multiply(other : NDArray) : NDArray = this.nd *= other + def _multiply(other : Float) : NDArray = this.nd *= other + def div(other : NDArray) : NDArray = this.nd / other + def div(other : Float) : NDArray = this.nd / other + def _div(other : NDArray) : NDArray = this.nd /= other + def _div(other : Float) : NDArray = this.nd /= other + def pow(other : NDArray) : NDArray = this.nd ** other + def pow(other : Float) : NDArray = this.nd ** other + def _pow(other : NDArray) : NDArray = this.nd **= other + def _pow(other : Float) : NDArray = this.nd **= other + def mod(other : NDArray) : NDArray = this.nd % other + def mod(other : Float) : NDArray = this.nd % other + def _mod(other : NDArray) : NDArray = this.nd %= other + def _mod(other : Float) : NDArray = this.nd %= other + def greater(other : NDArray) : NDArray = this.nd > other + def greater(other : Float) : NDArray = this.nd > other + def greaterEqual(other : NDArray) : NDArray = this.nd >= other + def greaterEqual(other : Float) : NDArray = this.nd >= other + def lesser(other : NDArray) : NDArray = this.nd < other + def lesser(other : Float) : NDArray = this.nd < other + def lesserEqual(other : NDArray) : NDArray = this.nd <= other + def lesserEqual(other : Float) : NDArray = this.nd <= other + + def toArray : Array[Float] = nd.toArray + + def toScalar : Float = nd.toScalar + + def copyTo(other : NDArray) : NDArray = nd.copyTo(other) + + def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx) + + def copy() : NDArray = copyTo(this.context) + + def shape : Shape = nd.shape + + def size : Int = shape.product + + def asInContext(context: Context): NDArray = nd.asInContext(context) + + override def equals(obj: Any): Boolean = nd.equals(obj) + override def hashCode(): Int = nd.hashCode +} + +object NDArrayFuncReturn { + implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn) + : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn + implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) + : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn) +} + +private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) { + def head : NDArray = ndFuncReturn.head + def get : NDArray = ndFuncReturn.get + def apply(i : Int) : NDArray = ndFuncReturn.apply(i) + // TODO: Add JavaNDArray operational stuff +} diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java new file mode 100644 index 000000000000..a9bad83f62d6 --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertTrue; + +public class NDArrayTest { + @Test + public void testCreateNDArray() { + NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, + new Shape(new int[]{1, 3}), + new Context("cpu", 0)); + int[] arr = new int[]{1, 3}; + assertTrue(Arrays.equals(nd.shape().toArray(), arr)); + assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f); + List list = Arrays.asList(1.0f, 2.0f, 3.0f); + // Second way creating NDArray + nd = NDArray.array(list, + new Shape(new int[]{1, 3}), + new Context("cpu", 0)); + assertTrue(Arrays.equals(nd.shape().toArray(), arr)); + } + + @Test + public void testZeroOneEmpty(){ + NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100}); + NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100}); + NDArray empty = NDArray.empty(new Context("cpu", 0), new int[]{100, 100}); + int[] arr = new int[]{100, 100}; + assertTrue(Arrays.equals(ones.shape().toArray(), arr)); + assertTrue(Arrays.equals(zeros.shape().toArray(), arr)); + assertTrue(Arrays.equals(empty.shape().toArray(), arr)); + } + + @Test + public void testComparison(){ + NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new int[]{3}), new Context("cpu", 0)); + NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new int[]{3}), new Context("cpu", 0)); + nd = nd.add(nd2); + float[] greater = new float[]{1, 1, 1}; + assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater)); + nd = nd.subtract(nd2); + nd = nd.subtract(nd2); + float[] lesser = new float[]{0, 0, 0}; + assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser)); + } + + @Test + public void testGenerated(){ + NDArray$ NDArray = NDArray$.MODULE$; + float[] arr = new float[]{1.0f, 2.0f, 3.0f}; + NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0)); + float result = NDArray.norm(nd).invoke().get().toArray()[0]; + float cal = 0.0f; + for (float ele : arr) { + cal += ele * ele; + } + cal = (float) Math.sqrt(cal); + assertTrue(Math.abs(result - cal) < 1e-5); + NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0)); + NDArray.dot(nd, nd).setout(dotResult).invoke().get(); + assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f})); + } +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala new file mode 100644 index 000000000000..c530c730a449 --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi + +import org.apache.mxnet.init.Base._ +import org.apache.mxnet.utils.CToScalaUtils + +import scala.annotation.StaticAnnotation +import scala.collection.mutable.ListBuffer +import scala.language.experimental.macros +import scala.reflect.macros.blackbox + +private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = macro JavaNDArrayMacro.typeSafeAPIDefs +} + +private[mxnet] object JavaNDArrayMacro { + case class NDArrayArg(argName: String, argType: String, isOptional : Boolean) + case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg]) + + // scalastyle:off havetype + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + typeSafeAPIImpl(c)(annottees: _*) + } + // scalastyle:off havetype + + private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() + + private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + import c.universe._ + + val isContrib: Boolean = c.prefix.tree match { + case q"new AddJNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) + } + // Defines Operators that should not generated + val notGenerated = Set("Custom") + + val newNDArrayFunctions = { + if (isContrib) ndarrayFunctions.filter( + func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + else ndarrayFunctions.filterNot(_.name.startsWith("_")) + }.filterNot(ele => notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => { + // Pattern matching for not generating depreciated method + if (ele._2.length == 1) ele._2.head + else { + if (ele._2.head.name.head.isLower) ele._2.head + else ele._2.last + } + }) + + val functionDefs = ListBuffer[DefDef]() + val classDefs = ListBuffer[ClassDef]() + + newNDArrayFunctions.foreach { ndarrayfunction => + + // Construct argument field with all required args + var argDef = ListBuffer[String]() + // Construct Optional Arg + var OptionArgDef = ListBuffer[String]() + // Construct function Implementation field (e.g norm) + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + // scalastyle:off + impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]" + // scalastyle:on + // Construct Class Implementation (e.g normBuilder) + var classImpl = ListBuffer[String]() + ndarrayfunction.listOfArgs.foreach({ ndarrayArg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + var currArgName = ndarrayArg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => ndarrayArg.argName + } + if (ndarrayArg.isOptional) { + OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null" + val tempDef = s"def set$currArgName($currArgName : ${ndarrayArg.argType})" + val tempImpl = s"this.$currArgName = $currArgName\nthis" + classImpl += s"$tempDef = {$tempImpl}" + } else { + argDef += s"$currArgName : ${ndarrayArg.argType}" + } + // NDArray arg implementation + val returnType = "org.apache.mxnet.javaapi.NDArray" + val base = + if (ndarrayArg.argType.equals(returnType)) { + s"args += this.$currArgName" + } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){ + s"this.$currArgName.foreach(args+=_)" + } else { + "map(\"" + ndarrayArg.argName + "\") = this." + currArgName + } + impl.append( + if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base" + else base + ) + }) + // add default out parameter + classImpl += + "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}" + impl += "if (this.out != null) map(\"out\") = this.out" + OptionArgDef += "private var out : org.apache.mxnet.NDArray = null" + val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" + // scalastyle:off + // Combine and build the function string + impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" + val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})" + val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}" + val classFinal = s"$classDef {$classBody}" + val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})" + val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})" + val functionFinal = s"$functionDef = $functionBody" + // scalastyle:on + functionDefs += c.parse(functionFinal).asInstanceOf[DefDef] + classDefs += c.parse(classFinal).asInstanceOf[ClassDef] + } + + structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*) + } + + private def structGeneration(c: blackbox.Context) + (funcDef : List[c.universe.DefDef], + classDef : List[c.universe.ClassDef], + annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ + val inputs = annottees.map(_.tree).toList + // pattern match on the inputs + var modDefs = inputs map { + case ClassDef(mods, name, something, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ClassDef(mods, name, something, q) + case ModuleDef(mods, name, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ModuleDef(mods, name, q) + case ex => + throw new IllegalArgumentException(s"Invalid macro input: $ex") + } + // modDefs ++= classDef + // wrap the result up in an Expr, and return it + val result = c.Expr(Block(modDefs, Literal(Constant()))) + result + } + + // List and add all the atomic symbol functions to current module. + private def initNDArrayModule(): List[NDArrayFunction] = { + val opNames = ListBuffer.empty[String] + _LIB.mxListAllOpNames(opNames) + opNames.map(opName => { + val opHandle = new RefLong + _LIB.nnGetOpHandle(opName, opHandle) + makeNDArrayFunction(opHandle.value, opName) + }).toList + } + + // Create an atomic symbol function by handle and function name. + private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String) + : NDArrayFunction = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new RefInt + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + _LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + val argList = argNames zip argTypes map { case (argName, argType) => + val typeAndOption = + CToScalaUtils.argumentCleaner(argName, argType, + "org.apache.mxnet.javaapi.NDArray", "javaapi.Shape") + new NDArrayArg(argName, typeAndOption._1, typeAndOption._2) + } + new NDArrayFunction(aliasName, argList.toList) + } +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index d0ebe5b1d2cb..48d8fdf38bc4 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -21,19 +21,19 @@ private[mxnet] object CToScalaUtils { // Convert C++ Types to Scala Types - def typeConversion(in : String, argType : String = "", - argName : String, returnType : String) : String = { + def typeConversion(in : String, argType : String = "", argName : String, + returnType : String, shapeType : String = "Shape") : String = { in match { - case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType" case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" => s"Array[$returnType]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" + case "float" | "real_t" | "floatorNone" => "java.lang.Float" + case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer" + case "long" | "long(non-negative)" => "java.lang.Long" + case "double" | "doubleorNone" => "java.lang.Double" case "string" => "String" - case "boolean" | "booleanorNone" => "Boolean" + case "boolean" | "booleanorNone" => "java.lang.Boolean" case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default\nString argType: $argType\nargName: $argName") @@ -52,8 +52,8 @@ private[mxnet] object CToScalaUtils { * @param argType Raw arguement Type description * @return (Scala_Type, isOptional) */ - def argumentCleaner(argName: String, - argType : String, returnType : String) : (String, Boolean) = { + def argumentCleaner(argName: String, argType : String, + returnType : String, shapeType : String = "Shape") : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} @@ -73,7 +73,7 @@ private[mxnet] object CToScalaUtils { s"""expected "default=..." got ${commaRemoved(2)}""") (typeConversion(commaRemoved(0), argType, argName, returnType), true) } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType, argName, returnType) + val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType) val tempOptional = tempType.equals("org.apache.mxnet.Symbol") (tempType, tempOptional) } else { diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index c3a7c58c1afc..4404b0885d57 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) val output = List( ("org.apache.mxnet.Symbol", true), - ("Int", false), + ("java.lang.Integer", false), ("org.apache.mxnet.Shape", true), ("String", true), ("Any", false)