+## Java Tutorials
+* Getting Started
+ * [Developer Environment Setup on IntelliJ IDE](/tutorials/java/mxnet_java_on_intellij.html)
+* [Multi Object Detection using pre-trained Single Shot Detector (SSD) Model](/tutorials/java/ssd_inference.html)
+* [MXNet-Java Examples](https://github.com/apache/incubator-mxnet/tree/master/scala-package/examples/src/main/java/org/apache/mxnetexamples)
+
+
## C++ Tutorials
* Models
diff --git a/docs/tutorials/java/mxnet_java_on_intellij.md b/docs/tutorials/java/mxnet_java_on_intellij.md
new file mode 100644
index 000000000000..d9a215998005
--- /dev/null
+++ b/docs/tutorials/java/mxnet_java_on_intellij.md
@@ -0,0 +1,171 @@
+# Run MXNet Java Examples Using the IntelliJ IDE (macOS)
+
+This tutorial guides you through setting up a simple Java project in IntelliJ IDE on macOS and demonstrates usage of the MXNet Java APIs.
+
+## Prerequisites:
+To use this tutorial you need the following pre-requisites:
+
+- [Java 8 JDK](http://www.oracle.com/technetwork/java/javase/downloads/index.html)
+- [Maven](https://maven.apache.org/install.html)
+- [OpenCV](https://opencv.org/)
+- [IntelliJ IDEA](https://www.jetbrains.com/idea/) (One can download the community edition from [here](https://www.jetbrains.com/idea/download))
+
+### MacOS Prerequisites
+
+You can run the following commands to install the prerequisites.
+```
+/usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
+brew update
+brew tap caskroom/versions
+brew cask install java8
+brew install maven
+brew install opencv
+```
+
+You can also run this tutorial on an Ubuntu machine after installing the following prerequisites.
+### Ubuntu Prerequisites
+
+Run the following commands to install the prerequisites.
+
+```bash
+wget https://github.com/apache/incubator-mxnet/blob/master/ci/docker/install/ubuntu_core.sh
+sudo ./ubuntu_core.sh
+wget https://github.com/apache/incubator-mxnet/blob/master/ci/docker/install/ubuntu_scala.sh
+sudo ./ubuntu_scala.sh
+```
+
+Note : You might need to run `chmod u+x ubuntu_core.sh` and `chmod u+x ubuntu_scala` before running the scripts.
+
+The `ubuntu_scala.sh` installs the common dependencies required for both MXNet Scala and MXNet Java packages.
+
+## Set Up Your Project
+
+**Step 1.** Install and setup [IntelliJ IDEA](https://www.jetbrains.com/idea/)
+
+**Step 2.** Create a new Project:
+
+![intellij welcome](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/scala/intellij-welcome.png)
+
+From the IntelliJ welcome screen, select "Create New Project".
+
+Choose the Maven project type.
+
+Select the checkbox for `Create from archetype`, then choose `org.apache.maven.archetypes:maven-archetype-quickstart` from the list below. More on this can be found on a Maven tutorial : [Maven in 5 Minutes](https://maven.apache.org/guides/getting-started/maven-in-five-minutes.html).
+
+![maven project type - archetype](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/java/project-archetype.png)
+
+click `Next`.
+
+![project metadata](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/java/intellij-project-metadata.png)
+
+Set the project's metadata. For this tutorial, use the following:
+
+**GroupId**
+```
+mxnet
+```
+**ArtifactId**
+```
+ArtifactId: javaMXNet
+```
+**Version**
+```
+1.0-SNAPSHOT
+```
+
+![project properties](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/java/intellij-project-properties.png)
+
+Review the project's properties. The settings can be left as their default.
+
+![project location](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/java/intellij-project-location.png)
+
+Set the project's location. The rest of the settings can be left as their default.
+
+![project 1](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/java/intellij-project-pom.png)
+
+After clicking Finish, you will be presented with the project's first view.
+The project's `pom.xml` will be open for editing.
+
+**Step 3.** Add the following Maven dependency to your `pom.xml` file under the `dependencies` tag:
+
+```html
+
+ org.apache.mxnet
+ mxnet-full_2.11-osx-x86_64-cpu
+ 1.4.0
+
+```
+
+To view the latest MXNet Maven packages, you can check [MXNet Maven package repository](https://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.mxnet%22)
+
+Note :
+- Change the osx-x86_64 to linux-x86_64 if your platform is linux.
+- Change cpu into gpu if you have a gpu backed machine and want to use gpu.
+
+
+**Step 4.** Import dependencies with Maven:
+
+ - Note the prompt in the lower right corner that states "Maven projects need to be imported". If this is not visible, click on the little greed balloon that appears in the lower right corner.
+
+![import_dependencies](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/java/project-import-changes.png)
+
+Click "Import Changes" in this prompt.
+
+**Step 5.** Build the project:
+- To build the project, from the menu choose Build, and then choose Build Project.
+
+**Step 6.** Navigate to the App.java class in the project and paste the code from HelloWorld.java from [Java Demo project](https://github.com/apache/incubator-mxnet/blob/java-api/scala-package/mxnet-demo/java-demo/src/main/java/sample/HelloWorld.java) on MXNet repository, overwriting the original hello world code.
+You can also grab the entire [Java Demo project](https://github.com/apache/incubator-mxnet/tree/java-api/scala-package/mxnet-demo/java-demo) and run it by following the instructions on the [README](https://github.com/apache/incubator-mxnet/blob/java-api/scala-package/mxnet-demo/java-demo/README.md)
+
+**Step 7.** Now run the App.java.
+
+The result should be something similar to this:
+
+```
+Hello World!
+(1,2)
+Process finished with exit code 0
+```
+
+### Troubleshooting
+
+If you get an error, check the dependencies at the beginning of this tutorial. For example, you might see the following in the middle of the error messages, where `x.x` would the version it's looking for.
+
+```
+...
+Library not loaded: /usr/local/opt/opencv/lib/libopencv_calib3d.x.x.dylib
+...
+```
+
+This can be resolved be installing OpenCV.
+
+### Command Line Build Option
+
+- You can also compile the project by using the following command at the command line. Change directories to this project's root folder then run the following:
+
+```bash
+mvn clean install dependency:copy-dependencies
+```
+If the command succeeds, you should see a lot of info and some warning messages, followed by:
+
+```bash
+[INFO] ------------------------------------------------------------------------
+[INFO] BUILD SUCCESS
+[INFO] ------------------------------------------------------------------------
+[INFO] Total time: 3.475 s
+[INFO] Finished at: 2018-11-08T05:06:31-08:00
+[INFO] ------------------------------------------------------------------------
+```
+The build generates a new jar file in the `target` folder called `javaMXNet-1.0-SNAPSHOT.jar`.
+
+To run the App.java use the following command from the project's root folder and you should see the same output as we got when the project was run from IntelliJ.
+```bash
+java -cp target/javaMXNet-1.0-SNAPSHOT.jar:target/dependency/* mxnet.App
+```
+
+## Next Steps
+For more information about MXNet Java resources, see the following:
+
+* [Java Inference API](https://mxnet.incubator.apache.org/api/java/infer.html)
+* [Java Inference Examples](https://github.com/apache/incubator-mxnet/tree/java-api/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/)
+* [MXNet Tutorials Index](http://mxnet.io/tutorials/index.html)
diff --git a/docs/tutorials/java/ssd_inference.md b/docs/tutorials/java/ssd_inference.md
new file mode 100644
index 000000000000..6bcaaa2504a4
--- /dev/null
+++ b/docs/tutorials/java/ssd_inference.md
@@ -0,0 +1,186 @@
+# Multi Object Detection using pre-trained SSD Model via Java Inference APIs
+
+This tutorial shows how to use MXNet Java Inference APIs to run inference on a pre-trained Single Shot Detector (SSD) Model.
+
+The SSD model is trained on the Pascal VOC 2012 dataset. The network is a SSD model built on Resnet50 as the base network to extract image features. The model is trained to detect the following entities (classes): ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']. For more details about the model, you can refer to the [MXNet SSD example](https://github.com/apache/incubator-mxnet/tree/master/example/ssd).
+
+## Prerequisites
+
+To complete this tutorial, you need the following:
+* [MXNet Java Setup on IntelliJ IDEA](/java/mxnet_java_on_intellij.html) (Optional)
+* [wget](https://www.gnu.org/software/wget/) To download model artifacts
+* SSD Model artifacts
+ * Use the following script to get the SSD Model files :
+```bash
+data_path=/tmp/resnet50_ssd
+mkdir -p "$data_path"
+wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json -P $data_path
+wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params -P $data_path
+wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/synset.txt -P $data_path
+```
+* Test images : A few sample images to run inference on.
+ * Use the following script to download sample images :
+```bash
+image_path=/tmp/resnet50_ssd/images
+mkdir -p "$image_path"
+cd $image_path
+wget https://cloud.githubusercontent.com/assets/3307514/20012567/cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg -O dog.jpg
+wget https://cloud.githubusercontent.com/assets/3307514/20012563/cbb41382-a27d-11e6-92a9-18dab4fd1ad3.jpg -O person.jpg
+```
+
+Alternately, you can get the entire SSD Model artifacts + images in one single script from the MXNet Repository by running [get_ssd_data.sh script](https://github.com/apache/incubator-mxnet/blob/master/scala-package/examples/scripts/infer/objectdetector/get_ssd_data.sh)
+
+## Time to code!
+1\. Following the [MXNet Java Setup on IntelliJ IDEA](/java/mxnet_java_on_intellij.html) tutorial, in the same project `JavaMXNet`, create a new empty class called : `ObjectDetectionTutorial.java`.
+
+2\. In the `main` function of `ObjectDetectionTutorial.java` define the downloaded model path and the image data paths. This is the same path where we downloaded the model artifacts and images in a previous step.
+
+```java
+String modelPathPrefix = "/tmp/resnet50_ssd/resnet50_ssd_model";
+String inputImagePath = "/tmp/resnet50_ssd/images/dog.jpg";
+```
+
+3\. We can run the inference code in this example on either CPU or GPU (if you have a GPU backed machine) by choosing the appropriate context.
+
+```java
+
+List context = getContext();
+...
+
+private static List getContext() {
+List ctx = new ArrayList<>();
+ctx.add(Context.cpu()); // Choosing CPU Context here
+
+return ctx;
+}
+```
+
+4\. To provide an input to the model, define the input shape to the model and the Input Data Descriptor (DataDesc) as shown below :
+
+```java
+Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+List inputDescriptors = new ArrayList();
+inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+```
+
+The input shape can be interpreted as follows : The input has a batch size of 1, with 3 RGB channels in the image, and the height and width of the image is 512 each.
+
+5\. To run an actual inference on the given image, add the following lines to the `ObjectDetectionTutorial.java` class :
+
+```java
+BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
+ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+List> output = objDet.imageObjectDetect(img, 3); // Top 3 objects detected will be returned
+```
+
+6\. Let's piece all of the above steps together by showing the final contents of the `ObjectDetectionTutorial.java`.
+
+```java
+package mxnet;
+
+import org.apache.mxnet.infer.javaapi.ObjectDetector;
+import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
+import org.apache.mxnet.javaapi.Context;
+import org.apache.mxnet.javaapi.DType;
+import org.apache.mxnet.javaapi.DataDesc;
+import org.apache.mxnet.javaapi.Shape;
+
+import java.awt.image.BufferedImage;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class ObjectDetectionTutorial {
+
+ public static void main(String[] args) {
+
+ String modelPathPrefix = "/tmp/resnet50_ssd/resnet50_ssd_model";
+
+ String inputImagePath = "/tmp/resnet50_ssd/images/dog.jpg";
+
+ List context = getContext();
+
+ Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+
+ List inputDescriptors = new ArrayList();
+ inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+
+ BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
+ ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+ List> output = objDet.imageObjectDetect(img, 3);
+
+ printOutput(output, inputShape);
+ }
+
+
+ private static List getContext() {
+ List ctx = new ArrayList<>();
+ ctx.add(Context.cpu());
+
+ return ctx;
+ }
+
+ private static void printOutput(List> output, Shape inputShape) {
+
+ StringBuilder outputStr = new StringBuilder();
+
+ int width = inputShape.get(3);
+ int height = inputShape.get(2);
+
+ for (List ele : output) {
+ for (ObjectDetectorOutput i : ele) {
+ outputStr.append("Class: " + i.getClassName() + "\n");
+ outputStr.append("Probabilties: " + i.getProbability() + "\n");
+
+ List coord = Arrays.asList(i.getXMin() * width,
+ i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+ StringBuilder sb = new StringBuilder();
+ for (float c: coord) {
+ sb.append(", ").append(c);
+ }
+ outputStr.append("Coord:" + sb.substring(2)+ "\n");
+ }
+ }
+ System.out.println(outputStr);
+
+ }
+}
+```
+
+7\. To compile and run this code, change directories to this project's root folder, then run the following:
+```bash
+mvn clean install dependency:copy-dependencies
+```
+
+The build generates a new jar file in the `target` folder called `javaMXNet-1.0-SNAPSHOT.jar`.
+
+To run the ObjectDetectionTutorial.java use the following command from the project's root folder.
+```bash
+java -cp target/javaMXNet-1.0-SNAPSHOT.jar:target/dependency/* mxnet.ObjectDetectionTutorial
+```
+
+You should see a similar output being generated for the dog image that we used:
+```bash
+Class: car
+Probabilties: 0.99847263
+Coord:312.21335, 72.02908, 456.01443, 150.66176
+Class: bicycle
+Probabilties: 0.9047381
+Coord:155.9581, 149.96365, 383.83694, 418.94516
+Class: dog
+Probabilties: 0.82268167
+Coord:83.82356, 179.14001, 206.63783, 476.78754
+```
+
+![dog_1](https://cloud.githubusercontent.com/assets/3307514/20012567/cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg)
+
+The results returned by the inference call translate into the regions in the image where the model detected objects.
+
+![dog_2](https://cloud.githubusercontent.com/assets/3307514/19171063/91ec2792-8be0-11e6-983c-773bd6868fa8.png)
+
+## Next Steps
+For more information about MXNet Java resources, see the following:
+
+* [Java Inference API](/api/java/infer.html)
+* [Java Inference Examples](https://github.com/apache/incubator-mxnet/tree/java-api/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/)
+* [MXNet Tutorials Index](/tutorials/index.html)
diff --git a/scala-package/.gitignore b/scala-package/.gitignore
index 0f860e62836a..6aa4da6b1cfc 100644
--- a/scala-package/.gitignore
+++ b/scala-package/.gitignore
@@ -1,5 +1,8 @@
.flattened-pom.xml
core/src/main/scala/org/apache/mxnet/NDArrayAPIBase.scala
core/src/main/scala/org/apache/mxnet/NDArrayBase.scala
+core/src/main/scala/org/apache/mxnet/javaapi/NDArrayBase.scala
core/src/main/scala/org/apache/mxnet/SymbolAPIBase.scala
core/src/main/scala/org/apache/mxnet/SymbolBase.scala
+examples/scripts/infer/images/
+examples/scripts/infer/models/
diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 56ff4db1408b..daf8c389cbeb 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -10,6 +10,10 @@
../pom.xml
+
+ true
+
+
mxnet-core_2.11MXNet Scala Package - Core
@@ -20,12 +24,6 @@
false
-
- integrationtest
-
- true
-
- osx-x86_64-cpu
@@ -86,7 +84,10 @@
maven-surefire-plugin2.22.0
- false
+
+ -Djava.library.path=${project.parent.basedir}/native/${platform}/target
+
+ ${skipTests}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
index acae8bf5994c..3d397e3fc496 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -19,7 +19,14 @@ package org.apache.mxnet.javaapi
import collection.JavaConverters._
import scala.language.implicitConversions
-class Context(val context: org.apache.mxnet.Context) {
+/**
+ * Constructing a context which is used to specify the device and device type that will
+ * be utilized by the engine.
+ *
+ * @param deviceTypeName {'cpu', 'gpu'} String representing the device type
+ * @param deviceId The device id of the device, needed for GPU
+ */
+class Context private[mxnet] (val context: org.apache.mxnet.Context) {
val deviceTypeid: Int = context.deviceTypeid
@@ -27,6 +34,11 @@ class Context(val context: org.apache.mxnet.Context) {
= this(new org.apache.mxnet.Context(deviceTypeName, deviceId))
def withScope[T](body: => T): T = context.withScope(body)
+
+ /**
+ * Return device type of current context.
+ * @return device_type
+ */
def deviceType: String = context.deviceType
override def toString: String = context.toString
@@ -43,6 +55,5 @@ object Context {
val gpu: Context = org.apache.mxnet.Context.gpu()
val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
val devstr2type = org.apache.mxnet.Context.devstr2type.asJava
-
def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
index 888a5d812c7a..d0e10815a1e6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
@@ -16,10 +16,9 @@
*/
package org.apache.mxnet.javaapi
-
import scala.language.implicitConversions
-class DataDesc(val dataDesc: org.apache.mxnet.DataDesc) {
+class DataDesc private[mxnet] (val dataDesc: org.apache.mxnet.DataDesc) {
def this(name: String, shape: Shape, dType: DType.DType, layout: String) =
this(new org.apache.mxnet.DataDesc(name, shape, dType, layout))
@@ -32,5 +31,13 @@ object DataDesc{
implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc
+ /**
+ * Get the dimension that corresponds to the batch size.
+ * @param layout layout string. For example, "NCHW".
+ * @return An axis indicating the batch_size dimension. When data-parallelism is used,
+ * the data will be automatically split and concatenate along the batch_size dimension.
+ * Axis can be -1, which means the whole array will be copied
+ * for each data-parallelism device.
+ */
def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout))
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
new file mode 100644
index 000000000000..6b4f4bdebda5
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -0,0 +1,397 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.javaapi.DType.DType
+
+import collection.JavaConverters._
+
+@AddJNDArrayAPIs(false)
+object NDArray extends NDArrayBase {
+ implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)
+
+ implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
+
+ def waitall(): Unit = org.apache.mxnet.NDArray.waitall()
+
+ /**
+ * One hot encoding indices into matrix out.
+ * @param indices An NDArray containing indices of the categorical features.
+ * @param out The result holder of the encoding.
+ * @return Same as out.
+ */
+ def onehotEncode(indices: NDArray, out: NDArray): NDArray
+ = org.apache.mxnet.NDArray.onehotEncode(indices, out)
+
+ /**
+ * Create an empty uninitialized new NDArray, with specified shape.
+ *
+ * @param shape shape of the NDArray.
+ * @param ctx The context of the NDArray.
+ *
+ * @return The created NDArray.
+ */
+ def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
+ def empty(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+ def empty(ctx: Context, shape: java.util.List[java.lang.Integer]): NDArray
+ = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+
+ /**
+ * Create a new NDArray filled with 0, with specified shape.
+ *
+ * @param shape shape of the NDArray.
+ * @param ctx The context of the NDArray.
+ *
+ * @return The created NDArray.
+ */
+ def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
+ def zeros(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+ def zeros(ctx: Context, shape: java.util.List[java.lang.Integer]): NDArray
+ = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+
+ /**
+ * Create a new NDArray filled with 1, with specified shape.
+ * @param shape shape of the NDArray.
+ * @param ctx The context of the NDArray.
+ * @return The created NDArray.
+ */
+ def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
+ def ones(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+ def ones(ctx: Context, shape: java.util.List[java.lang.Integer]): NDArray
+ = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+
+ /**
+ * Create a new NDArray filled with given value, with specified shape.
+ * @param shape shape of the NDArray.
+ * @param value value to be filled with
+ * @param ctx The context of the NDArray
+ */
+ def full(shape: Shape, value: Float, ctx: Context): NDArray
+ = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
+ def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+
+ def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+
+ def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+
+
+ /**
+ * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are same,
+ * otherwise return 0(false).
+ */
+ def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+ def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+
+ /**
+ * Returns the result of element-wise **not equal to** (!=) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are different,
+ * otherwise return 0(false).
+ */
+ def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+ def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+
+ /**
+ * Returns the result of element-wise **greater than** (>) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
+ * otherwise return 0(false).
+ */
+ def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+ def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+
+ /**
+ * Returns the result of element-wise **greater than or equal to** (>=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than equal to rhs
+ * otherwise return 0(false).
+ */
+ def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+ def greaterEqual(lhs: NDArray, rhs: Float): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+
+ /**
+ * Returns the result of element-wise **lesser than** (<) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are less than rhs,
+ * otherwise return 0(false).
+ */
+ def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+ def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+
+ /**
+ * Returns the result of element-wise **lesser than or equal to** (<=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are
+ * lesser than equal to rhs, otherwise return 0(false).
+ */
+ def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+ def lesserEqual(lhs: NDArray, rhs: Float): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+
+ /**
+ * Create a new NDArray that copies content from source_array.
+ * @param sourceArr Source data to create NDArray from.
+ * @param shape shape of the NDArray
+ * @param ctx The context of the NDArray, default to current default context.
+ * @return The created NDArray.
+ */
+ def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
+ = org.apache.mxnet.NDArray.array(
+ sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+
+ /**
+ * Returns evenly spaced values within a given interval.
+ * Values are generated within the half-open interval [`start`, `stop`). In other
+ * words, the interval includes `start` but excludes `stop`.
+ * @param start Start of interval.
+ * @param stop End of interval.
+ * @param step Spacing between values.
+ * @param repeat Number of times to repeat each element.
+ * @param ctx Device context.
+ * @param dType The data type of the `NDArray`.
+ * @return NDArray of evenly spaced values in the specified range.
+ */
+ def arange(start: Float, stop: Float, step: Float, repeat: Int,
+ ctx: Context, dType: DType.DType): NDArray =
+ org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
+}
+
+/**
+ * NDArray object in mxnet.
+ * NDArray is basic ndarray/Tensor like data structure in mxnet.
+ *
+ * NOTE: NDArray is stored in native memory. Use NDArray in a try-with-resources() construct
+ * or a [[org.apache.mxnet.ResourceScope]] in a try-with-resource to have them
+ * automatically disposed. You can explicitly control the lifetime of NDArray
+ * by calling dispose manually. Failure to do this will result in leaking native memory.
+ *
+ */
+class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
+
+ def this(arr: Array[Float], shape: Shape, ctx: Context) = {
+ this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+ }
+
+ def this(arr: java.util.List[java.lang.Float], shape: Shape, ctx: Context) = {
+ this(NDArray.array(arr, shape, ctx))
+ }
+
+ def serialize(): Array[Byte] = nd.serialize()
+
+ /**
+ * Release the native memory.
+ * The NDArrays it depends on will NOT be disposed.
+ * The object shall never be used after it is disposed.
+ */
+ def dispose(): Unit = nd.dispose()
+
+ /**
+ * Dispose all NDArrays who help to construct this array.
+ * e.g. (a * b + c).disposeDeps() will dispose a, b, c (including their deps) and a * b
+ * @return this array
+ */
+ def disposeDeps(): NDArray = nd.disposeDepsExcept()
+
+ /**
+ * Dispose all NDArrays who help to construct this array, excepts those in the arguments.
+ * e.g. (a * b + c).disposeDepsExcept(a, b)
+ * will dispose c and a * b.
+ * Note that a, b's dependencies will not be disposed either.
+ * @param arr the Array of NDArray not to dispose
+ * @return this array
+ */
+ def disposeDepsExcept(arr: Array[NDArray]): NDArray =
+ nd.disposeDepsExcept(arr.map(NDArray.toNDArray): _*)
+
+ /**
+ * Return a sliced NDArray that shares memory with current one.
+ * NDArray only support continuous slicing on axis 0
+ *
+ * @param start Starting index of slice.
+ * @param stop Finishing index of slice.
+ *
+ * @return a sliced NDArray that shares memory with current one.
+ */
+ def slice(start: Int, stop: Int): NDArray = nd.slice(start, stop)
+
+ /**
+ * Return a sliced NDArray at the ith position of axis0
+ * @param i
+ * @return a sliced NDArray that shares memory with current one.
+ */
+ def slice (i: Int): NDArray = nd.slice(i)
+
+ /**
+ * Return a sub NDArray that shares memory with current one.
+ * the first axis will be rolled up, which causes its shape different from slice(i, i+1)
+ * @param idx index of sub array.
+ */
+ def at(idx: Int): NDArray = nd.at(idx)
+
+ def T: NDArray = nd.T
+
+ /**
+ * Get data type of current NDArray.
+ * @return class representing type of current ndarray
+ */
+ def dtype: DType = nd.dtype
+
+ /**
+ * Return a copied numpy array of current array with specified type.
+ * @param dtype Desired type of result array.
+ * @return A copy of array content.
+ */
+ def asType(dtype: DType): NDArray = nd.asType(dtype)
+
+ /**
+ * Return a reshaped NDArray that shares memory with current one.
+ * @param dims New shape.
+ *
+ * @return a reshaped NDArray that shares memory with current one.
+ */
+ def reshape(dims: Array[Int]): NDArray = nd.reshape(dims)
+
+ /**
+ * Block until all pending writes operations on current NDArray are finished.
+ * This function will return when all the pending writes to the current
+ * NDArray finishes. There can still be pending read going on when the
+ * function returns.
+ */
+ def waitToRead(): Unit = nd.waitToRead()
+
+ /**
+ * Get context of current NDArray.
+ * @return The context of current NDArray.
+ */
+ def context: Context = nd.context
+
+ /**
+ * Set the values of the NDArray
+ * @param value Value to set
+ * @return Current NDArray
+ */
+ def set(value: Float): NDArray = nd.set(value)
+ def set(other: NDArray): NDArray = nd.set(other)
+ def set(other: Array[Float]): NDArray = nd.set(other)
+
+ def add(other: NDArray): NDArray = this.nd + other.nd
+ def add(other: Float): NDArray = this.nd + other
+ def addInplace(other: NDArray): NDArray = this.nd += other
+ def addInplace(other: Float): NDArray = this.nd += other
+ def subtract(other: NDArray): NDArray = this.nd - other
+ def subtract(other: Float): NDArray = this.nd - other
+ def subtractInplace(other: NDArray): NDArray = this.nd -= other
+ def subtractInplace(other: Float): NDArray = this.nd -= other
+ def multiply(other: NDArray): NDArray = this.nd * other
+ def multiply(other: Float): NDArray = this.nd * other
+ def multiplyInplace(other: NDArray): NDArray = this.nd *= other
+ def multiplyInplace(other: Float): NDArray = this.nd *= other
+ def div(other: NDArray): NDArray = this.nd / other
+ def div(other: Float): NDArray = this.nd / other
+ def divInplace(other: NDArray): NDArray = this.nd /= other
+ def divInplace(other: Float): NDArray = this.nd /= other
+ def pow(other: NDArray): NDArray = this.nd ** other
+ def pow(other: Float): NDArray = this.nd ** other
+ def powInplace(other: NDArray): NDArray = this.nd **= other
+ def powInplace(other: Float): NDArray = this.nd **= other
+ def mod(other: NDArray): NDArray = this.nd % other
+ def mod(other: Float): NDArray = this.nd % other
+ def modInplace(other: NDArray): NDArray = this.nd %= other
+ def modInplace(other: Float): NDArray = this.nd %= other
+ def greater(other: NDArray): NDArray = this.nd > other
+ def greater(other: Float): NDArray = this.nd > other
+ def greaterEqual(other: NDArray): NDArray = this.nd >= other
+ def greaterEqual(other: Float): NDArray = this.nd >= other
+ def lesser(other: NDArray): NDArray = this.nd < other
+ def lesser(other: Float): NDArray = this.nd < other
+ def lesserEqual(other: NDArray): NDArray = this.nd <= other
+ def lesserEqual(other: Float): NDArray = this.nd <= other
+
+ /**
+ * Return a copied flat java array of current array (row-major).
+ * @return A copy of array content.
+ */
+ def toArray: Array[Float] = nd.toArray
+
+ /**
+ * Return a CPU scalar(float) of current ndarray.
+ * This ndarray must have shape (1,)
+ *
+ * @return The scalar representation of the ndarray.
+ */
+ def toScalar: Float = nd.toScalar
+
+ /**
+ * Copy the content of current array to other.
+ *
+ * @param other Target NDArray or context we want to copy data to.
+ * @return The copy target NDArray
+ */
+ def copyTo(other: NDArray): NDArray = nd.copyTo(other)
+
+ /**
+ * Copy the content of current array to a new NDArray in the context.
+ *
+ * @param ctx Target context we want to copy data to.
+ * @return The copy target NDArray
+ */
+ def copyTo(ctx: Context): NDArray = nd.copyTo(ctx)
+
+ /**
+ * Clone the current array
+ * @return the copied NDArray in the same context
+ */
+ def copy(): NDArray = copyTo(this.context)
+
+ /**
+ * Get shape of current NDArray.
+ * @return an array representing shape of current ndarray
+ */
+ def shape: Shape = nd.shape
+
+
+ def size: Int = shape.product
+
+ /**
+ * Return an `NDArray` that lives in the target context. If the array
+ * is already in that context, `self` is returned. Otherwise, a copy is made.
+ * @param context The target context we want the return value to live in.
+ * @return A copy or `self` as an `NDArray` that lives in the target context.
+ */
+ def asInContext(context: Context): NDArray = nd.asInContext(context)
+
+ override def equals(obj: Any): Boolean = nd.equals(obj)
+ override def hashCode(): Int = nd.hashCode
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
index 5c4464f84211..b795fe31f726 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
@@ -24,7 +24,7 @@ import scala.language.implicitConversions
* Shape of [[NDArray]] or other data
*/
-class Shape(val shape: org.apache.mxnet.Shape) {
+class Shape private[mxnet] (val shape: org.apache.mxnet.Shape) {
def this(dims: java.util.List[java.lang.Integer])
= this(new org.apache.mxnet.Shape(dims.asScala.map(Int.unbox)))
def this(dims: Array[Int]) = this(new org.apache.mxnet.Shape(dims))
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
new file mode 100644
index 000000000000..2659b7848bc6
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+import org.apache.mxnet.javaapi.NDArrayBase.*;
+
+import static org.junit.Assert.assertTrue;
+
+public class NDArrayTest {
+ @Test
+ public void testCreateNDArray() {
+ NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f},
+ new Shape(new int[]{1, 3}),
+ new Context("cpu", 0));
+ int[] arr = new int[]{1, 3};
+ assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+ assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f);
+ List list = Arrays.asList(1.0f, 2.0f, 3.0f);
+ // Second way creating NDArray
+ nd = NDArray.array(list,
+ new Shape(new int[]{1, 3}),
+ new Context("cpu", 0));
+ assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+ }
+
+ @Test
+ public void testZeroOneEmpty(){
+ NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100});
+ NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100});
+ NDArray empty = NDArray.empty(new Context("cpu", 0), new int[]{100, 100});
+ int[] arr = new int[]{100, 100};
+ assertTrue(Arrays.equals(ones.shape().toArray(), arr));
+ assertTrue(Arrays.equals(zeros.shape().toArray(), arr));
+ assertTrue(Arrays.equals(empty.shape().toArray(), arr));
+ }
+
+ @Test
+ public void testComparison(){
+ NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
+ NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
+ nd = nd.add(nd2);
+ float[] greater = new float[]{1, 1, 1};
+ assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater));
+ nd = nd.subtract(nd2);
+ nd = nd.subtract(nd2);
+ float[] lesser = new float[]{0, 0, 0};
+ assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
+ }
+
+ @Test
+ public void testGenerated(){
+ NDArray$ NDArray = NDArray$.MODULE$;
+ float[] arr = new float[]{1.0f, 2.0f, 3.0f};
+ NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
+ float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
+ float cal = 0.0f;
+ for (float ele : arr) {
+ cal += ele * ele;
+ }
+ cal = (float) Math.sqrt(cal);
+ assertTrue(Math.abs(result - cal) < 1e-5);
+ NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
+ NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
+ assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
+ }
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ResourceScopeTestSuite.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ResourceScopeTestSuite.java
new file mode 100644
index 000000000000..1c246d870e28
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ResourceScopeTestSuite.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.mxnet.javaapi;
+
+import org.apache.mxnet.NativeResourceRef;
+import org.apache.mxnet.ResourceScope;
+import org.junit.Test;
+
+import java.util.*;
+import java.util.concurrent.Callable;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class ResourceScopeTestSuite {
+
+ /**
+ * This is a placeholder class to test out whether NDArray References get collected or not when using
+ * try-with-resources in Java.
+ *
+ */
+ class TestNDArray {
+ NDArray selfArray;
+
+ public TestNDArray(Context context, int[] shape) {
+ this.selfArray = NDArray.ones(context, shape);
+ }
+
+ public boolean verifyIsDisposed() {
+ return this.selfArray.nd().isDisposed();
+ }
+
+ public NativeResourceRef getNDArrayReference() {
+ return this.selfArray.nd().ref();
+ }
+ }
+
+ @Test
+ public void testNDArrayAutoRelease() {
+ TestNDArray test = null;
+
+ try (ResourceScope scope = new ResourceScope()) {
+ test = new TestNDArray(Context.cpu(), new int[]{100, 100});
+ }
+
+ assertTrue(test.verifyIsDisposed());
+ }
+
+ @Test
+ public void testObjectReleaseFromList() {
+ List list = new ArrayList<>();
+
+ try (ResourceScope scope = new ResourceScope()) {
+ for (int i = 0;i < 10; i++) {
+ list.add(new TestNDArray(Context.cpu(), new int[] {100, 100}));
+ }
+ }
+
+ assertEquals(list.size() , 10);
+ for (TestNDArray item : list) {
+ assertTrue(item.verifyIsDisposed());
+ }
+ }
+
+ @Test
+ public void testObjectReleaseFromMap() {
+ Map stringToNDArrayMap = new HashMap<>();
+
+ try (ResourceScope scope = new ResourceScope()) {
+ for (int i = 0;i < 10; i++) {
+ stringToNDArrayMap.put(String.valueOf(i),new TestNDArray(Context.cpu(), new int[] {i, i}));
+ }
+ }
+
+ assertEquals(stringToNDArrayMap.size(), 10);
+ for (Map.Entry entry : stringToNDArrayMap.entrySet()) {
+ assertTrue(entry.getValue().verifyIsDisposed());
+ }
+
+ Map ndArrayToStringMap = new HashMap<>();
+
+ try (ResourceScope scope = new ResourceScope()) {
+ for (int i = 0;i < 10; i++) {
+ ndArrayToStringMap.put(new TestNDArray(Context.cpu(), new int[] {i, i}), String.valueOf(i));
+ }
+ }
+
+ assertEquals(ndArrayToStringMap.size(), 10);
+ for (Map.Entry entry : ndArrayToStringMap.entrySet()) {
+ assertTrue(entry.getKey().verifyIsDisposed());
+ }
+
+ }
+}
diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml
index 436f2992768b..72a40dc01f01 100644
--- a/scala-package/examples/pom.xml
+++ b/scala-package/examples/pom.xml
@@ -13,13 +13,11 @@
mxnet-examples_2.11MXNet Scala Package - Examples
+
+ true
+
+
-
- unittest
-
- true
-
- integrationtest
diff --git a/scala-package/examples/scripts/benchmark/run_java_inference_bm.sh b/scala-package/examples/scripts/benchmark/run_java_inference_bm.sh
new file mode 100644
index 000000000000..5a468e344829
--- /dev/null
+++ b/scala-package/examples/scripts/benchmark/run_java_inference_bm.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+hw_type=cpu
+if [ "$USE_GPU" = "1" ]
+then
+ hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+ platform=osx-x86_64
+fi
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd)
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*
+
+java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \
+ org.apache.mxnetexamples.javaapi.benchmark.JavaBenchmark $@
+
diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
index 8cea892b5809..6b4edb7c4c94 100755
--- a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
@@ -17,9 +17,21 @@
# specific language governing permissions and limitations
# under the License.
+hw_type=cpu
+if [[ $4 = gpu ]]
+then
+ hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+ platform=osx-x86_64
+fi
MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
-CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
# model dir and prefix
MODEL_DIR=$1
diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
new file mode 100755
index 000000000000..00ed793a7bb5
--- /dev/null
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+hw_type=cpu
+if [[ $4 = gpu ]]
+then
+ hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+ platform=osx-x86_64
+fi
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*:$MXNET_ROOT/scala-package/examples/src/main/scala/org/apache/mxnetexamples/api/java/infer/imageclassifier/*
+
+# model dir and prefix
+MODEL_DIR=$1
+# input image
+INPUT_IMG=$2
+# which input image dir
+INPUT_DIR=$3
+
+java -Xmx8G -cp $CLASS_PATH \
+ org.apache.mxnetexamples.javaapi.infer.objectdetector.SSDClassifierExample \
+ --model-path-prefix $MODEL_DIR \
+ --input-image $INPUT_IMG \
+ --input-dir $INPUT_DIR
diff --git a/scala-package/examples/scripts/infer/predictor/run_predictor_java_example.sh b/scala-package/examples/scripts/infer/predictor/run_predictor_java_example.sh
new file mode 100755
index 000000000000..4ebcc3076a78
--- /dev/null
+++ b/scala-package/examples/scripts/infer/predictor/run_predictor_java_example.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+hw_type=cpu
+if [[ $3 = gpu ]]
+then
+ hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+ platform=osx-x86_64
+fi
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+
+# model dir and prefix
+MODEL_DIR=$1
+# input image
+INPUT_IMG=$2
+
+java -Xmx8G -cp $CLASS_PATH \
+ org.apache.mxnetexamples.javaapi.infer.predictor.PredictorExample \
+ --model-path-prefix $MODEL_DIR \
+ --input-image $INPUT_IMG
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/InferBase.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/InferBase.java
new file mode 100644
index 000000000000..fdcde6b4152c
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/InferBase.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnetexamples.javaapi.benchmark;
+
+import org.apache.mxnet.javaapi.Context;
+import org.kohsuke.args4j.Option;
+
+import java.util.List;
+
+abstract class InferBase {
+ @Option(name = "--num-runs", usage = "Number of runs")
+ public int numRun = 1;
+ @Option(name = "--model-name", usage = "Name of the model")
+ public String modelName = "";
+ @Option(name = "--batchsize", usage = "Size of the batch")
+ public int batchSize = 1;
+
+ public abstract void preProcessModel(List context);
+ public abstract void runSingleInference();
+ public abstract void runBatchInference();
+}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/JavaBenchmark.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/JavaBenchmark.java
new file mode 100644
index 000000000000..4a6bb2dd38bf
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/JavaBenchmark.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.benchmark;
+
+import org.apache.mxnet.javaapi.Context;
+import org.kohsuke.args4j.CmdLineParser;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class JavaBenchmark {
+
+ private static boolean runBatch = false;
+
+ private static void parse(Object inst, String[] args) {
+ CmdLineParser parser = new CmdLineParser(inst);
+ try {
+ parser.parseArgument(args);
+ } catch (Exception e) {
+ System.err.println(e.getMessage() + e);
+ parser.printUsage(System.err);
+ System.exit(1);
+ }
+ }
+
+ private static long percentile(int p, long[] seq) {
+ Arrays.sort(seq);
+ int k = (int) Math.ceil((seq.length - 1) * (p / 100.0));
+ return seq[k];
+ }
+
+ private static void printStatistics(long[] inferenceTimesRaw, String metricsPrefix) {
+ long[] inferenceTimes = inferenceTimesRaw;
+ // remove head and tail
+ if (inferenceTimes.length > 2) {
+ inferenceTimes = Arrays.copyOfRange(inferenceTimesRaw,
+ 1, inferenceTimesRaw.length - 1);
+ }
+ double p50 = percentile(50, inferenceTimes) / 1.0e6;
+ double p99 = percentile(99, inferenceTimes) / 1.0e6;
+ double p90 = percentile(90, inferenceTimes) / 1.0e6;
+ long sum = 0;
+ for (long time: inferenceTimes) sum += time;
+ double average = sum / (inferenceTimes.length * 1.0e6);
+
+ System.out.println(
+ String.format("\n%s_p99 %fms\n%s_p90 %fms\n%s_p50 %fms\n%s_average %1.2fms",
+ metricsPrefix, p99, metricsPrefix, p90,
+ metricsPrefix, p50, metricsPrefix, average)
+ );
+
+ }
+
+ private static List bindToDevice() {
+ List context = new ArrayList();
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+ context.add(Context.gpu());
+ } else {
+ context.add(Context.cpu());
+ }
+ return context;
+ }
+
+ public static void main(String[] args) {
+ if (args.length < 2) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Please follow the format:");
+ sb.append("\n --model-name ");
+ sb.append("\n --num-runs ");
+ sb.append("\n --batchsize ");
+ System.out.println(sb.toString());
+ return;
+ }
+ String modelName = args[1];
+ InferBase model = null;
+ switch(modelName) {
+ case "ObjectDetection":
+ runBatch = true;
+ ObjectDetectionBenchmark inst = new ObjectDetectionBenchmark();
+ parse(inst, args);
+ model = inst;
+ break;
+ default:
+ System.err.println("Model name not found! " + modelName);
+ System.exit(1);
+ }
+ List context = bindToDevice();
+ long[] result = new long[model.numRun];
+ model.preProcessModel(context);
+ if (runBatch) {
+ for (int i =0;i < model.numRun; i++) {
+ long currTime = System.nanoTime();
+ model.runBatchInference();
+ result[i] = System.nanoTime() - currTime;
+ }
+ System.out.println("Batchsize: " + model.batchSize);
+ System.out.println("Num of runs: " + model.numRun);
+ printStatistics(result, modelName +"batch_inference");
+ }
+
+ model.batchSize = 1;
+ model.preProcessModel(context);
+ result = new long[model.numRun];
+ for (int i = 0; i < model.numRun; i++) {
+ long currTime = System.nanoTime();
+ model.runSingleInference();
+ result[i] = System.nanoTime() - currTime;
+ }
+ System.out.println("Num of runs: " + model.numRun);
+ printStatistics(result, modelName + "single_inference");
+ }
+}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java
new file mode 100644
index 000000000000..257ea3241626
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.benchmark;
+
+import org.apache.mxnet.infer.javaapi.ObjectDetector;
+import org.apache.mxnet.javaapi.*;
+import org.kohsuke.args4j.Option;
+
+import java.util.ArrayList;
+import java.util.List;
+
+class ObjectDetectionBenchmark extends InferBase {
+ @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
+ public String modelPathPrefix = "/model/ssd_resnet50_512";
+ @Option(name = "--input-image", usage = "the input image")
+ public String inputImagePath = "/images/dog.jpg";
+
+ private ObjectDetector objDet;
+ private NDArray img;
+ private NDArray$ NDArray = NDArray$.MODULE$;
+
+ public void preProcessModel(List context) {
+ Shape inputShape = new Shape(new int[] {this.batchSize, 3, 512, 512});
+ List inputDescriptors = new ArrayList<>();
+ inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+ img = ObjectDetector.bufferedImageToPixels(
+ ObjectDetector.reshapeImage(
+ ObjectDetector.loadImageFromFile(inputImagePath), 512, 512
+ ),
+ new Shape(new int[] {1, 3, 512, 512})
+ );
+ }
+
+ public void runSingleInference() {
+ List nd = new ArrayList<>();
+ nd.add(img);
+ objDet.objectDetectWithNDArray(nd, 3);
+ }
+
+ public void runBatchInference() {
+ List nd = new ArrayList<>();
+ NDArray[] temp = new NDArray[batchSize];
+ for (int i = 0; i < batchSize; i++) temp[i] = img.copy();
+ NDArray batched = NDArray.concat(temp, batchSize, 0, null)[0];
+ nd.add(batched);
+ objDet.objectDetectWithNDArray(nd, 3);
+ }
+}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
new file mode 100644
index 000000000000..681253f39a88
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
@@ -0,0 +1,97 @@
+# Single Shot Multi Object Detection using Java Inference API
+
+In this example, you will learn how to use Java Inference API to run Inference on pre-trained Single Shot Multi Object Detection (SSD) MXNet model.
+
+The model is trained on the [Pascal VOC 2012 dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html). The network is a SSD model built on Resnet50 as base network to extract image features. The model is trained to detect the following entities (classes): ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']. For more details about the model, you can refer to the [MXNet SSD example](https://github.com/apache/incubator-mxnet/tree/master/example/ssd).
+
+
+## Contents
+
+1. [Prerequisites](#prerequisites)
+2. [Download artifacts](#download-artifacts)
+3. [Setup datapath and parameters](#setup-datapath-and-parameters)
+4. [Run the image inference example](#run-the-image-inference-example)
+5. [Infer APIs](#infer-api-details)
+6. [Next steps](#next-steps)
+
+
+## Prerequisites
+
+1. MXNet
+2. MXNet Scala Package
+3. [IntelliJ IDE (or alternative IDE) project setup](http://mxnet.incubator.apache.org/tutorials/java/mxnet_java_on_intellij.html) with the MXNet Scala/Java Package
+4. wget
+
+
+## Setup Guide
+
+### Download Artifacts
+#### Step 1
+You can download the files using the script `get_ssd_data.sh`. It will download and place the model files in a `model` folder and the test image files in a `image` folder in the current directory.
+From the `scala-package/examples/scripts/infer/objectdetector/` folder run:
+
+```bash
+./get_ssd_data.sh
+```
+
+**Note**: You may need to run `chmod +x get_ssd_data.sh` before running this script.
+
+In the pre-trained model, the `input_name` is `data` and shape is `(1, 3, 512, 512)`.
+This shape translates to: a batch of `1` image, the image has color and uses `3` channels (RGB), and the image has the dimensions of `512` pixels in height by `512` pixels in width.
+
+`image/jpeg` is the expected input type, since this example's image pre-processor only supports the handling of binary JPEG images.
+
+The output shape is `(1, 6132, 6)`. As with the input, the `1` is the number of images. `6132` is the number of prediction results, and `6` is for the size of each prediction. Each prediction contains the following components:
+- `Class`
+- `Accuracy`
+- `Xmin`
+- `Ymin`
+- `Xmax`
+- `Ymax`
+
+
+### Setup Datapath and Parameters
+#### Step 2
+The followings is the parameters defined for this example, you can find more information in the `class SSDClassifierExample`.
+
+| Argument | Comments |
+| ----------------------------- | ---------------------------------------- |
+| `model-path-prefix` | Folder path with prefix to the model (including json, params, and any synset file). |
+| `input-image` | The image to run inference on. |
+| `input-dir` | The directory of images to run inference on. |
+
+
+## How to Run Inference
+After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API.
+
+From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
+
+```bash
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
+```
+
+**Notes**:
+* These are relative paths to this script.
+* You may need to run `chmod +x run_ssd_example.sh` before running this script.
+
+The example should give expected output as shown below:
+```
+Class: car
+Probabilties: 0.99847263
+(Coord:,312.21335,72.0291,456.01443,150.66176)
+Class: bicycle
+Probabilties: 0.90473825
+(Coord:,155.95807,149.96362,383.8369,418.94513)
+Class: dog
+Probabilties: 0.8226818
+(Coord:,83.82353,179.13998,206.63783,476.7875)
+```
+the outputs come from the the input image, with top3 predictions picked.
+
+
+## Infer API Details
+This example uses ObjectDetector class provided by MXNet's Java Infer APIs. It provides methods to load the images, create NDArray out of Java BufferedImage and run prediction using Classifier and Predictor APIs.
+
+
+## References
+This documentation used the model and inference setup guide from the [MXNet Model Server SSD example](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/README.md).
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
new file mode 100644
index 000000000000..a9c00f7f1d81
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.objectdetector;
+
+import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.mxnet.javaapi.*;
+import org.apache.mxnet.infer.javaapi.ObjectDetector;
+
+// scalastyle:off
+import java.awt.image.BufferedImage;
+// scalastyle:on
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import java.io.File;
+
+public class SSDClassifierExample {
+ @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
+ private String modelPathPrefix = "/model/ssd_resnet50_512";
+ @Option(name = "--input-image", usage = "the input image")
+ private String inputImagePath = "/images/dog.jpg";
+ @Option(name = "--input-dir", usage = "the input batch of images directory")
+ private String inputImageDir = "/images/";
+
+ final static Logger logger = LoggerFactory.getLogger(SSDClassifierExample.class);
+
+ static List>
+ runObjectDetectionSingle(String modelPathPrefix, String inputImagePath, List context) {
+ Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
+ List inputDescriptors = new ArrayList();
+ inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
+ ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+ return objDet.imageObjectDetect(img, 3);
+ }
+
+ static List>>
+ runObjectDetectionBatch(String modelPathPrefix, String inputImageDir, List context) {
+ Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
+ List inputDescriptors = new ArrayList();
+ inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+
+ // Loading batch of images from the directory path
+ List> batchFiles = generateBatches(inputImageDir, 20);
+ List>> outputList
+ = new ArrayList>>();
+
+ for (List batchFile : batchFiles) {
+ List imgList = ObjectDetector.loadInputBatch(batchFile);
+ // Running inference on batch of images loaded in previous step
+ List> tmp
+ = objDet.imageBatchObjectDetect(imgList, 5);
+ outputList.add(tmp);
+ }
+ return outputList;
+ }
+
+ static List> generateBatches(String inputImageDirPath, int batchSize) {
+ File dir = new File(inputImageDirPath);
+
+ List> output = new ArrayList>();
+ List batch = new ArrayList();
+ for (File imgFile : dir.listFiles()) {
+ batch.add(imgFile.getPath());
+ if (batch.size() == batchSize) {
+ output.add(batch);
+ batch = new ArrayList();
+ }
+ }
+ if (batch.size() > 0) {
+ output.add(batch);
+ }
+ return output;
+ }
+
+ public static void main(String[] args) {
+ SSDClassifierExample inst = new SSDClassifierExample();
+ CmdLineParser parser = new CmdLineParser(inst);
+ try {
+ parser.parseArgument(args);
+ } catch (Exception e) {
+ logger.error(e.getMessage(), e);
+ parser.printUsage(System.err);
+ System.exit(1);
+ }
+
+ String mdprefixDir = inst.modelPathPrefix;
+ String imgPath = inst.inputImagePath;
+ String imgDir = inst.inputImageDir;
+
+ if (!checkExist(Arrays.asList(mdprefixDir + "-symbol.json", imgDir, imgPath))) {
+ logger.error("Model or input image path does not exist");
+ System.exit(1);
+ }
+
+ List context = new ArrayList();
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+ context.add(Context.gpu());
+ } else {
+ context.add(Context.cpu());
+ }
+
+ try {
+ Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
+ Shape outputShape = new Shape(new int[]{1, 6132, 6});
+
+
+ int width = inputShape.get(2);
+ int height = inputShape.get(3);
+ StringBuilder outputStr = new StringBuilder().append("\n");
+
+ List> output
+ = runObjectDetectionSingle(mdprefixDir, imgPath, context);
+
+ for (List ele : output) {
+ for (ObjectDetectorOutput i : ele) {
+ outputStr.append("Class: " + i.getClassName() + "\n");
+ outputStr.append("Probabilties: " + i.getProbability() + "\n");
+
+ List coord = Arrays.asList(i.getXMin() * width,
+ i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+ StringBuilder sb = new StringBuilder();
+ for (float c : coord) {
+ sb.append(", ").append(c);
+ }
+ outputStr.append("Coord:" + sb.substring(2) + "\n");
+ }
+ }
+ logger.info(outputStr.toString());
+
+ List>> outputList =
+ runObjectDetectionBatch(mdprefixDir, imgDir, context);
+
+ outputStr = new StringBuilder().append("\n");
+ int index = 0;
+ for (List> i : outputList) {
+ for (List j : i) {
+ outputStr.append("*** Image " + (index + 1) + "***" + "\n");
+ for (ObjectDetectorOutput k : j) {
+ outputStr.append("Class: " + k.getClassName() + "\n");
+ outputStr.append("Probabilties: " + k.getProbability() + "\n");
+ List coord = Arrays.asList(k.getXMin() * width,
+ k.getXMax() * height, k.getYMin() * width, k.getYMax() * height);
+
+ StringBuilder sb = new StringBuilder();
+ for (float c : coord) {
+ sb.append(", ").append(c);
+ }
+ outputStr.append("Coord:" + sb.substring(2) + "\n");
+ }
+ index++;
+ }
+ }
+ logger.info(outputStr.toString());
+
+ } catch (Exception e) {
+ logger.error(e.getMessage(), e);
+ parser.printUsage(System.err);
+ System.exit(1);
+ }
+ System.exit(0);
+ }
+
+ static Boolean checkExist(List arr) {
+ Boolean exist = true;
+ for (String item : arr) {
+ if (!(new File(item).exists())) {
+ logger.error("Cannot find: " + item);
+ exist = false;
+ }
+ }
+ return exist;
+ }
+}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
new file mode 100644
index 000000000000..c9b4426f52b3
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.predictor;
+
+import org.apache.mxnet.infer.javaapi.Predictor;
+import org.apache.mxnet.javaapi.*;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.imageio.ImageIO;
+import java.awt.Graphics2D;
+import java.awt.image.BufferedImage;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * This Class is a demo to show how users can use Predictor APIs to do
+ * Image Classification with all hand-crafted Pre-processing.
+ * All helper functions for image pre-processing are
+ * currently available in ObjectDetector class.
+ */
+public class PredictorExample {
+ @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
+ private String modelPathPrefix = "/model/ssd_resnet50_512";
+ @Option(name = "--input-image", usage = "the input image")
+ private String inputImagePath = "/images/dog.jpg";
+
+ final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);
+
+ /**
+ * Load the image from file to buffered image
+ * It can be replaced by loadImageFromFile from ObjectDetector
+ * @param inputImagePath input image Path in String
+ * @return Buffered image
+ */
+ private static BufferedImage loadIamgeFromFile(String inputImagePath) {
+ BufferedImage buf = null;
+ try {
+ buf = ImageIO.read(new File(inputImagePath));
+ } catch (IOException e) {
+ System.err.println(e);
+ }
+ return buf;
+ }
+
+ /**
+ * Reshape the current image using ImageIO and Graph2D
+ * It can be replaced by reshapeImage from ObjectDetector
+ * @param buf Buffered image
+ * @param newWidth desired width
+ * @param newHeight desired height
+ * @return a reshaped bufferedImage
+ */
+ private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) {
+ BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
+ Graphics2D g = resizedImage.createGraphics();
+ g.drawImage(buf, 0, 0, newWidth, newHeight, null);
+ g.dispose();
+ return resizedImage;
+ }
+
+ /**
+ * Convert an image from a buffered image into pixels float array
+ * It can be replaced by bufferedImageToPixels from ObjectDetector
+ * @param buf buffered image
+ * @return Float array
+ */
+ private static float[] imagePreprocess(BufferedImage buf) {
+ // Get height and width of the image
+ int w = buf.getWidth();
+ int h = buf.getHeight();
+
+ // get an array of integer pixels in the default RGB color mode
+ int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w);
+
+ // 3 times height and width for R,G,B channels
+ float[] result = new float[3 * h * w];
+
+ int row = 0;
+ // copy pixels to array vertically
+ while (row < h) {
+ int col = 0;
+ // copy pixels to array horizontally
+ while (col < w) {
+ int rgb = pixels[row * w + col];
+ // getting red color
+ result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF;
+ // getting green color
+ result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF;
+ // getting blue color
+ result[2 * h * w + row * w + col] = rgb & 0xFF;
+ col += 1;
+ }
+ row += 1;
+ }
+ buf.flush();
+ return result;
+ }
+
+ /**
+ * Helper class to print the maximum prediction result
+ * @param probabilities The float array of probability
+ * @param modelPathPrefix model Path needs to load the synset.txt
+ */
+ private static String printMaximumClass(float[] probabilities,
+ String modelPathPrefix) throws IOException {
+ String synsetFilePath = modelPathPrefix.substring(0,
+ 1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt";
+ BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath));
+ ArrayList list = new ArrayList<>();
+ String line = reader.readLine();
+
+ while (line != null){
+ list.add(line);
+ line = reader.readLine();
+ }
+ reader.close();
+
+ int maxIdx = 0;
+ for (int i = 1;i probabilities[maxIdx]) {
+ maxIdx = i;
+ }
+ }
+
+ return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ;
+ }
+
+ public static void main(String[] args) {
+ PredictorExample inst = new PredictorExample();
+ CmdLineParser parser = new CmdLineParser(inst);
+ try {
+ parser.parseArgument(args);
+ } catch (Exception e) {
+ logger.error(e.getMessage(), e);
+ parser.printUsage(System.err);
+ System.exit(1);
+ }
+ // Prepare the model
+ List context = new ArrayList();
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+ context.add(Context.gpu());
+ } else {
+ context.add(Context.cpu());
+ }
+ List inputDesc = new ArrayList<>();
+ Shape inputShape = new Shape(new int[]{1, 3, 224, 224});
+ inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
+ // Prepare data
+ BufferedImage img = loadIamgeFromFile(inst.inputImagePath);
+
+ img = reshapeImage(img, 224, 224);
+ // predict
+ float[][] result = predictor.predict(new float[][]{imagePreprocess(img)});
+ try {
+ System.out.println("Predict with Float input");
+ System.out.println(printMaximumClass(result[0], inst.modelPathPrefix));
+ } catch (IOException e) {
+ System.err.println(e);
+ }
+ // predict with NDArray
+ NDArray nd = new NDArray(
+ imagePreprocess(img),
+ new Shape(new int[]{1, 3, 224, 224}),
+ Context.cpu());
+ List ndList = new ArrayList<>();
+ ndList.add(nd);
+ List ndResult = predictor.predictWithNDArray(ndList);
+ try {
+ System.out.println("Predict with NDArray");
+ System.out.println(printMaximumClass(ndResult.get(0).toArray(), inst.modelPathPrefix));
+ } catch (IOException e) {
+ System.err.println(e);
+ }
+ }
+
+}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/README.md
new file mode 100644
index 000000000000..1f2c9e0e813c
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/README.md
@@ -0,0 +1,61 @@
+# Image Classification using Java Predictor
+
+In this example, you will learn how to use Java Inference API to
+build and run pre-trained Resnet 18 model.
+
+## Contents
+
+1. [Prerequisites](#prerequisites)
+2. [Download artifacts](#download-artifacts)
+3. [Setup datapath and parameters](#setup-datapath-and-parameters)
+4. [Run the image classifier example](#run-the-image-inference-example)
+
+## Prerequisites
+
+1. Build from source with [MXNet](https://mxnet.incubator.apache.org/install/index.html)
+2. [IntelliJ IDE (or alternative IDE) project setup](https://github.com/apache/incubator-mxnet/blob/master/docs/tutorials/java/mxnet_java_on_intellij.md) with the MXNet Java Package
+3. wget
+
+## Download Artifacts
+
+For this tutorial, you can get the model and sample input image by running following bash file. This script will use `wget` to download these artifacts from AWS S3.
+
+From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
+
+```bash
+./get_resnet_18_data.sh
+```
+
+**Note**: You may need to run `chmod +x get_resnet_18_data.sh` before running this script.
+
+### Setup Datapath and Parameters
+
+The available arguments are as follows:
+
+| Argument | Comments |
+| ----------------------------- | ---------------------------------------- |
+| `model-dir` | Folder path with prefix to the model (including json, params, and any synset file). |
+| `input-image` | The image to run inference on. |
+
+## Run the image classifier example
+
+After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Predictor API.
+
+From the `scala-package/examples/scripts/infer/predictor/` folder run:
+
+```bash
+bash run_predictor_java_example.sh ../models/resnet-18/resnet-18 ../images/kitten.jpg
+```
+
+**Notes**:
+* These are relative paths to this script.
+* You may need to run `chmod +x run_predictor_java_example.sh` before running this script.
+
+The example should give an output similar to the one shown below:
+```
+Predict with Float input
+Probability : 0.30337515 Class : n02123159 tiger cat
+Predict with NDArray
+Probability : 0.30337515 Class : n02123159 tiger cat
+```
+the outputs come from the the input image, with top1 predictions picked.
\ No newline at end of file
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
index 69328a44bab6..77aec7bb5dee 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
@@ -28,18 +28,13 @@ The model is trained on the [Pascal VOC 2012 dataset](http://host.robots.ox.ac.u
### Download Artifacts
#### Step 1
You can download the files using the script `get_ssd_data.sh`. It will download and place the model files in a `model` folder and the test image files in a `image` folder in the current directory.
-From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
+From the `scala-package/examples/scripts/infer/objectdetector/` folder run:
```bash
-./get_resnet_data.sh
+./get_ssd_data.sh
```
-**Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
-
-Alternatively use the following links to download the Symbol and Params files via your browser:
-- [resnet50_ssd_model-symbol.json](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json)
-- [resnet50_ssd_model-0000.params](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params)
-- [synset.txt](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/synset.txt)
+**Note**: You may need to run `chmod +x get_ssd_data.sh` before running this script.
In the pre-trained model, the `input_name` is `data` and shape is `(1, 3, 512, 512)`.
This shape translates to: a batch of `1` image, the image has color and uses `3` channels (RGB), and the image has the dimensions of `512` pixels in height by `512` pixels in width.
@@ -57,13 +52,6 @@ The output shape is `(1, 6132, 6)`. As with the input, the `1` is the number of
### Setup Datapath and Parameters
#### Step 2
-The code `Line 31: val baseDir = System.getProperty("user.dir")` in the example will automatically searches the work directory you have defined. Please put the files in your [work directory](https://stackoverflow.com/questions/16239130/java-user-dir-property-what-exactly-does-it-mean).
-
-Alternatively, if you would like to use your own path, please change line 31 into your own path
-```scala
-val baseDir =
-```
-
The followings is the parameters defined for this example, you can find more information in the `class SSDClassifierExample`.
| Argument | Comments |
@@ -79,7 +67,7 @@ After the previous steps, you should be able to run the code using the following
From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
```bash
-./run_ssd_example.sh ../models/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```
**Notes**:
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
index f752ef6dab58..07d1cc82e927 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
@@ -182,9 +182,9 @@ object SSDClassifierExample {
def checkExist(arr : Array[String]) : Boolean = {
var exist : Boolean = true
for (item <- arr) {
- exist = Files.exists(Paths.get(item)) && exist
- if (!exist) {
+ if (!(Files.exists(Paths.get(item)))) {
logger.error("Cannot find: " + item)
+ exist = false
}
}
exist
diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
index e50100169328..91a1e1b30d2f 100644
--- a/scala-package/infer/pom.xml
+++ b/scala-package/infer/pom.xml
@@ -13,6 +13,10 @@
mxnet-infer_2.11MXNet Scala Package - Inference
+
+ true
+
+
unittest
@@ -20,12 +24,6 @@
false
-
- integrationtest
-
- true
-
- osx-x86_64-cpu
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
new file mode 100644
index 000000000000..08fffb410adf
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray, Shape}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+/**
+ * The ObjectDetector class helps to run ObjectDetection tasks where the goal
+ * is to find bounding boxes and corresponding labels for objects in a image.
+ *
+ * @param modelPathPrefix Path prefix from where to load the model artifacts.
+ * These include the symbol, parameters, and synset.txt.
+ * Example: file://model-dir/ssd_resnet50_512 (containing
+ * ssd_resnet50_512-symbol.json, ssd_resnet50_512-0000.params,
+ * and synset.txt)
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ * layout and type parameters
+ * @param contexts Device contexts on which you want to run inference.
+ * Defaults to CPU.
+ * @param epoch Model epoch to load; defaults to 0
+ */
+class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){
+
+ def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
+ java.util.List[Context], epoch: Int)
+ = this {
+ val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+ .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+ val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+ // scalastyle:off
+ new org.apache.mxnet.infer.ObjectDetector(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+ // scalastyle:on
+ }
+
+ /**
+ * Detects objects and returns bounding boxes with corresponding class/label
+ *
+ * @param inputImage Path prefix of the input image
+ * @param topK Number of result elements to return, sorted by probability
+ * @return List of list of tuples of
+ * (class, [probability, xmin, ymin, xmax, ymax])
+ */
+ def imageObjectDetect(inputImage: BufferedImage, topK: Int):
+ java.util.List[java.util.List[ObjectDetectorOutput]] = {
+ val ret = objDetector.imageObjectDetect(inputImage, Some(topK))
+ (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+ }
+
+ /**
+ * Takes input images as NDArrays. Useful when you want to perform multiple operations on
+ * the input array, or when you want to pass a batch of input images.
+ *
+ * @param input Indexed Sequence of NDArrays
+ * @param topK (Optional) How many top_k (sorting will be based on the last axis)
+ * elements to return. If not passed, returns all unsorted output.
+ * @return List of list of tuples of
+ * (class, [probability, xmin, ymin, xmax, ymax])
+ */
+ def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
+ java.util.List[java.util.List[ObjectDetectorOutput]] = {
+ val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
+ (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+ }
+
+ /**
+ * To classify batch of input images according to the provided model
+ *
+ * @param inputBatch Input array of buffered images
+ * @param topK Number of result elements to return, sorted by probability
+ * @return List of list of tuples of (class, probability)
+ */
+ def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
+ java.util.List[java.util.List[ObjectDetectorOutput]] = {
+ val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
+ (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+ }
+
+ def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+
+}
+
+
+object ObjectDetector {
+ implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
+ ObjectDetector = new ObjectDetector(OD)
+
+ implicit def toObjectDetector(jOD: ObjectDetector):
+ org.apache.mxnet.infer.ObjectDetector = jOD.objDetector
+
+ def loadImageFromFile(inputImagePath: String): BufferedImage = {
+ org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath)
+ }
+
+ def reshapeImage(img : BufferedImage, newWidth: Int, newHeight: Int): BufferedImage = {
+ org.apache.mxnet.infer.ImageClassifier.reshapeImage(img, newWidth, newHeight)
+ }
+
+ def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
+ org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
+ }
+
+ def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
+ org.apache.mxnet.infer.ImageClassifier
+ .loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
+ }
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
new file mode 100644
index 000000000000..13369c8fcef5
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+class ObjectDetectorOutput (className: String, args: Array[Float]){
+
+ def getClassName: String = className
+
+ def getProbability: Float = args(0)
+
+ def getXMin: Float = args(1)
+
+ def getXMax: Float = args(2)
+
+ def getYMin: Float = args(3)
+
+ def getYMax: Float = args(4)
+
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
new file mode 100644
index 000000000000..a5428e1c8219
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+/**
+ * Implementation of prediction routines.
+ *
+ * @param modelPathPrefix Path prefix from where to load the model artifacts.
+ * These include the symbol, parameters, and synset.txt
+ * Example: file://model-dir/resnet-152 (containing
+ * resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ * layout and type parameters
+ *
Note: If the input Descriptors is missing batchSize
+ * ('N' in layout), a batchSize of 1 is assumed for the model.
+ * @param contexts Device contexts on which you want to run inference; defaults to CPU
+ * @param epoch Model epoch to load; defaults to 0
+
+ */
+
+// JavaDoc description of class to be updated in https://issues.apache.org/jira/browse/MXNET-1178
+class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor){
+ def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
+ contexts: java.util.List[Context], epoch: Int)
+ = this {
+ val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+ .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+ val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+ new org.apache.mxnet.infer.Predictor(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+ }
+
+ /**
+ * Takes input as Array of one dimensional arrays and creates the NDArray needed for inference
+ * The array will be reshaped based on the input descriptors.
+ *
+ * @param input: An Array of a one-dimensional array.
+ An extra Array is needed for when the model has more than one input.
+ * @return Indexed sequence array of outputs
+ */
+ def predict(input: Array[Array[Float]]):
+ Array[Array[Float]] = {
+ predictor.predict(input).toArray
+ }
+
+ /**
+ * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
+ * The array will be reshaped based on the input descriptors.
+ *
+ * @param input: A List of a one-dimensional array.
+ An extra List is needed for when the model has more than one input.
+ * @return Indexed sequence array of outputs
+ */
+ def predict(input: java.util.List[java.util.List[Float]]):
+ java.util.List[java.util.List[Float]] = {
+ val in = JavaConverters.asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq
+ (predictor.predict(in map {a => a.asScala.toArray}) map {b => b.toList.asJava}).asJava
+ }
+
+
+
+ /**
+ * Predict using NDArray as input
+ * This method is useful when the input is a batch of data
+ * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
+ *
+ * @param input List of NDArrays
+ * @return Output of predictions as NDArrays
+ */
+ def predictWithNDArray(input: java.util.List[NDArray]):
+ java.util.List[NDArray] = {
+ val ret = predictor.predictWithNDArray(convert(JavaConverters
+ .asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
+ // TODO: For some reason the implicit wasn't working here when trying to use convert.
+ // So did it this way. Needs to be figured out
+ (ret map {a => new NDArray(a)}).asJava
+ }
+
+ private def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index bfa378ea9e95..7c1edb5df5c1 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -36,6 +36,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
hashCollector += typeSafeClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
+ hashCollector += javaClassGen(FILE_PATH)
val finalHash = hashCollector.mkString("\n")
}
@@ -89,6 +90,29 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
absFuncs)
}
+ def javaClassGen(filePath : String) : String = {
+ val notGenerated = Set("Custom")
+ val absClassFunctions = functionsToGenerate(false, false, true)
+ val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
+ .groupBy(_.name.toLowerCase).map(ele => {
+ /* Pattern matching for not generating deprecated method
+ * Group all method name in lowercase
+ * Kill the capital lettered method such as Cast vs cast
+ * As it defined by default it deprecated
+ */
+ if (ele._2.length == 1) ele._2.head
+ else {
+ if (ele._2.head.name.head.isLower) ele._2.head
+ else ele._2.last
+ }
+ }).map(absClassFunction => {
+ generateJavaAPISignature(absClassFunction)
+ }).toSeq
+ val packageName = "NDArrayBase"
+ val packageDef = "package org.apache.mxnet.javaapi"
+ writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
+ }
+
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
val desc = func.desc.split("\n")
.mkString(" *
\n", "\n * ", " *
\n")
@@ -123,6 +147,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
argDef += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
+
}
val returnType = func.returnType
@@ -131,6 +156,64 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
}
+ def generateJavaAPISignature(func : Func) : String = {
+ val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
+ var argDef = ListBuffer[String]()
+ var classDef = ListBuffer[String]()
+ var requiredParam = ListBuffer[String]()
+ func.listOfArgs.foreach(absClassArg => {
+ val currArgName = absClassArg.safeArgName
+ // scalastyle:off
+ if (absClassArg.isOptional && useParamObject) {
+ classDef +=
+ s"""private var $currArgName: ${absClassArg.argType} = null
+ |/**
+ | * @param $currArgName\t\t${absClassArg.argDesc}
+ | */
+ |def set${currArgName.capitalize}($currArgName : ${absClassArg.argType}): ${func.name}Param = {
+ | this.$currArgName = $currArgName
+ | this
+ | }""".stripMargin
+ }
+ else {
+ requiredParam += s" * @param $currArgName\t\t${absClassArg.argDesc}"
+ argDef += s"$currArgName : ${absClassArg.argType}"
+ }
+ classDef += s"def get${currArgName.capitalize}() = this.$currArgName"
+ // scalastyle:on
+ })
+ val experimentalTag = "@Experimental"
+ val returnType = "Array[NDArray]"
+ val scalaDoc = generateAPIDocFromBackend(func)
+ val scalaDocNoParam = generateAPIDocFromBackend(func, false)
+ if(useParamObject) {
+ classDef +=
+ s"""private var out : org.apache.mxnet.NDArray = null
+ |def setOut(out : NDArray) : ${func.name}Param = {
+ | this.out = out
+ | this
+ | }
+ | def getOut() = this.out
+ | """.stripMargin
+ s"""$scalaDocNoParam
+ | $experimentalTag
+ | def ${func.name}(po: ${func.name}Param) : $returnType
+ | /**
+ | * This Param Object is specifically used for ${func.name}
+ | ${requiredParam.mkString("\n")}
+ | */
+ | class ${func.name}Param(${argDef.mkString(",")}) {
+ | ${classDef.mkString("\n ")}
+ | }""".stripMargin
+ } else {
+ argDef += "out : NDArray"
+ s"""$scalaDoc
+ |$experimentalTag
+ | def ${func.name}(${argDef.mkString(", ")}) : $returnType
+ | """.stripMargin
+ }
+ }
+
def writeFile(FILE_PATH: String, className: String, packageDef: String,
absFuncs: Seq[String]): String = {
@@ -161,6 +244,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
|${absFuncs.mkString("\n")}
|}""".stripMargin
+
val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
pw.write(finalStr)
pw.close()
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index f4c4a91bdf9a..9245ef1b437f 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -36,8 +36,9 @@ abstract class GeneratorBase {
case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String)
- def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
- val l = getBackEndFunctions(isSymbol)
+ def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean,
+ isJava: Boolean = false): List[Func] = {
+ val l = getBackEndFunctions(isSymbol, isJava)
if (isContrib) {
l.filter(func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
} else {
@@ -58,17 +59,18 @@ abstract class GeneratorBase {
res.filterNot(ele => notGenerated.contains(ele.name))
}
- protected def getBackEndFunctions(isSymbol: Boolean): List[Func] = {
+ protected def getBackEndFunctions(isSymbol: Boolean, isJava: Boolean = false): List[Func] = {
val opNames = ListBuffer.empty[String]
_LIB.mxListAllOpNames(opNames)
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
- makeAtomicFunction(opHandle.value, opName, isSymbol)
+ makeAtomicFunction(opHandle.value, opName, isSymbol, isJava)
}).toList
}
- private def makeAtomicFunction(handle: Handle, aliasName: String, isSymbol: Boolean): Func = {
+ private def makeAtomicFunction(handle: Handle, aliasName: String,
+ isSymbol: Boolean, isJava: Boolean): Func = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
@@ -89,13 +91,17 @@ abstract class GeneratorBase {
val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
- val family = if (isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArray"
+ val family = if (isJava) "org.apache.mxnet.javaapi.NDArray"
+ else if (isSymbol) "org.apache.mxnet.Symbol"
+ else "org.apache.mxnet.NDArray"
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, family)
Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
val returnType =
- if (isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArrayFuncReturn"
+ if (isJava) "Array[org.apache.mxnet.javaapi.NDArray]"
+ else if (isSymbol) "org.apache.mxnet.Symbol"
+ else "org.apache.mxnet.NDArrayFuncReturn"
Func(aliasName, desc.value, argList.toList, returnType)
}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
new file mode 100644
index 000000000000..4dfd6eb044a1
--- /dev/null
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.GeneratorBase
+
+import scala.annotation.StaticAnnotation
+import scala.collection.mutable.ListBuffer
+import scala.language.experimental.macros
+import scala.reflect.macros.blackbox
+
+private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) = macro JavaNDArrayMacro.typeSafeAPIDefs
+}
+
+private[mxnet] object JavaNDArrayMacro extends GeneratorBase {
+
+ // scalastyle:off havetype
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+ typeSafeAPIImpl(c)(annottees: _*)
+ }
+ // scalastyle:off havetype
+
+ private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+ import c.universe._
+
+ val isContrib: Boolean = c.prefix.tree match {
+ case q"new AddJNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
+ }
+ // Defines Operators that should not generated
+ val notGenerated = Set("Custom")
+
+ val newNDArrayFunctions = functionsToGenerate(false, false, true)
+ .filterNot(ele => notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => {
+ /* Pattern matching for not generating deprecated method
+ * Group all method name in lowercase
+ * Kill the capital lettered method such as Cast vs cast
+ * As it defined by default it deprecated
+ */
+ if (ele._2.length == 1) ele._2.head
+ else {
+ if (ele._2.head.name.head.isLower) ele._2.head
+ else ele._2.last
+ }
+ })
+
+ val functionDefs = ListBuffer[DefDef]()
+ val classDefs = ListBuffer[ClassDef]()
+
+ newNDArrayFunctions.foreach { ndarrayfunction =>
+
+ val useParamObject = ndarrayfunction.listOfArgs.count(arg => arg.isOptional) >= 2
+ // Construct argument field with all required args
+ var argDef = ListBuffer[String]()
+ // Construct function Implementation field (e.g norm)
+ var impl = ListBuffer[String]()
+ impl += "val map = scala.collection.mutable.Map[String, Any]()"
+ impl +=
+ "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
+ ndarrayfunction.listOfArgs.foreach({ ndarrayArg =>
+ // var is a special word used to define variable in Scala,
+ // need to changed to something else in order to make it work
+ var currArgName = ndarrayArg.safeArgName
+ if (useParamObject) currArgName = s"po.get${currArgName.capitalize}()"
+ argDef += s"$currArgName : ${ndarrayArg.argType}"
+ // NDArray arg implementation
+ val returnType = "org.apache.mxnet.javaapi.NDArray"
+ val base =
+ if (ndarrayArg.argType.equals(returnType)) {
+ s"args += $currArgName"
+ } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){
+ s"$currArgName.foreach(args+=_)"
+ } else {
+ "map(\"" + ndarrayArg.argName + "\") = " + currArgName
+ }
+ impl.append(
+ if (ndarrayArg.isOptional) s"if ($currArgName != null) $base"
+ else base
+ )
+ })
+ // add default out parameter
+ argDef += s"out: org.apache.mxnet.javaapi.NDArray"
+ if (useParamObject) {
+ impl += "if (po.getOut() != null) map(\"out\") = po.getOut()"
+ } else {
+ impl += "if (out != null) map(\"out\") = out"
+ }
+ val returnType = "Array[org.apache.mxnet.javaapi.NDArray]"
+ // scalastyle:off
+ // Combine and build the function string
+ impl += "val finalArr = org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" +
+ ndarrayfunction.name + "\", args.toSeq, map.toMap).arr"
+ impl += "finalArr.map(ele => new NDArray(ele))"
+ if (useParamObject) {
+ val funcDef =
+ s"""def ${ndarrayfunction.name}(po: ${ndarrayfunction.name}Param): $returnType = {
+ | ${impl.mkString("\n")}
+ | }""".stripMargin
+ functionDefs += c.parse(funcDef).asInstanceOf[DefDef]
+ } else {
+ val funcDef =
+ s"""def ${ndarrayfunction.name}(${argDef.mkString(",")}): $returnType = {
+ | ${impl.mkString("\n")}
+ | }""".stripMargin
+ functionDefs += c.parse(funcDef).asInstanceOf[DefDef]
+ }
+ }
+ structGeneration(c)(functionDefs.toList, annottees : _*)
+ }
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index d0ebe5b1d2cb..2fd8b2e73c7a 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -21,19 +21,20 @@ private[mxnet] object CToScalaUtils {
// Convert C++ Types to Scala Types
- def typeConversion(in : String, argType : String = "",
- argName : String, returnType : String) : String = {
+ def typeConversion(in : String, argType : String = "", argName : String,
+ returnType : String) : String = {
+ val header = returnType.split("\\.").dropRight(1)
in match {
- case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
+ case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> s"Array[$returnType]"
- case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
- case "int" | "intorNone" | "int(non-negative)" => "Int"
- case "long" | "long(non-negative)" => "Long"
- case "double" | "doubleorNone" => "Double"
+ case "float" | "real_t" | "floatorNone" => "java.lang.Float"
+ case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
+ case "long" | "long(non-negative)" => "java.lang.Long"
+ case "double" | "doubleorNone" => "java.lang.Double"
case "string" => "String"
- case "boolean" | "booleanorNone" => "Boolean"
+ case "boolean" | "booleanorNone" => "java.lang.Boolean"
case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
@@ -52,8 +53,8 @@ private[mxnet] object CToScalaUtils {
* @param argType Raw arguement Type description
* @return (Scala_Type, isOptional)
*/
- def argumentCleaner(argName: String,
- argType : String, returnType : String) : (String, Boolean) = {
+ def argumentCleaner(argName: String, argType : String,
+ returnType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index c3a7c58c1afc..4404b0885d57 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)
val output = List(
("org.apache.mxnet.Symbol", true),
- ("Int", false),
+ ("java.lang.Integer", false),
("org.apache.mxnet.Shape", true),
("String", true),
("Any", false)
diff --git a/scala-package/mxnet-demo/java-demo/Makefile b/scala-package/mxnet-demo/java-demo/Makefile
new file mode 100644
index 000000000000..340a50f75965
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/Makefile
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+SCALA_VERSION_PROFILE := 2.11
+SCALA_VERSION := 2.11.8
+MXNET_VERSION := 1.3.1-SNAPSHOT
+
+ifeq ($(OS),Windows_NT)
+ UNAME_S := Windows
+else
+ UNAME_S := $(shell uname -s)
+endif
+
+ifeq ($(UNAME_S), Windows)
+ # TODO: currently scala package does not support windows
+ SCALA_PKG_PROFILE := windows
+else
+ ifeq ($(UNAME_S), Darwin)
+ SCALA_PKG_PROFILE := osx-x86_64-cpu
+ else
+ SCALA_PKG_PROFILE := linux-x86_64
+ ifeq ($(USE_CUDA), 1)
+ SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-gpu
+ else
+ SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-cpu
+ endif
+ endif
+endif
+
+javademo:
+ (mvn package -Dmxnet.profile=$(SCALA_PKG_PROFILE) \
+ -Dmxnet.scalaprofile=$(SCALA_VERSION_PROFILE) \
+ -Dmxnet.version=$(MXNET_VERSION) \
+ -Dscala.version=$(SCALA_VERSION))
+
+javaclean:
+ (mvn clean -Dmxnet.profile=$(SCALA_PKG_PROFILE) \
+ -Dmxnet.scalaprofile=$(SCALA_VERSION_PROFILE) \
+ -Dmxnet.version=$(MXNET_VERSION) \
+ -Dscala.version=$(SCALA_VERSION))
\ No newline at end of file
diff --git a/scala-package/mxnet-demo/java-demo/README.md b/scala-package/mxnet-demo/java-demo/README.md
new file mode 100644
index 000000000000..ffe614a29287
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/README.md
@@ -0,0 +1,76 @@
+# MXNet Java Sample Project
+This is an project created to use Maven-published Scala/Java package with two Java examples.
+## Setup
+Please copy the downloaded MXNet Java package jar file under the `java-demo` folder.
+
+User are required to use `mvn package` to build the package,
+ which are shown below:
+```Bash
+export SCALA_VERSION_PROFILE=2.11 SCALA_VERSION=2.11.8 MXNET_VERSION=1.3.1-SNAPSHOT
+export SCALA_PKG_PROFILE=
+mvn package -Dmxnet.profile=$(SCALA_PKG_PROFILE) \
+ -Dmxnet.scalaprofile=$(SCALA_VERSION_PROFILE) \
+ -Dmxnet.version=$(MXNET_VERSION) \
+ -Dscala.version=$(SCALA_VERSION)
+```
+These environment variable (`SCALA_PKG_PROFILE`, `SCALA_VERSION_PROFILE`, `MXNET_VERSION`, `SCALA_VERSION`)
+should be set before executing the line above.
+
+You can also use the `Makefile` as an alternative to do the same thing. Simply do the following:
+```Bash
+make javademo
+```
+This will load the default parameter for all the environment variable.
+ If you want to run with GPU on Linux, just simply add `USE_CUDA=1` when you run the make file
+
+## Run
+### Hello World
+The Scala file is being executed using Java. You can execute the helloWorld example as follows:
+```Bash
+java -cp $CLASSPATH sample.HelloWorld
+```
+However, you have to define the Classpath before you run the demo code. More information can be found in the `demo.sh` And you can run the bash script as follows:
+```Bash
+bash bin/java_sample.sh
+```
+It will load the library automatically and run the example
+### Object Detection using Inference API
+We also provide an example to do object detection, which downloads a ImageNet trained resnet50 model and runs inference on an image to return the classification result as
+```Bash
+Class: car
+Probabilties: 0.99847263
+Coord:312.21335, 72.02908, 456.01443, 150.66176
+Class: bicycle
+Probabilties: 0.9047381
+Coord:155.9581, 149.96365, 383.83694, 418.94516
+Class: dog
+Probabilties: 0.82268167
+Coord:83.82356, 179.14001, 206.63783, 476.78754
+```
+
+you can run using the command shown below:
+```Bash
+java -cp $CLASSPATH sample.ObjectDetection
+```
+or script as follows:
+```Bash
+bash bin/run_od.sh
+```
+
+If you want to test run on GPU, you can set a environment variable as follows:
+```Bash
+export SCALA_TEST_ON_GPU=1
+```
+## Clean up
+Clean up for Maven package is simple, you can run the pre-configed `Makefile` as:
+```Bash
+make javaclean
+```
+
+## Q & A
+If you are facing opencv issue on Ubuntu, please try as follows to install opencv 3.4 (required by 1.2.0 package and above)
+```Bash
+sudo add-apt-repository ppa:timsc/opencv-3.4
+sudo apt-get update
+sudo apt install libopencv-imgcodecs3.4
+```
\ No newline at end of file
diff --git a/scala-package/mxnet-demo/java-demo/bin/java_sample.sh b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
new file mode 100644
index 000000000000..50e7fb9eb97d
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#!/bin/bash
+CURR_DIR=$(cd $(dirname $0)/../; pwd)
+CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/*
+java -Xmx8G -cp $CLASSPATH sample.HelloWorld
\ No newline at end of file
diff --git a/scala-package/mxnet-demo/java-demo/bin/run_od.sh b/scala-package/mxnet-demo/java-demo/bin/run_od.sh
new file mode 100644
index 000000000000..5cbc53fbcefe
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/bin/run_od.sh
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#!/bin/bash
+CURR_DIR=$(cd $(dirname $0)/../; pwd)
+
+CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/*
+java -Xmx8G -cp $CLASSPATH sample.ObjectDetection
\ No newline at end of file
diff --git a/scala-package/mxnet-demo/java-demo/pom.xml b/scala-package/mxnet-demo/java-demo/pom.xml
new file mode 100644
index 000000000000..5014d2e09f55
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/pom.xml
@@ -0,0 +1,25 @@
+
+
+ 4.0.0
+ Demo
+ mxnet-java-demo
+ 1.0-SNAPSHOT
+ MXNet Java Demo
+
+
+
+ org.apache.mxnet
+ mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}
+ ${mxnet.version}
+ system
+ ${project.basedir}/mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}-${mxnet.version}.jar
+
+
+ commons-io
+ commons-io
+ 2.4
+
+
+
\ No newline at end of file
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/sample/HelloWorld.java b/scala-package/mxnet-demo/java-demo/src/main/java/sample/HelloWorld.java
new file mode 100644
index 000000000000..60619dc8a806
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/sample/HelloWorld.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package sample;
+
+import org.apache.mxnet.javaapi.*;
+import java.util.Arrays;
+
+public class HelloWorld {
+ public static void main(String[] args) {
+ System.out.println("Hello World!");
+ NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
+ System.out.println(nd.shape());
+ }
+}
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/sample/ObjectDetection.java b/scala-package/mxnet-demo/java-demo/src/main/java/sample/ObjectDetection.java
new file mode 100644
index 000000000000..bf9a93ae8217
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/sample/ObjectDetection.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package sample;
+import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
+import org.apache.mxnet.javaapi.*;
+import org.apache.mxnet.infer.javaapi.ObjectDetector;
+import org.apache.commons.io.FileUtils;
+import java.io.File;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class ObjectDetection {
+ private static String modelPath;
+ private static String imagePath;
+
+ private static void downloadUrl(String url, String filePath) {
+ File tmpFile = new File(filePath);
+ if (!tmpFile.exists()) {
+ try {
+ FileUtils.copyURLToFile(new URL(url), tmpFile);
+ } catch (Exception exception) {
+ System.err.println(exception);
+ }
+ }
+ }
+
+ public static void downloadModelImage() {
+ String tempDirPath = System.getProperty("java.io.tmpdir");
+ System.out.println("tempDirPath: %s".format(tempDirPath));
+ imagePath = tempDirPath + "/inputImages/resnetssd/dog-ssd.jpg";
+ String imgURL = "https://s3.amazonaws.com/model-server/inputs/dog-ssd.jpg";
+ downloadUrl(imgURL, imagePath);
+ modelPath = tempDirPath + "resnetssd/resnet50_ssd_model";
+ System.out.println("Download model files, this can take a while...");
+ String modelURL = "https://s3.amazonaws.com/model-server/models/resnet50_ssd/";
+ downloadUrl(modelURL + "resnet50_ssd_model-symbol.json",
+ tempDirPath + "/resnetssd/resnet50_ssd_model-symbol.json");
+ downloadUrl(modelURL + "resnet50_ssd_model-0000.params",
+ tempDirPath + "/resnetssd/resnet50_ssd_model-0000.params");
+ downloadUrl(modelURL + "synset.txt",
+ tempDirPath + "/resnetssd/synset.txt");
+ }
+
+ static List>
+ runObjectDetectionSingle(String modelPathPrefix, String inputImagePath, List context) {
+ Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+ List inputDescriptors = new ArrayList();
+ inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+ return objDet.imageObjectDetect(ObjectDetector.loadImageFromFile(inputImagePath), 3);
+ }
+
+ public static void main(String[] args) {
+ List context = new ArrayList();
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+ context.add(Context.gpu());
+ } else {
+ context.add(Context.cpu());
+ }
+ downloadModelImage();
+ Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+ Shape outputShape = new Shape(new int[] {1, 6132, 6});
+ int width = inputShape.get(2);
+ int height = inputShape.get(3);
+ List> output
+ = runObjectDetectionSingle(modelPath, imagePath, context);
+ String outputStr = "\n";
+ for (List ele : output) {
+ for (ObjectDetectorOutput i : ele) {
+ outputStr += "Class: " + i.getClassName() + "\n";
+ outputStr += "Probabilties: " + i.getProbability() + "\n";
+
+ List coord = Arrays.asList(i.getXMin() * width,
+ i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+ StringBuilder sb = new StringBuilder();
+ for (float c: coord) {
+ sb.append(", ").append(c);
+ }
+ outputStr += "Coord:" + sb.substring(2)+ "\n";
+ }
+ }
+ System.out.println(outputStr);
+ }
+}
\ No newline at end of file
diff --git a/scala-package/mxnet-demo/Makefile b/scala-package/mxnet-demo/scala-demo/Makefile
similarity index 98%
rename from scala-package/mxnet-demo/Makefile
rename to scala-package/mxnet-demo/scala-demo/Makefile
index 227697ba2e8a..458077d13904 100644
--- a/scala-package/mxnet-demo/Makefile
+++ b/scala-package/mxnet-demo/scala-demo/Makefile
@@ -17,7 +17,7 @@
SCALA_VERSION_PROFILE := 2.11
SCALA_VERSION := 2.11.8
-MXNET_VERSION := 1.2.0
+MXNET_VERSION := 1.3.0
ifeq ($(OS),Windows_NT)
UNAME_S := Windows
diff --git a/scala-package/mxnet-demo/README.md b/scala-package/mxnet-demo/scala-demo/README.md
similarity index 88%
rename from scala-package/mxnet-demo/README.md
rename to scala-package/mxnet-demo/scala-demo/README.md
index e30a61a2fc13..300fc7b2e108 100644
--- a/scala-package/mxnet-demo/README.md
+++ b/scala-package/mxnet-demo/scala-demo/README.md
@@ -4,7 +4,7 @@ This is an project created to use Maven-published Scala package with two Scala e
User are required to use `mvn package` to build the package,
which are shown below:
```Bash
-export SCALA_VERSION_PROFILE=2.11 SCALA_VERSION=2.11.8 MXNET_VERSION=1.2.0
+export SCALA_VERSION_PROFILE=2.11 SCALA_VERSION=2.11.8 MXNET_VERSION=1.3.0
export SCALA_PKG_PROFILE=
mvn package -Dmxnet.profile=$(SCALA_PKG_PROFILE) \
-Dmxnet.scalaprofile=$(SCALA_VERSION_PROFILE) \
@@ -12,7 +12,9 @@ mvn package -Dmxnet.profile=$(SCALA_PKG_PROFILE) \
-Dscala.version=$(SCALA_VERSION)
```
These environment variable (`SCALA_PKG_PROFILE`, `SCALA_VERSION_PROFILE`, `MXNET_VERSION`, `SCALA_VERSION`)
-should be set before executing the line above.
+should be set before executing the line above.
+
+To obtain the most recent MXNet version, please click [here](https://mvnrepository.com/search?q=org.apache.mxnet)
You can also use the `Makefile` as an alternative to do the same thing. Simply do the following:
```Bash
@@ -25,7 +27,7 @@ This will load the default parameter for all the environment variable.
### Hello World
The Scala file is being executed using Java. You can execute the helloWorld example as follows:
```Bash
-java -Xmx8G -cp $CLASSPATH sample.HelloWorld
+java -cp $CLASSPATH sample.HelloWorld
```
However, you have to define the Classpath before you run the demo code. More information can be found in the `demo.sh` And you can run the bash script as follows:
```Bash
@@ -41,7 +43,7 @@ You can review the complete example [here](https://github.com/apache/incubator-m
you can run using the command shown below:
```Bash
-java -Xmx8G -cp $CLASSPATH sample.ImageClassificationExample
+java -cp $CLASSPATH sample.ImageClassificationExample
```
or script as follows:
```Bash
@@ -59,7 +61,7 @@ make scalaclean
```
## Q & A
-If you are facing opencv issue on Ubuntu, please try as follows to install opencv 3.4 (required by 1.2.0 package)
+If you are facing opencv issue on Ubuntu, please try as follows to install opencv 3.4 (required by 1.2.0 package and above)
```Bash
sudo add-apt-repository ppa:timsc/opencv-3.4
sudo apt-get update
diff --git a/scala-package/mxnet-demo/bin/demo.sh b/scala-package/mxnet-demo/scala-demo/bin/demo.sh
similarity index 100%
rename from scala-package/mxnet-demo/bin/demo.sh
rename to scala-package/mxnet-demo/scala-demo/bin/demo.sh
diff --git a/scala-package/mxnet-demo/bin/run_im.sh b/scala-package/mxnet-demo/scala-demo/bin/run_im.sh
similarity index 100%
rename from scala-package/mxnet-demo/bin/run_im.sh
rename to scala-package/mxnet-demo/scala-demo/bin/run_im.sh
diff --git a/scala-package/mxnet-demo/pom.xml b/scala-package/mxnet-demo/scala-demo/pom.xml
similarity index 100%
rename from scala-package/mxnet-demo/pom.xml
rename to scala-package/mxnet-demo/scala-demo/pom.xml
diff --git a/scala-package/mxnet-demo/src/main/scala/sample/HelloWorld.scala b/scala-package/mxnet-demo/scala-demo/src/main/scala/sample/HelloWorld.scala
similarity index 100%
rename from scala-package/mxnet-demo/src/main/scala/sample/HelloWorld.scala
rename to scala-package/mxnet-demo/scala-demo/src/main/scala/sample/HelloWorld.scala
diff --git a/scala-package/mxnet-demo/src/main/scala/sample/ImageClassificationExample.scala b/scala-package/mxnet-demo/scala-demo/src/main/scala/sample/ImageClassificationExample.scala
similarity index 100%
rename from scala-package/mxnet-demo/src/main/scala/sample/ImageClassificationExample.scala
rename to scala-package/mxnet-demo/scala-demo/src/main/scala/sample/ImageClassificationExample.scala
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index be28f0f09fe2..3240c144e822 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -190,8 +190,8 @@
maven-compiler-plugin3.3
-
- 1.6
+
+ 1.7UTF-8
diff --git a/tests/tutorials/test_sanity_tutorials.py b/tests/tutorials/test_sanity_tutorials.py
index 0ebeb59bf40d..9e5c38abc976 100644
--- a/tests/tutorials/test_sanity_tutorials.py
+++ b/tests/tutorials/test_sanity_tutorials.py
@@ -50,12 +50,15 @@
'scala/mnist.md',
'scala/index.md',
'scala/mxnet_scala_on_intellij.md',
+ 'scala/mxnet_java_install_and_run_examples.md',
'sparse/index.md',
'speech_recognition/index.md',
'unsupervised_learning/index.md',
'vision/index.md',
'tensorrt/index.md',
- 'tensorrt/inference_with_trt.md']
+ 'tensorrt/inference_with_trt.md',
+ 'java/mxnet_java_on_intellij.md',
+ 'java/ssd_inference.md']
whitelist_set = set(whitelist)
def test_tutorial_downloadable():