-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1182] Predictor example #13237
Changes from all commits
91a37c2
2cb07cf
f34ce4e
ce66813
9f9e13b
60ca8c1
823ae31
ca51c24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/bin/bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
hw_type=cpu | ||
if [[ $3 = gpu ]] | ||
then | ||
hw_type=gpu | ||
fi | ||
|
||
platform=linux-x86_64 | ||
|
||
if [[ $OSTYPE = [darwin]* ]] | ||
then | ||
platform=osx-x86_64 | ||
fi | ||
|
||
MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd) | ||
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* | ||
|
||
# model dir and prefix | ||
MODEL_DIR=$1 | ||
# input image | ||
INPUT_IMG=$2 | ||
|
||
java -Xmx8G -cp $CLASS_PATH \ | ||
org.apache.mxnetexamples.javaapi.infer.predictor.PredictorExample \ | ||
--model-path-prefix $MODEL_DIR \ | ||
--input-image $INPUT_IMG |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mxnetexamples.javaapi.infer.predictor; | ||
|
||
import org.apache.mxnet.infer.javaapi.Predictor; | ||
import org.apache.mxnet.javaapi.*; | ||
import org.kohsuke.args4j.CmdLineParser; | ||
import org.kohsuke.args4j.Option; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import javax.imageio.ImageIO; | ||
import java.awt.Graphics2D; | ||
import java.awt.image.BufferedImage; | ||
import java.io.BufferedReader; | ||
import java.io.File; | ||
import java.io.FileReader; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
/** | ||
* This Class is a demo to show how users can use Predictor APIs to do | ||
* Image Classification with all hand-crafted Pre-processing. | ||
* All helper functions for image pre-processing are | ||
* currently available in ObjectDetector class. | ||
*/ | ||
public class PredictorExample { | ||
@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") | ||
private String modelPathPrefix = "/model/ssd_resnet50_512"; | ||
@Option(name = "--input-image", usage = "the input image") | ||
private String inputImagePath = "/images/dog.jpg"; | ||
|
||
final static Logger logger = LoggerFactory.getLogger(PredictorExample.class); | ||
|
||
/** | ||
* Load the image from file to buffered image | ||
* It can be replaced by loadImageFromFile from ObjectDetector | ||
* @param inputImagePath input image Path in String | ||
* @return Buffered image | ||
*/ | ||
private static BufferedImage loadIamgeFromFile(String inputImagePath) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This same logic is in ObjectDetector. Is there any reason we aren't just reusing those methods to make this example simpler? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason to do this is show how users can build their own preprocessing stages with Predictor API as well as there is no single method that can bring me a float array unless I do an NDArray extraction to get a float array |
||
BufferedImage buf = null; | ||
try { | ||
buf = ImageIO.read(new File(inputImagePath)); | ||
} catch (IOException e) { | ||
System.err.println(e); | ||
} | ||
return buf; | ||
} | ||
|
||
/** | ||
* Reshape the current image using ImageIO and Graph2D | ||
* It can be replaced by reshapeImage from ObjectDetector | ||
* @param buf Buffered image | ||
* @param newWidth desired width | ||
* @param newHeight desired height | ||
* @return a reshaped bufferedImage | ||
*/ | ||
private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB); | ||
Graphics2D g = resizedImage.createGraphics(); | ||
g.drawImage(buf, 0, 0, newWidth, newHeight, null); | ||
g.dispose(); | ||
return resizedImage; | ||
} | ||
|
||
/** | ||
* Convert an image from a buffered image into pixels float array | ||
* It can be replaced by bufferedImageToPixels from ObjectDetector | ||
* @param buf buffered image | ||
* @return Float array | ||
*/ | ||
private static float[] imagePreprocess(BufferedImage buf) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
// Get height and width of the image | ||
int w = buf.getWidth(); | ||
int h = buf.getHeight(); | ||
|
||
// get an array of integer pixels in the default RGB color mode | ||
int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w); | ||
|
||
// 3 times height and width for R,G,B channels | ||
float[] result = new float[3 * h * w]; | ||
|
||
int row = 0; | ||
// copy pixels to array vertically | ||
while (row < h) { | ||
int col = 0; | ||
// copy pixels to array horizontally | ||
while (col < w) { | ||
int rgb = pixels[row * w + col]; | ||
// getting red color | ||
result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF; | ||
// getting green color | ||
result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF; | ||
// getting blue color | ||
result[2 * h * w + row * w + col] = rgb & 0xFF; | ||
col += 1; | ||
} | ||
row += 1; | ||
} | ||
buf.flush(); | ||
return result; | ||
} | ||
|
||
/** | ||
* Helper class to print the maximum prediction result | ||
* @param probabilities The float array of probability | ||
* @param modelPathPrefix model Path needs to load the synset.txt | ||
*/ | ||
private static String printMaximumClass(float[] probabilities, | ||
String modelPathPrefix) throws IOException { | ||
String synsetFilePath = modelPathPrefix.substring(0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not take the synsetFilePath from the modelPathPrefix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with you with this part, but they usually ship together in the same location. Most of the examples, SSD/Image classifier do this... |
||
1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt"; | ||
BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath)); | ||
ArrayList<String> list = new ArrayList<>(); | ||
String line = reader.readLine(); | ||
|
||
while (line != null){ | ||
list.add(line); | ||
line = reader.readLine(); | ||
} | ||
reader.close(); | ||
|
||
int maxIdx = 0; | ||
for (int i = 1;i<probabilities.length;i++) { | ||
if (probabilities[i] > probabilities[maxIdx]) { | ||
maxIdx = i; | ||
} | ||
} | ||
|
||
return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ; | ||
} | ||
|
||
public static void main(String[] args) { | ||
PredictorExample inst = new PredictorExample(); | ||
CmdLineParser parser = new CmdLineParser(inst); | ||
try { | ||
parser.parseArgument(args); | ||
} catch (Exception e) { | ||
logger.error(e.getMessage(), e); | ||
parser.printUsage(System.err); | ||
System.exit(1); | ||
} | ||
// Prepare the model | ||
List<Context> context = new ArrayList<Context>(); | ||
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && | ||
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { | ||
context.add(Context.gpu()); | ||
} else { | ||
context.add(Context.cpu()); | ||
} | ||
List<DataDesc> inputDesc = new ArrayList<>(); | ||
Shape inputShape = new Shape(new int[]{1, 3, 224, 224}); | ||
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); | ||
Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0); | ||
// Prepare data | ||
BufferedImage img = loadIamgeFromFile(inst.inputImagePath); | ||
|
||
img = reshapeImage(img, 224, 224); | ||
// predict | ||
float[][] result = predictor.predict(new float[][]{imagePreprocess(img)}); | ||
try { | ||
System.out.println("Predict with Float input"); | ||
System.out.println(printMaximumClass(result[0], inst.modelPathPrefix)); | ||
} catch (IOException e) { | ||
System.err.println(e); | ||
} | ||
// predict with NDArray | ||
NDArray nd = new NDArray( | ||
imagePreprocess(img), | ||
new Shape(new int[]{1, 3, 224, 224}), | ||
Context.cpu()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick question : Should this context be cpu ? Or determined based on the environment variable you have ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, I will fix here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, this one should be CPU, it defines where we keep this NDArray. It doesn't matter much to keep it cpu or gpu as it would copy around |
||
List<NDArray> ndList = new ArrayList<>(); | ||
ndList.add(nd); | ||
List<NDArray> ndResult = predictor.predictWithNDArray(ndList); | ||
try { | ||
System.out.println("Predict with NDArray"); | ||
System.out.println(printMaximumClass(ndResult.get(0).toArray(), inst.modelPathPrefix)); | ||
} catch (IOException e) { | ||
System.err.println(e); | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Image Classification using Java Predictor | ||
|
||
In this example, you will learn how to use Java Inference API to | ||
build and run pre-trained Resnet 18 model. | ||
|
||
## Contents | ||
|
||
1. [Prerequisites](#prerequisites) | ||
2. [Download artifacts](#download-artifacts) | ||
3. [Setup datapath and parameters](#setup-datapath-and-parameters) | ||
4. [Run the image classifier example](#run-the-image-inference-example) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why Image Classification in Prediction? Is it possible to use another example even may be a simple Linear Regression model? |
||
|
||
## Prerequisites | ||
|
||
1. Build from source with [MXNet](https://mxnet.incubator.apache.org/install/index.html) | ||
2. [IntelliJ IDE (or alternative IDE) project setup](https://github.com/apache/incubator-mxnet/blob/master/docs/tutorials/java/mxnet_java_on_intellij.md) with the MXNet Java Package | ||
3. wget | ||
|
||
## Download Artifacts | ||
|
||
For this tutorial, you can get the model and sample input image by running following bash file. This script will use `wget` to download these artifacts from AWS S3. | ||
|
||
From the `scala-package/examples/scripts/infer/imageclassifier/` folder run: | ||
|
||
```bash | ||
./get_resnet_18_data.sh | ||
``` | ||
|
||
**Note**: You may need to run `chmod +x get_resnet_18_data.sh` before running this script. | ||
|
||
### Setup Datapath and Parameters | ||
|
||
The available arguments are as follows: | ||
|
||
| Argument | Comments | | ||
| ----------------------------- | ---------------------------------------- | | ||
| `model-dir` | Folder path with prefix to the model (including json, params, and any synset file). | | ||
| `input-image` | The image to run inference on. | | ||
|
||
## Run the image classifier example | ||
|
||
After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Predictor API. | ||
|
||
From the `scala-package/examples/scripts/infer/predictor/` folder run: | ||
|
||
```bash | ||
bash run_predictor_java_example.sh ../models/resnet-18/resnet-18 ../images/kitten.jpg | ||
``` | ||
|
||
**Notes**: | ||
* These are relative paths to this script. | ||
* You may need to run `chmod +x run_predictor_java_example.sh` before running this script. | ||
|
||
The example should give an output similar to the one shown below: | ||
``` | ||
Predict with Float input | ||
Probability : 0.30337515 Class : n02123159 tiger cat | ||
Predict with NDArray | ||
Probability : 0.30337515 Class : n02123159 tiger cat | ||
lanking520 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
the outputs come from the the input image, with top1 predictions picked. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this class should be renamed to ImageClassificationExample
And in comments we can write that this is an example code which uses Predictor API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we will have an ImageClassification Example once we introduce Image Classifier...