Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yshuo-Li committed Oct 14, 2021
1 parent ee27468 commit 8bb8420
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
1 change: 1 addition & 0 deletions tools/deployment/mmedit2torchserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def mmedit2torchserve(
'runtime': 'python',
'archive_format': 'default'
})
print(args_.model_name)
manifest = ModelExportUtils.generate_manifest_json(args_)
package_model(args_, manifest)

Expand Down
18 changes: 16 additions & 2 deletions tools/deployment/mmedit_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
import string
from io import BytesIO

import PIL.Image as Image
import torch
from ts.torch_handler.base_handler import BaseHandler

Expand All @@ -11,6 +15,7 @@
class MMEditHandler(BaseHandler):

def initialize(self, context):
print('MMEditHandler.initialize is called')
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
Expand All @@ -27,11 +32,20 @@ def initialize(self, context):
self.initialized = True

def preprocess(self, data, *args, **kwargs):
body = data[0].get('data') or data[0].get('body')
result = Image.open(BytesIO(body))
# data preprocess is in inference.
return data
return result

def inference(self, data, *args, **kwargs):
results = restoration_inference(self.model, data)
# generate temp image path for restoration_inference
temp_name = ''.join(
random.sample(string.ascii_letters + string.digits, 18))
temp_path = f'./{temp_name}.png'
data.save(temp_path)
results = restoration_inference(self.model, temp_path)
# delete the temp image path
os.remove(temp_path)
return results

def postprocess(self, data):
Expand Down
31 changes: 14 additions & 17 deletions tools/deployment/test_torchserver.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,40 @@
from argparse import ArgumentParser

import cv2
import numpy as np
import requests
from PIL import Image


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument('--img-path', type=str, help='The input LQ image.')
parser.add_argument(
'--img-path',
type=str,
default='demo.png',
help='Path to save generated image.')
parser.add_argument(
'--img-size',
type=list,
default=(256, 256),
help='Size of the output image.')
'--save-path', type=str, help='Path to save the generated GT image.')
args = parser.parse_args()
return args


def save_results(content, img_path, img_size):
print(content)
img = Image.frombytes('RGB', img_size, content)
img.save(img_path)
def save_results(content, save_path, ori_shape):
ori_len = np.prod(ori_shape)
scale = int(np.sqrt(len(content) / ori_len))
target_size = [int(size * scale) for size in ori_shape[:2][::-1]]
# Convert to RGB and save image
img = Image.frombytes('RGB', target_size, content, 'raw', 'BGR', 0, 0)
img.save(save_path)


def main(args):
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name

with open(args.img, 'rb') as image:
ori_shape = cv2.imread(args.img_path).shape
with open(args.img_path, 'rb') as image:
response = requests.post(url, image)
save_results(response.content, args.img_path, args.img_size)
save_results(response.content, args.save_path, ori_shape)


if __name__ == '__main__':
Expand Down

0 comments on commit 8bb8420

Please sign in to comment.