Skip to content

Commit

Permalink
Merge pull request #14 from IBM/png-support
Browse files Browse the repository at this point in the history
Support 'RGBA' mode images by converting to 'RGB'
  • Loading branch information
MLnick authored Mar 6, 2019
2 parents 8d42f85 + 948df98 commit 4f513bd
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 22 deletions.
9 changes: 3 additions & 6 deletions api/predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from core.model import ModelWrapper
from flask_restplus import fields, abort
from flask_restplus import fields
from werkzeug.datastructures import FileStorage
from maxfw.core import MAX_API, PredictAPI

Expand Down Expand Up @@ -58,11 +58,8 @@ def post(self):
"""Make a prediction given input data"""
result = {'status': 'error'}
args = input_parser.parse_args()
try:
input_data = args['file'].read()
image = self.model_wrapper._read_image(input_data)
except OSError as e:
abort(400, "Please submit a valid image in PNG, Tiff or JPEG format")
input_data = args['file'].read()
image = self.model_wrapper._read_image(input_data)

label_preds = self.model_wrapper.predict(image)
result['predictions'] = label_preds
Expand Down
18 changes: 13 additions & 5 deletions core/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from maxfw.model import MAXModelWrapper

import io
import logging
import time
from PIL import Image
import numpy as np
from maxfw.model import MAXModelWrapper
from flask_restplus import abort

from core.tf_pose.estimator import TfPoseEstimator
from config import DEFAULT_MODEL_PATH, DEFAULT_IMAGE_SIZE, MODEL_NAME
Expand Down Expand Up @@ -35,10 +37,16 @@ def __init__(self, path=DEFAULT_MODEL_PATH):
logger.info("W = {}, H = {} ".format(self.w, self.h))

def _read_image(self, image_data):
image = Image.open(io.BytesIO(image_data))
# Convert RGB to BGR for OpenCV.
image = np.array(image)[:, :, ::-1]
return image
try:
image = Image.open(io.BytesIO(image_data))
if image.mode is not 'RGB':
image = image.convert('RGB')
# Convert RGB to BGR for OpenCV.
image = np.array(image)[:, :, ::-1]
return image
except IOError as e:
logger.error(str(e))
abort(400, "Please submit a valid image in PNG, TIFF or JPEG format")

def _predict(self, x):
t = time.time()
Expand Down
Binary file added tests/Pilots.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/Pilots.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/Pilots.tiff
Binary file not shown.
26 changes: 15 additions & 11 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,7 @@ def test_metadata():
assert metadata['license'] == 'Apache License 2.0'


def test_predict():

model_endpoint = 'http://localhost:5000/model/predict'

# Test by the image with multiple faces
img1_path = 'assets/Pilots.jpg'

with open(img1_path, 'rb') as file:
file_form = {'file': (img1_path, file, 'image/jpeg')}
r = requests.post(url=model_endpoint, files=file_form)

def _check_response(r):
assert r.status_code == 200
response = r.json()

Expand All @@ -49,6 +39,20 @@ def test_predict():
assert len(response['predictions'][0]['pose_lines']) > 0
assert len(response['predictions'][0]['body_parts']) > 0


def test_predict():

model_endpoint = 'http://localhost:5000/model/predict'
formats = ['jpg', 'png', 'tiff']
img_path = 'tests/Pilots.{}'

for f in formats:
p = img_path.format(f)
with open(p, 'rb') as file:
file_form = {'file': (p, file, 'image/{}'.format(f))}
r = requests.post(url=model_endpoint, files=file_form)
_check_response(r)

# Test by the image without faces
img2_path = 'assets/IBM.jpeg'

Expand Down

0 comments on commit 4f513bd

Please sign in to comment.