Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-984] Add Java NDArray and introduce Java Operator Builder class (
Browse files Browse the repository at this point in the history
#12816)

* clean history and add commit

* add lint header

* bypass the java unittest when make the package

* clean up redundant test

* clean spacing issue

* revert the change

* clean up

* cleanup the JMacros

* adding line escape

* revert some changes and fix scala style

* fixes regarding to Naveen's comment
  • Loading branch information
lanking520 authored and nswamy committed Oct 15, 2018
1 parent 06bf600 commit 64b9ac7
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ scalaclean:

scalapkg:
(cd $(ROOTDIR)/scala-package; \
mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE) -Dcxx="$(CXX)" \
mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE),integrationtest -Dcxx="$(CXX)" \
-Dbuild.platform="$(SCALA_PKG_PROFILE)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" \
Expand Down
5 changes: 4 additions & 1 deletion scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
<configuration>
<skipTests>false</skipTests>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
</argLine>
<skipTests>${skipTests}</skipTests>
</configuration>
</plugin>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,5 @@ object Context {
val gpu: Context = org.apache.mxnet.Context.gpu()
val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
val devstr2type = org.apache.mxnet.Context.devstr2type.asJava

def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi

import org.apache.mxnet.javaapi.DType.DType

import collection.JavaConverters._

@AddJNDArrayAPIs(false)
object NDArray {
implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)

implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd

def waitall(): Unit = org.apache.mxnet.NDArray.waitall()

def onehotEncode(indices: NDArray, out: NDArray): NDArray
= org.apache.mxnet.NDArray.onehotEncode(indices, out)

def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
def empty(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
def zeros(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
def ones(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
def full(shape: Shape, value: Float, ctx: Context): NDArray
= org.apache.mxnet.NDArray.full(shape, value, ctx)

def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)

def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)

def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)

def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)

def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)

def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)

def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
def greaterEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)

def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)

def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
def lesserEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)

def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
= org.apache.mxnet.NDArray.array(
sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)

def arange(start: Float, stop: Float, step: Float, repeat: Int,
ctx: Context, dType: DType.DType): NDArray =
org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
}

class NDArray(val nd : org.apache.mxnet.NDArray ) {

def this(arr : Array[Float], shape : Shape, ctx : Context) = {
this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
}

def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : Context) = {
this(NDArray.array(arr, shape, ctx))
}

def serialize() : Array[Byte] = nd.serialize()

def dispose() : Unit = nd.dispose()
def disposeDeps() : NDArray = nd.disposeDepsExcept()
// def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()

def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)

def slice (i : Int) : NDArray = nd.slice(i)

def at(idx : Int) : NDArray = nd.at(idx)

def T : NDArray = nd.T

def dtype : DType = nd.dtype

def asType(dtype : DType) : NDArray = nd.asType(dtype)

def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)

def waitToRead(): Unit = nd.waitToRead()

def context : Context = nd.context

def set(value : Float) : NDArray = nd.set(value)
def set(other : NDArray) : NDArray = nd.set(other)
def set(other : Array[Float]) : NDArray = nd.set(other)

def add(other : NDArray) : NDArray = this.nd + other.nd
def add(other : Float) : NDArray = this.nd + other
def _add(other : NDArray) : NDArray = this.nd += other
def _add(other : Float) : NDArray = this.nd += other
def subtract(other : NDArray) : NDArray = this.nd - other
def subtract(other : Float) : NDArray = this.nd - other
def _subtract(other : NDArray) : NDArray = this.nd -= other
def _subtract(other : Float) : NDArray = this.nd -= other
def multiply(other : NDArray) : NDArray = this.nd * other
def multiply(other : Float) : NDArray = this.nd * other
def _multiply(other : NDArray) : NDArray = this.nd *= other
def _multiply(other : Float) : NDArray = this.nd *= other
def div(other : NDArray) : NDArray = this.nd / other
def div(other : Float) : NDArray = this.nd / other
def _div(other : NDArray) : NDArray = this.nd /= other
def _div(other : Float) : NDArray = this.nd /= other
def pow(other : NDArray) : NDArray = this.nd ** other
def pow(other : Float) : NDArray = this.nd ** other
def _pow(other : NDArray) : NDArray = this.nd **= other
def _pow(other : Float) : NDArray = this.nd **= other
def mod(other : NDArray) : NDArray = this.nd % other
def mod(other : Float) : NDArray = this.nd % other
def _mod(other : NDArray) : NDArray = this.nd %= other
def _mod(other : Float) : NDArray = this.nd %= other
def greater(other : NDArray) : NDArray = this.nd > other
def greater(other : Float) : NDArray = this.nd > other
def greaterEqual(other : NDArray) : NDArray = this.nd >= other
def greaterEqual(other : Float) : NDArray = this.nd >= other
def lesser(other : NDArray) : NDArray = this.nd < other
def lesser(other : Float) : NDArray = this.nd < other
def lesserEqual(other : NDArray) : NDArray = this.nd <= other
def lesserEqual(other : Float) : NDArray = this.nd <= other

def toArray : Array[Float] = nd.toArray

def toScalar : Float = nd.toScalar

def copyTo(other : NDArray) : NDArray = nd.copyTo(other)

def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)

def copy() : NDArray = copyTo(this.context)

def shape : Shape = nd.shape

def size : Int = shape.product

def asInContext(context: Context): NDArray = nd.asInContext(context)

override def equals(obj: Any): Boolean = nd.equals(obj)
override def hashCode(): Int = nd.hashCode
}

object NDArrayFuncReturn {
implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn)
: org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn
implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn)
: NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn)
}

private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) {
def head : NDArray = ndFuncReturn.head
def get : NDArray = ndFuncReturn.get
def apply(i : Int) : NDArray = ndFuncReturn.apply(i)
// TODO: Add JavaNDArray operational stuff
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi;

import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertTrue;

public class NDArrayTest {
@Test
public void testCreateNDArray() {
NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f},
new Shape(new int[]{1, 3}),
new Context("cpu", 0));
int[] arr = new int[]{1, 3};
assertTrue(Arrays.equals(nd.shape().toArray(), arr));
assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f);
List<Float> list = Arrays.asList(1.0f, 2.0f, 3.0f);
// Second way creating NDArray
nd = NDArray.array(list,
new Shape(new int[]{1, 3}),
new Context("cpu", 0));
assertTrue(Arrays.equals(nd.shape().toArray(), arr));
}

@Test
public void testZeroOneEmpty(){
NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100});
NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100});
NDArray empty = NDArray.empty(new Context("cpu", 0), new int[]{100, 100});
int[] arr = new int[]{100, 100};
assertTrue(Arrays.equals(ones.shape().toArray(), arr));
assertTrue(Arrays.equals(zeros.shape().toArray(), arr));
assertTrue(Arrays.equals(empty.shape().toArray(), arr));
}

@Test
public void testComparison(){
NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
nd = nd.add(nd2);
float[] greater = new float[]{1, 1, 1};
assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater));
nd = nd.subtract(nd2);
nd = nd.subtract(nd2);
float[] lesser = new float[]{0, 0, 0};
assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
}

@Test
public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(nd).invoke().get().toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(nd, nd).setout(dotResult).invoke().get();
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Loading

0 comments on commit 64b9ac7

Please sign in to comment.