Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-984] Add Java NDArray and introduce Java Operator Builder class #12816

Merged
merged 11 commits into from
Oct 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ scalaclean:

scalapkg:
(cd $(ROOTDIR)/scala-package; \
mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE) -Dcxx="$(CXX)" \
mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE),integrationtest -Dcxx="$(CXX)" \
-Dbuild.platform="$(SCALA_PKG_PROFILE)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" \
Expand Down
5 changes: 4 additions & 1 deletion scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
<configuration>
<skipTests>false</skipTests>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
</argLine>
<skipTests>${skipTests}</skipTests>
</configuration>
</plugin>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,5 @@ object Context {
val gpu: Context = org.apache.mxnet.Context.gpu()
val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
val devstr2type = org.apache.mxnet.Context.devstr2type.asJava

def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi

import org.apache.mxnet.javaapi.DType.DType

import collection.JavaConverters._

@AddJNDArrayAPIs(false)
object NDArray {
implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)

implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd

def waitall(): Unit = org.apache.mxnet.NDArray.waitall()

def onehotEncode(indices: NDArray, out: NDArray): NDArray
= org.apache.mxnet.NDArray.onehotEncode(indices, out)

def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
def empty(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
def zeros(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
= org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
def ones(ctx: Context, shape: Array[Int]): NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
= org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
def full(shape: Shape, value: Float, ctx: Context): NDArray
= org.apache.mxnet.NDArray.full(shape, value, ctx)

def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)

def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)

def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)

def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)

def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)

def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)

def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
def greaterEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)

def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)

def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
def lesserEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)

def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
= org.apache.mxnet.NDArray.array(
sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)

def arange(start: Float, stop: Float, step: Float, repeat: Int,
ctx: Context, dType: DType.DType): NDArray =
org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
}

class NDArray(val nd : org.apache.mxnet.NDArray ) {

def this(arr : Array[Float], shape : Shape, ctx : Context) = {
this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
}

def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : Context) = {
this(NDArray.array(arr, shape, ctx))
}

def serialize() : Array[Byte] = nd.serialize()

def dispose() : Unit = nd.dispose()
def disposeDeps() : NDArray = nd.disposeDepsExcept()
// def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()

def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)

def slice (i : Int) : NDArray = nd.slice(i)

def at(idx : Int) : NDArray = nd.at(idx)

def T : NDArray = nd.T

def dtype : DType = nd.dtype

def asType(dtype : DType) : NDArray = nd.asType(dtype)

def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)

def waitToRead(): Unit = nd.waitToRead()

def context : Context = nd.context

def set(value : Float) : NDArray = nd.set(value)
def set(other : NDArray) : NDArray = nd.set(other)
def set(other : Array[Float]) : NDArray = nd.set(other)

def add(other : NDArray) : NDArray = this.nd + other.nd
def add(other : Float) : NDArray = this.nd + other
def _add(other : NDArray) : NDArray = this.nd += other
def _add(other : Float) : NDArray = this.nd += other
def subtract(other : NDArray) : NDArray = this.nd - other
def subtract(other : Float) : NDArray = this.nd - other
def _subtract(other : NDArray) : NDArray = this.nd -= other
def _subtract(other : Float) : NDArray = this.nd -= other
def multiply(other : NDArray) : NDArray = this.nd * other
def multiply(other : Float) : NDArray = this.nd * other
def _multiply(other : NDArray) : NDArray = this.nd *= other
def _multiply(other : Float) : NDArray = this.nd *= other
def div(other : NDArray) : NDArray = this.nd / other
def div(other : Float) : NDArray = this.nd / other
def _div(other : NDArray) : NDArray = this.nd /= other
def _div(other : Float) : NDArray = this.nd /= other
def pow(other : NDArray) : NDArray = this.nd ** other
def pow(other : Float) : NDArray = this.nd ** other
def _pow(other : NDArray) : NDArray = this.nd **= other
def _pow(other : Float) : NDArray = this.nd **= other
def mod(other : NDArray) : NDArray = this.nd % other
def mod(other : Float) : NDArray = this.nd % other
def _mod(other : NDArray) : NDArray = this.nd %= other
def _mod(other : Float) : NDArray = this.nd %= other
def greater(other : NDArray) : NDArray = this.nd > other
def greater(other : Float) : NDArray = this.nd > other
def greaterEqual(other : NDArray) : NDArray = this.nd >= other
def greaterEqual(other : Float) : NDArray = this.nd >= other
def lesser(other : NDArray) : NDArray = this.nd < other
def lesser(other : Float) : NDArray = this.nd < other
def lesserEqual(other : NDArray) : NDArray = this.nd <= other
def lesserEqual(other : Float) : NDArray = this.nd <= other

def toArray : Array[Float] = nd.toArray

def toScalar : Float = nd.toScalar

def copyTo(other : NDArray) : NDArray = nd.copyTo(other)

def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)

def copy() : NDArray = copyTo(this.context)

def shape : Shape = nd.shape

def size : Int = shape.product

def asInContext(context: Context): NDArray = nd.asInContext(context)

override def equals(obj: Any): Boolean = nd.equals(obj)
override def hashCode(): Int = nd.hashCode
}

object NDArrayFuncReturn {
implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn)
: org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn
implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn)
: NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn)
}

private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) {
def head : NDArray = ndFuncReturn.head
def get : NDArray = ndFuncReturn.get
def apply(i : Int) : NDArray = ndFuncReturn.apply(i)
// TODO: Add JavaNDArray operational stuff
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi;

import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertTrue;

public class NDArrayTest {
@Test
public void testCreateNDArray() {
NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f},
new Shape(new int[]{1, 3}),
new Context("cpu", 0));
int[] arr = new int[]{1, 3};
assertTrue(Arrays.equals(nd.shape().toArray(), arr));
assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f);
List<Float> list = Arrays.asList(1.0f, 2.0f, 3.0f);
// Second way creating NDArray
nd = NDArray.array(list,
new Shape(new int[]{1, 3}),
new Context("cpu", 0));
assertTrue(Arrays.equals(nd.shape().toArray(), arr));
}

@Test
public void testZeroOneEmpty(){
NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100});
NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100});
NDArray empty = NDArray.empty(new Context("cpu", 0), new int[]{100, 100});
int[] arr = new int[]{100, 100};
assertTrue(Arrays.equals(ones.shape().toArray(), arr));
assertTrue(Arrays.equals(zeros.shape().toArray(), arr));
assertTrue(Arrays.equals(empty.shape().toArray(), arr));
}

@Test
public void testComparison(){
NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
nd = nd.add(nd2);
float[] greater = new float[]{1, 1, 1};
assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater));
nd = nd.subtract(nd2);
nd = nd.subtract(nd2);
float[] lesser = new float[]{0, 0, 0};
assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
}

@Test
public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(nd).invoke().get().toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(nd, nd).setout(dotResult).invoke().get();
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Loading