From 8573cf5ef34650c8224fceda713965b53ead9a2f Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 11:29:44 -0700 Subject: [PATCH 01/11] clean history and add commit --- scala-package/core/pom.xml | 5 +- .../org/apache/mxnet/javaapi/Context.scala | 1 - .../org/apache/mxnet/javaapi/DType.scala | 2 + .../scala/org/apache/mxnet/javaapi/IO.scala | 1 + .../org/apache/mxnet/javaapi/NDArray.scala | 194 +++++++++++++++++ .../org/apache/mxnet/javaapi/Shape.scala | 1 + .../apache/mxnet/javaapi/JavaContextTest.java | 38 ++++ .../apache/mxnet/javaapi/JavaNDArrayTest.java | 72 ++++++ .../apache/mxnet/javaapi/JavaShapeTest.java | 121 ++++++++++ .../mxnet/javaapi/JavaNDArrayMacro.scala | 206 ++++++++++++++++++ .../apache/mxnet/utils/CToScalaUtils.scala | 22 +- .../scala/org/apache/mxnet/MacrosSuite.scala | 2 +- .../native/org_apache_mxnet_native_c_api.cc | 39 ++-- scala-package/scalastyle-config.xml | 2 +- 14 files changed, 670 insertions(+), 36 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala create mode 100644 scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java create mode 100644 scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java create mode 100644 scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java create mode 100644 scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index ea3a2d68c9f4..6e2d8d6e9cc7 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -86,7 +86,10 @@ maven-surefire-plugin 2.22.0 - false + + -Djava.library.path=${project.parent.basedir}/native/${platform}/target + + ${skipTests} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala index 5f0caedcc402..2f4f3e6409ed 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala @@ -42,6 +42,5 @@ object Context { val gpu: Context = org.apache.mxnet.Context.gpu() val devtype2str = org.apache.mxnet.Context.devstr2type.asJava val devstr2type = org.apache.mxnet.Context.devstr2type.asJava - def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala index e25cdde7ac73..aafad424ffeb 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala @@ -24,4 +24,6 @@ object DType extends Enumeration { val UInt8 = org.apache.mxnet.DType.UInt8 val Int32 = org.apache.mxnet.DType.Int32 val Unknown = org.apache.mxnet.DType.Unknown + + private[mxnet] def numOfBytes(dType: DType): Int = org.apache.mxnet.DType.numOfBytes(dType) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala index 47b1c367c1c2..2a2fdd8046f8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala @@ -32,3 +32,4 @@ object DataDesc{ def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout)) } + diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala new file mode 100644 index 000000000000..1d26694d095d --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi + +import org.apache.mxnet.javaapi.DType.DType + +import collection.JavaConverters._ + +@AddJNDArrayAPIs(false) +object NDArray { + implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd) + + implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd + + def waitall(): Unit = org.apache.mxnet.NDArray.waitall() + + def onehotEncode(indices: NDArray, out: NDArray): NDArray = org.apache.mxnet.NDArray.onehotEncode(indices, out) + + def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray + = org.apache.mxnet.NDArray.empty(shape, ctx, dtype) + def empty(ctx: Context, shape: Array[Int]): NDArray + = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx) + def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray + = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx) + def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray + = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype) + def zeros(ctx: Context, shape: Array[Int]): NDArray + = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx) + def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray + = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx) + def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray + = org.apache.mxnet.NDArray.ones(shape, ctx, dtype) + def ones(ctx: Context, shape: Array[Int]): NDArray + = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx) + def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray + = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx) + def full(shape: Shape, value: Float, ctx: Context): NDArray = org.apache.mxnet.NDArray.full(shape, value, ctx) + + def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + + def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + + def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + + def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs) + def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs) + + def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs) + def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs) + + def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) + def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) + + def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) + def greaterEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) + + def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) + def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) + + def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) + def lesserEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) + + def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray + = org.apache.mxnet.NDArray.array(sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx) + + def arange(start: Float, stop: Float, step: Float, repeat: Int, + ctx: Context, dType: DType.DType): NDArray = + org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType) +} + +class NDArray(val nd : org.apache.mxnet.NDArray ) { + + def this(arr : Array[Float], shape : Shape, ctx : Context) = { + this(org.apache.mxnet.NDArray.array(arr, shape, ctx)) + } + + def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : Context) = { + this(NDArray.array(arr, shape, ctx)) + } + + def serialize() : Array[Byte] = nd.serialize() + + def dispose() : Unit = nd.dispose() + def disposeDeps() : NDArray = nd.disposeDepsExcept() + // def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept() + + def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop) + + def slice (i : Int) : NDArray = nd.slice(i) + + def at(idx : Int) : NDArray = nd.at(idx) + + def T : NDArray = nd.T + + def dtype : DType = nd.dtype + + def asType(dtype : DType) : NDArray = nd.asType(dtype) + + def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims) + + def waitToRead(): Unit = nd.waitToRead() + + def context : Context = nd.context + + def set(value : Float) : NDArray = nd.set(value) + def set(other : NDArray) : NDArray = nd.set(other) + def set(other : Array[Float]) : NDArray = nd.set(other) + + def add(other : NDArray) : NDArray = this.nd + other.nd + def add(other : Float) : NDArray = this.nd + other + def _add(other : NDArray) : NDArray = this.nd += other + def _add(other : Float) : NDArray = this.nd += other + def subtract(other : NDArray) : NDArray = this.nd - other + def subtract(other : Float) : NDArray = this.nd - other + def _subtract(other : NDArray) : NDArray = this.nd -= other + def _subtract(other : Float) : NDArray = this.nd -= other + def multiply(other : NDArray) : NDArray = this.nd * other + def multiply(other : Float) : NDArray = this.nd * other + def _multiply(other : NDArray) : NDArray = this.nd *= other + def _multiply(other : Float) : NDArray = this.nd *= other + def div(other : NDArray) : NDArray = this.nd / other + def div(other : Float) : NDArray = this.nd / other + def _div(other : NDArray) : NDArray = this.nd /= other + def _div(other : Float) : NDArray = this.nd /= other + def pow(other : NDArray) : NDArray = this.nd ** other + def pow(other : Float) : NDArray = this.nd ** other + def _pow(other : NDArray) : NDArray = this.nd **= other + def _pow(other : Float) : NDArray = this.nd **= other + def mod(other : NDArray) : NDArray = this.nd % other + def mod(other : Float) : NDArray = this.nd % other + def _mod(other : NDArray) : NDArray = this.nd %= other + def _mod(other : Float) : NDArray = this.nd %= other + def greater(other : NDArray) : NDArray = this.nd > other + def greater(other : Float) : NDArray = this.nd > other + def greaterEqual(other : NDArray) : NDArray = this.nd >= other + def greaterEqual(other : Float) : NDArray = this.nd >= other + def lesser(other : NDArray) : NDArray = this.nd < other + def lesser(other : Float) : NDArray = this.nd < other + def lesserEqual(other : NDArray) : NDArray = this.nd <= other + def lesserEqual(other : Float) : NDArray = this.nd <= other + + def toArray : Array[Float] = nd.toArray + + def toScalar : Float = nd.toScalar + + def copyTo(other : NDArray) : NDArray = nd.copyTo(other) + + def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx) + + def copy() : NDArray = copyTo(this.context) + + def shape : Shape = nd.shape + + def size : Int = shape.product + + def asInContext(context: Context): NDArray = nd.asInContext(context) + + override def equals(obj: Any): Boolean = nd.equals(obj) + override def hashCode(): Int = nd.hashCode +} + +object NDArrayFuncReturn { + implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn) : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn + implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) : NDArrayFuncReturn = + new NDArrayFuncReturn(ndFuncReturn) +} + +private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) { + def head : NDArray = ndFuncReturn.head + def get : NDArray = ndFuncReturn.get + def apply(i : Int) : NDArray = ndFuncReturn.apply(i) + // TODO: Add JavaNDArray operational stuff +} \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala index 594e3a60578f..80c719763b90 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala @@ -50,3 +50,4 @@ object Shape { implicit def toShape(jShape: Shape): org.apache.mxnet.Shape = jShape.shape } + diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java new file mode 100644 index 000000000000..ceae1b452301 --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi; + +import org.junit.Test; + +public class JavaContextTest { + + @Test + public void testCPU() { + Context.cpu(); + } + + @Test + public void testDefault() { + Context.defaultCtx(); + } + + @Test + public void testConstructor() { + new Context("cpu", 0); + } +} \ No newline at end of file diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java new file mode 100644 index 000000000000..88bfa11d01e0 --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java @@ -0,0 +1,72 @@ +package org.apache.mxnet.javaapi; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertTrue; + +public class JavaNDArrayTest { + @Test + public void testCreateNDArray() { + NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, + new Shape(new int[]{1, 3}), + new Context("cpu", 0)); + int[] arr = new int[]{1, 3}; + assertTrue(Arrays.equals(nd.shape().toArray(), arr)); + assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f); + List list = new ArrayList(); + list.add(1.0f); + list.add(2.0f); + list.add(3.0f); + nd.dispose(); + // Second way creating NDArray + nd = NDArray.array(list, + new Shape(new int[]{1, 3}), + new Context("cpu", 0)); + assertTrue(Arrays.equals(nd.shape().toArray(), arr)); + } + + @Test + public void testZeroOneEmpty(){ + NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100}); + NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100}); + NDArray empty = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100}); + int[] arr = new int[]{100, 100}; + assertTrue(Arrays.equals(ones.shape().toArray(), arr)); + assertTrue(Arrays.equals(zeros.shape().toArray(), arr)); + assertTrue(Arrays.equals(empty.shape().toArray(), arr)); + } + + @Test + public void testComparison(){ + NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new int[]{3}), new Context("cpu", 0)); + NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new int[]{3}), new Context("cpu", 0)); + nd = nd.add(nd2); + float[] greater = new float[]{1, 1, 1}; + assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater)); + nd = nd.subtract(nd2); + nd = nd.subtract(nd2); + float[] lesser = new float[]{0, 0, 0}; + assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser)); + } + + @Test + public void testGenerated(){ + NDArray$ NDArray = NDArray$.MODULE$; + float[] arr = new float[]{1.0f, 2.0f, 3.0f}; + NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0)); + float result = NDArray.norm(nd).invoke().get().toArray()[0]; + float cal = 0.0f; + for (float ele : arr) { + cal += ele * ele; + } + cal = (float) Math.sqrt(cal); + assertTrue(Math.abs(result - cal) < 1e-5); + NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0)); + NDArray.dot(nd, nd).setout(dotResult).invoke().get(); + assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f})); + } +} diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java new file mode 100644 index 000000000000..38ea24783efa --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import org.junit.Test; +import java.util.ArrayList; +import java.util.Arrays; + +public class JavaShapeTest { + @Test + public void testArrayConstructor() + { + new Shape(new int[] {3, 4, 5}); + } + + @Test + public void testListConstructor() + { + ArrayList arrList = new ArrayList(); + arrList.add(3); + arrList.add(4); + arrList.add(5); + new Shape(arrList); + } + + @Test + public void testApply() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.apply(1), 4); + } + + @Test + public void testGet() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.get(1), 4); + } + + @Test + public void testSize() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.size(), 3); + } + + @Test + public void testLength() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.length(), 3); + } + + @Test + public void testDrop() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + ArrayList l = new ArrayList(); + l.add(4); + l.add(5); + assertTrue(jS.drop(1).toVector().equals(l)); + } + + @Test + public void testSlice() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + ArrayList l = new ArrayList(); + l.add(4); + assertTrue(jS.slice(1,2).toVector().equals(l)); + } + + @Test + public void testProduct() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.product(), 60); + } + + @Test + public void testHead() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.head(), 3); + } + + @Test + public void testToArray() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertTrue(Arrays.equals(jS.toArray(), new int[] {3,4,5})); + } + + @Test + public void testToVector() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + ArrayList l = new ArrayList(); + l.add(3); + l.add(4); + l.add(5); + assertTrue(jS.toVector().equals(l)); + } +} \ No newline at end of file diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala new file mode 100644 index 000000000000..bb39e81e19ed --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi + +import org.apache.mxnet.init.Base._ +import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} + +import scala.annotation.StaticAnnotation +import scala.collection.mutable.ListBuffer +import scala.language.experimental.macros +import scala.reflect.macros.blackbox + +private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = macro JavaNDArrayMacro.typeSafeAPIDefs +} + +private[mxnet] object JavaNDArrayMacro { + case class NDArrayArg(argName: String, argType: String, isOptional : Boolean) + case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg]) + + // scalastyle:off havetype + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + typeSafeAPIImpl(c)(annottees: _*) + } + // scalastyle:off havetype + + private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() + + private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + import c.universe._ + + val isContrib: Boolean = c.prefix.tree match { + case q"new AddJNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) + } + // Defines Operators that should not generated + val notGenerated = Set("Custom") + + val newNDArrayFunctions = { + if (isContrib) ndarrayFunctions.filter( + func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + else ndarrayFunctions.filterNot(_.name.startsWith("_")) + }.filterNot(ele => notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => { + // Pattern matching for not generating depreciated method + if (ele._2.length == 1) ele._2.head + else { + if (ele._2.head.name.head.isLower) ele._2.head + else ele._2.last + } + }) + + val functionDefs = ListBuffer[DefDef]() + val classDefs = ListBuffer[ClassDef]() + + newNDArrayFunctions.foreach { ndarrayfunction => + + // Construct argument field with all required args + var argDef = ListBuffer[String]() + // Construct Optional Arg + var OptionArgDef = ListBuffer[String]() + // Construct function Implementation field (e.g norm) + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + // scalastyle:off + impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]" + // scalastyle:on + // Construct Class Implementation (e.g normBuilder) + var classImpl = ListBuffer[String]() + ndarrayfunction.listOfArgs.foreach({ ndarrayarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + var currArgName = ndarrayarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => ndarrayarg.argName + } + if (ndarrayarg.isOptional) { + OptionArgDef += s"private var $currArgName : ${ndarrayarg.argType} = null" + val tempDef = s"def set$currArgName($currArgName : ${ndarrayarg.argType})" + val tempImpl = s"this.$currArgName = $currArgName\nthis" + classImpl += s"$tempDef = {$tempImpl}" + } else { + argDef += s"$currArgName : ${ndarrayarg.argType}" + } + // NDArray arg implementation + val returnType = "org.apache.mxnet.javaapi.NDArray" + val base = + if (ndarrayarg.argType.equals(returnType)) { + s"args += this.$currArgName" + } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){ + s"this.$currArgName.foreach(args+=_)" + } else { + "map(\"" + ndarrayarg.argName + "\") = this." + currArgName + } + impl.append( + if (ndarrayarg.isOptional) s"if (this.$currArgName != null) $base" + else base + ) + }) + // add default out parameter + classImpl += + "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}" + impl += "if (this.out != null) map(\"out\") = this.out" + OptionArgDef += "private var out : org.apache.mxnet.NDArray = null" + val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" + // scalastyle:off + // Combine and build the function string + impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" + val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})" + val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}" + val classFinal = s"$classDef {$classBody}" + val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})" + val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})" + val functionFinal = s"$functionDef = $functionBody" + // scalastyle:on + functionDefs += c.parse(functionFinal).asInstanceOf[DefDef] + classDefs += c.parse(classFinal).asInstanceOf[ClassDef] + } + + structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*) + } + + private def structGeneration(c: blackbox.Context) + (funcDef : List[c.universe.DefDef], + classDef : List[c.universe.ClassDef], + annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ + val inputs = annottees.map(_.tree).toList + // pattern match on the inputs + var modDefs = inputs map { + case ClassDef(mods, name, something, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ClassDef(mods, name, something, q) + case ModuleDef(mods, name, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ModuleDef(mods, name, q) + case ex => + throw new IllegalArgumentException(s"Invalid macro input: $ex") + } + // modDefs ++= classDef + // wrap the result up in an Expr, and return it + val result = c.Expr(Block(modDefs, Literal(Constant()))) + result + } + + + + + // List and add all the atomic symbol functions to current module. + private def initNDArrayModule(): List[NDArrayFunction] = { + val opNames = ListBuffer.empty[String] + _LIB.mxListAllOpNames(opNames) + opNames.map(opName => { + val opHandle = new RefLong + _LIB.nnGetOpHandle(opName, opHandle) + makeNDArrayFunction(opHandle.value, opName) + }).toList + } + + // Create an atomic symbol function by handle and function name. + private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String) + : NDArrayFunction = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new RefInt + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + _LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + val argList = argNames zip argTypes map { case (argName, argType) => + val typeAndOption = + CToScalaUtils.argumentCleaner(argName, argType, + "org.apache.mxnet.javaapi.NDArray", "javaapi.Shape") + new NDArrayArg(argName, typeAndOption._1, typeAndOption._2) + } + new NDArrayFunction(aliasName, argList.toList) + } +} \ No newline at end of file diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index d0ebe5b1d2cb..48d8fdf38bc4 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -21,19 +21,19 @@ private[mxnet] object CToScalaUtils { // Convert C++ Types to Scala Types - def typeConversion(in : String, argType : String = "", - argName : String, returnType : String) : String = { + def typeConversion(in : String, argType : String = "", argName : String, + returnType : String, shapeType : String = "Shape") : String = { in match { - case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType" case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" => s"Array[$returnType]" - case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" - case "int" | "intorNone" | "int(non-negative)" => "Int" - case "long" | "long(non-negative)" => "Long" - case "double" | "doubleorNone" => "Double" + case "float" | "real_t" | "floatorNone" => "java.lang.Float" + case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer" + case "long" | "long(non-negative)" => "java.lang.Long" + case "double" | "doubleorNone" => "java.lang.Double" case "string" => "String" - case "boolean" | "booleanorNone" => "Boolean" + case "boolean" | "booleanorNone" => "java.lang.Boolean" case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default\nString argType: $argType\nargName: $argName") @@ -52,8 +52,8 @@ private[mxnet] object CToScalaUtils { * @param argType Raw arguement Type description * @return (Scala_Type, isOptional) */ - def argumentCleaner(argName: String, - argType : String, returnType : String) : (String, Boolean) = { + def argumentCleaner(argName: String, argType : String, + returnType : String, shapeType : String = "Shape") : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} @@ -73,7 +73,7 @@ private[mxnet] object CToScalaUtils { s"""expected "default=..." got ${commaRemoved(2)}""") (typeConversion(commaRemoved(0), argType, argName, returnType), true) } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType, argName, returnType) + val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType) val tempOptional = tempType.equals("org.apache.mxnet.Symbol") (tempType, tempOptional) } else { diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index c3a7c58c1afc..4404b0885d57 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) val output = List( ("org.apache.mxnet.Symbol", true), - ("Int", false), + ("java.lang.Integer", false), ("org.apache.mxnet.Shape", true), ("String", true), ("Any", false) diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 17d166eac345..95325f3a6a2e 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -1581,29 +1581,26 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0); env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0); - if (ret == 0) { - jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); - jmethodID listAppend = env->GetMethodID(listClass, - "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); - - if (FillSymbolInferShape( - env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) { - // TODO(Yizhi): out of memory error thrown, return a specific error code ? - return -1; - } - if (FillSymbolInferShape( - env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) { - // TODO(Yizhi): out of memory error thrown, return a specific error code ? - return -1; - } - if (FillSymbolInferShape( - env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) { - // TODO(Yizhi): out of memory error thrown, return a specific error code ? - return -1; - } + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); - SetIntField(env, jcomplete, complete); + if (FillSymbolInferShape(env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) { + // TODO(Yizhi): out of memory error thrown, return a specific error code ? + return -1; } + if (FillSymbolInferShape( + env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) { + // TODO(Yizhi): out of memory error thrown, return a specific error code ? + return -1; + } + if (FillSymbolInferShape( + env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) { + // TODO(Yizhi): out of memory error thrown, return a specific error code ? + return -1; + } + + SetIntField(env, jcomplete, complete); // release allocated memory if (jkeys != NULL) { diff --git a/scala-package/scalastyle-config.xml b/scala-package/scalastyle-config.xml index 71e246490fe2..f37522df44e8 100644 --- a/scala-package/scalastyle-config.xml +++ b/scala-package/scalastyle-config.xml @@ -64,7 +64,7 @@ You can also disable only one rule, by specifying its rule id, as specified in: - + true From fc70ab437fdd87fd21bd071691a72617c808e16d Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 11:45:27 -0700 Subject: [PATCH 02/11] add lint header --- .../apache/mxnet/javaapi/JavaNDArrayTest.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java index 88bfa11d01e0..5b03ad2d7be2 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.mxnet.javaapi; import org.junit.Test; From b82b772fbc9222cbe5167635d9c769aa592db46e Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 14:05:31 -0700 Subject: [PATCH 03/11] bypass the java unittest when make the package --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a4b41b8d8371..fe2df2c20afa 100644 --- a/Makefile +++ b/Makefile @@ -606,7 +606,7 @@ scalaclean: scalapkg: (cd $(ROOTDIR)/scala-package; \ - mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE) -Dcxx="$(CXX)" \ + mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE),integrationtest -Dcxx="$(CXX)" \ -Dbuild.platform="$(SCALA_PKG_PROFILE)" \ -Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \ -Dcurrent_libdir="$(ROOTDIR)/lib" \ From 09963c69ec0a896a13c46d59b9b8daf88eba96cd Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 14:39:14 -0700 Subject: [PATCH 04/11] clean up redundant test --- .../apache/mxnet/javaapi/JavaContextTest.java | 38 ------ .../apache/mxnet/javaapi/JavaShapeTest.java | 121 ------------------ ...{JavaNDArrayTest.java => NDArrayTest.java} | 2 +- 3 files changed, 1 insertion(+), 160 deletions(-) delete mode 100644 scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java delete mode 100644 scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java rename scala-package/core/src/test/java/org/apache/mxnet/javaapi/{JavaNDArrayTest.java => NDArrayTest.java} (99%) diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java deleted file mode 100644 index ceae1b452301..000000000000 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaContextTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet.javaapi; - -import org.junit.Test; - -public class JavaContextTest { - - @Test - public void testCPU() { - Context.cpu(); - } - - @Test - public void testDefault() { - Context.defaultCtx(); - } - - @Test - public void testConstructor() { - new Context("cpu", 0); - } -} \ No newline at end of file diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java deleted file mode 100644 index 38ea24783efa..000000000000 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaShapeTest.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet.javaapi; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import org.junit.Test; -import java.util.ArrayList; -import java.util.Arrays; - -public class JavaShapeTest { - @Test - public void testArrayConstructor() - { - new Shape(new int[] {3, 4, 5}); - } - - @Test - public void testListConstructor() - { - ArrayList arrList = new ArrayList(); - arrList.add(3); - arrList.add(4); - arrList.add(5); - new Shape(arrList); - } - - @Test - public void testApply() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - assertEquals(jS.apply(1), 4); - } - - @Test - public void testGet() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - assertEquals(jS.get(1), 4); - } - - @Test - public void testSize() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - assertEquals(jS.size(), 3); - } - - @Test - public void testLength() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - assertEquals(jS.length(), 3); - } - - @Test - public void testDrop() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - ArrayList l = new ArrayList(); - l.add(4); - l.add(5); - assertTrue(jS.drop(1).toVector().equals(l)); - } - - @Test - public void testSlice() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - ArrayList l = new ArrayList(); - l.add(4); - assertTrue(jS.slice(1,2).toVector().equals(l)); - } - - @Test - public void testProduct() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - assertEquals(jS.product(), 60); - } - - @Test - public void testHead() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - assertEquals(jS.head(), 3); - } - - @Test - public void testToArray() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - assertTrue(Arrays.equals(jS.toArray(), new int[] {3,4,5})); - } - - @Test - public void testToVector() - { - Shape jS = new Shape(new int[] {3, 4, 5}); - ArrayList l = new ArrayList(); - l.add(3); - l.add(4); - l.add(5); - assertTrue(jS.toVector().equals(l)); - } -} \ No newline at end of file diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java similarity index 99% rename from scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java rename to scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index 5b03ad2d7be2..e0b8179a2368 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/JavaNDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -25,7 +25,7 @@ import static org.junit.Assert.assertTrue; -public class JavaNDArrayTest { +public class NDArrayTest { @Test public void testCreateNDArray() { NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, From e486b1bf37fddee38bc9b7f5d08c850f84f6dee5 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 14:42:05 -0700 Subject: [PATCH 05/11] clean spacing issue --- .../main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index bb39e81e19ed..ea2e5abe40d6 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -203,4 +203,4 @@ private[mxnet] object JavaNDArrayMacro { } new NDArrayFunction(aliasName, argList.toList) } -} \ No newline at end of file +} From fff5c7912698127733b20d5b0d0f37277e31d18b Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 14:43:53 -0700 Subject: [PATCH 06/11] revert the change --- .../native/org_apache_mxnet_native_c_api.cc | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index 95325f3a6a2e..17d166eac345 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -1581,26 +1581,29 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0); env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0); - jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); - jmethodID listAppend = env->GetMethodID(listClass, - "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + if (ret == 0) { + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); - if (FillSymbolInferShape(env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) { - // TODO(Yizhi): out of memory error thrown, return a specific error code ? - return -1; - } - if (FillSymbolInferShape( - env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) { - // TODO(Yizhi): out of memory error thrown, return a specific error code ? - return -1; - } - if (FillSymbolInferShape( - env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) { - // TODO(Yizhi): out of memory error thrown, return a specific error code ? - return -1; - } + if (FillSymbolInferShape( + env, listAppend, jinShapeData, inShapeSize, inShapeNdim, inShapeData)) { + // TODO(Yizhi): out of memory error thrown, return a specific error code ? + return -1; + } + if (FillSymbolInferShape( + env, listAppend, joutShapeData, outShapeSize, outShapeNdim, outShapeData)) { + // TODO(Yizhi): out of memory error thrown, return a specific error code ? + return -1; + } + if (FillSymbolInferShape( + env, listAppend, jauxShapeData, auxShapeSize, auxShapeNdim, auxShapeData)) { + // TODO(Yizhi): out of memory error thrown, return a specific error code ? + return -1; + } - SetIntField(env, jcomplete, complete); + SetIntField(env, jcomplete, complete); + } // release allocated memory if (jkeys != NULL) { From b9449e54a00a5e00ac39bb1038e4ffd5240854b0 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 14:46:06 -0700 Subject: [PATCH 07/11] clean up --- .../core/src/main/scala/org/apache/mxnet/javaapi/IO.scala | 1 - .../core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala index 2a2fdd8046f8..47b1c367c1c2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala @@ -32,4 +32,3 @@ object DataDesc{ def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout)) } - diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala index 80c719763b90..594e3a60578f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala @@ -50,4 +50,3 @@ object Shape { implicit def toShape(jShape: Shape): org.apache.mxnet.Shape = jShape.shape } - From 2f2a4d09683d0bd1ff2a26b04e87825b69dd7c42 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 14:46:20 -0700 Subject: [PATCH 08/11] cleanup the JMacros --- .../scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index ea2e5abe40d6..de1c8058ab95 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -18,7 +18,7 @@ package org.apache.mxnet.javaapi import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} +import org.apache.mxnet.utils.CToScalaUtils import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer @@ -168,9 +168,6 @@ private[mxnet] object JavaNDArrayMacro { result } - - - // List and add all the atomic symbol functions to current module. private def initNDArrayModule(): List[NDArrayFunction] = { val opNames = ListBuffer.empty[String] From 50ffbfb7913aa76ddc98460dc7ab68083ee25534 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 9 Oct 2018 16:22:04 -0700 Subject: [PATCH 09/11] adding line escape --- .../core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala index 1d26694d095d..f838747b6dad 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -191,4 +191,4 @@ private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArr def get : NDArray = ndFuncReturn.get def apply(i : Int) : NDArray = ndFuncReturn.apply(i) // TODO: Add JavaNDArray operational stuff -} \ No newline at end of file +} From cd93cc3f7b7167621ab6a6c78ef745aed9e31f1e Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 11 Oct 2018 10:55:27 -0700 Subject: [PATCH 10/11] revert some changes and fix scala style --- .../org/apache/mxnet/javaapi/DType.scala | 2 -- .../org/apache/mxnet/javaapi/NDArray.scala | 28 ++++++++++++------- scala-package/scalastyle-config.xml | 2 +- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala index aafad424ffeb..e25cdde7ac73 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala @@ -24,6 +24,4 @@ object DType extends Enumeration { val UInt8 = org.apache.mxnet.DType.UInt8 val Int32 = org.apache.mxnet.DType.Int32 val Unknown = org.apache.mxnet.DType.Unknown - - private[mxnet] def numOfBytes(dType: DType): Int = org.apache.mxnet.DType.numOfBytes(dType) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala index f838747b6dad..c77b440d8802 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -29,7 +29,8 @@ object NDArray { def waitall(): Unit = org.apache.mxnet.NDArray.waitall() - def onehotEncode(indices: NDArray, out: NDArray): NDArray = org.apache.mxnet.NDArray.onehotEncode(indices, out) + def onehotEncode(indices: NDArray, out: NDArray): NDArray + = org.apache.mxnet.NDArray.onehotEncode(indices, out) def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray = org.apache.mxnet.NDArray.empty(shape, ctx, dtype) @@ -49,7 +50,8 @@ object NDArray { = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx) def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx) - def full(shape: Shape, value: Float, ctx: Context): NDArray = org.apache.mxnet.NDArray.full(shape, value, ctx) + def full(shape: Shape, value: Float, ctx: Context): NDArray + = org.apache.mxnet.NDArray.full(shape, value, ctx) def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) @@ -72,17 +74,22 @@ object NDArray { def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) - def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) - def greaterEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) + def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray + = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) + def greaterEqual(lhs: NDArray, rhs: Float): NDArray + = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) - def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) - def lesserEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) + def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray + = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) + def lesserEqual(lhs: NDArray, rhs: Float): NDArray + = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray - = org.apache.mxnet.NDArray.array(sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx) + = org.apache.mxnet.NDArray.array( + sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx) def arange(start: Float, stop: Float, step: Float, repeat: Int, ctx: Context, dType: DType.DType): NDArray = @@ -181,9 +188,10 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) { } object NDArrayFuncReturn { - implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn) : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn - implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) : NDArrayFuncReturn = - new NDArrayFuncReturn(ndFuncReturn) + implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn) + : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn + implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) + : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn) } private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) { diff --git a/scala-package/scalastyle-config.xml b/scala-package/scalastyle-config.xml index f37522df44e8..71e246490fe2 100644 --- a/scala-package/scalastyle-config.xml +++ b/scala-package/scalastyle-config.xml @@ -64,7 +64,7 @@ You can also disable only one rule, by specifying its rule id, as specified in: - + true From cf6dd1007973d3ae40bb6bc18684ef8439c4f1c2 Mon Sep 17 00:00:00 2001 From: Qing Date: Sun, 14 Oct 2018 23:07:48 -0700 Subject: [PATCH 11/11] fixes regarding to Naveen's comment --- .../org/apache/mxnet/javaapi/NDArrayTest.java | 8 ++----- .../mxnet/javaapi/JavaNDArrayMacro.scala | 22 +++++++++---------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index e0b8179a2368..a9bad83f62d6 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -34,11 +34,7 @@ public void testCreateNDArray() { int[] arr = new int[]{1, 3}; assertTrue(Arrays.equals(nd.shape().toArray(), arr)); assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f); - List list = new ArrayList(); - list.add(1.0f); - list.add(2.0f); - list.add(3.0f); - nd.dispose(); + List list = Arrays.asList(1.0f, 2.0f, 3.0f); // Second way creating NDArray nd = NDArray.array(list, new Shape(new int[]{1, 3}), @@ -50,7 +46,7 @@ public void testCreateNDArray() { public void testZeroOneEmpty(){ NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100}); NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100}); - NDArray empty = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100}); + NDArray empty = NDArray.empty(new Context("cpu", 0), new int[]{100, 100}); int[] arr = new int[]{100, 100}; assertTrue(Arrays.equals(ones.shape().toArray(), arr)); assertTrue(Arrays.equals(zeros.shape().toArray(), arr)); diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index de1c8058ab95..c530c730a449 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -80,34 +80,34 @@ private[mxnet] object JavaNDArrayMacro { // scalastyle:on // Construct Class Implementation (e.g normBuilder) var classImpl = ListBuffer[String]() - ndarrayfunction.listOfArgs.foreach({ ndarrayarg => + ndarrayfunction.listOfArgs.foreach({ ndarrayArg => // var is a special word used to define variable in Scala, // need to changed to something else in order to make it work - var currArgName = ndarrayarg.argName match { + var currArgName = ndarrayArg.argName match { case "var" => "vari" case "type" => "typeOf" - case _ => ndarrayarg.argName + case _ => ndarrayArg.argName } - if (ndarrayarg.isOptional) { - OptionArgDef += s"private var $currArgName : ${ndarrayarg.argType} = null" - val tempDef = s"def set$currArgName($currArgName : ${ndarrayarg.argType})" + if (ndarrayArg.isOptional) { + OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null" + val tempDef = s"def set$currArgName($currArgName : ${ndarrayArg.argType})" val tempImpl = s"this.$currArgName = $currArgName\nthis" classImpl += s"$tempDef = {$tempImpl}" } else { - argDef += s"$currArgName : ${ndarrayarg.argType}" + argDef += s"$currArgName : ${ndarrayArg.argType}" } // NDArray arg implementation val returnType = "org.apache.mxnet.javaapi.NDArray" val base = - if (ndarrayarg.argType.equals(returnType)) { + if (ndarrayArg.argType.equals(returnType)) { s"args += this.$currArgName" - } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){ + } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){ s"this.$currArgName.foreach(args+=_)" } else { - "map(\"" + ndarrayarg.argName + "\") = this." + currArgName + "map(\"" + ndarrayArg.argName + "\") = this." + currArgName } impl.append( - if (ndarrayarg.isOptional) s"if (this.$currArgName != null) $base" + if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base" else base ) })