Skip to content

Commit

Permalink
Update torchserve after test
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinshuo committed Oct 12, 2021
1 parent fde23a6 commit ee27468
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 23 deletions.
16 changes: 2 additions & 14 deletions tools/deployment/mmedit2torchserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def mmedit2torchserve(
output_folder: str,
model_name: str,
model_version: str = '1.0',
model_type: str = 'unconditional',
force: bool = False,
):
"""Converts MMEditing model (config + checkpoint) to TorchServe `.mar`.
Expand All @@ -38,8 +37,6 @@ def mmedit2torchserve(
If None, `{Path(checkpoint_file).stem}` will be used.
model_version:
Model's version.
model_type:
Model's type.
force:
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
Expand All @@ -54,11 +51,8 @@ def mmedit2torchserve(
args_ = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'model_path': output_folder,
'convert': True,
'serialized_file': checkpoint_file,
'handler':
f'{Path(__file__).parent}/mmedit_{model_type}_handler.py',
'handler': f'{Path(__file__).parent}/mmedit_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
Expand Down Expand Up @@ -89,11 +83,6 @@ def parse_args():
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-type',
type=str,
default='unconditional',
help='Which model type and handler to be used.')
parser.add_argument(
'--model-version',
type=str,
Expand All @@ -117,5 +106,4 @@ def parse_args():
'Try: pip install torch-model-archiver')

mmedit2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.model_type,
args.force)
args.model_name, args.model_version, args.force)
7 changes: 2 additions & 5 deletions tools/deployment/mmedit_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler

from mmedit.apis import init_model, restoration_inference
from mmedit.core import tensor2img


class MMEditHandler(BaseHandler):
Expand Down Expand Up @@ -38,10 +38,7 @@ def postprocess(self, data):
# convert torch tensor to numpy and then covert to bytes
output_list = []
for data_ in data:
data_ = data_[[2, 1, 0], ...] # RGB to BGR
data_ = data_.clamp_(0, 1)
data_ = (data_ * 255).permute(1, 2, 0)
data_np = data_.detach().cpu().numpy().astype(np.uint8)
data_np = tensor2img(data_)
data_byte = data_np.tobytes()
output_list.append(data_byte)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
Expand All @@ -17,21 +18,25 @@ def parse_args():
default='demo.png',
help='Path to save generated image.')
parser.add_argument(
'--img-size', type=int, default=128, help='Size of the output image.')
'--img-size',
type=list,
default=(256, 256),
help='Size of the output image.')
args = parser.parse_args()
return args


def save_results(content, img_path, img_size):
img = Image.frombytes('RGB', (img_size, img_size), content)
print(content)
img = Image.frombytes('RGB', img_size, content)
img.save(img_path)


def main(args):
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name

# just post a meanless dict
response = requests.post(url, {'key': 'value'})
with open(args.img, 'rb') as image:
response = requests.post(url, image)
save_results(response.content, args.img_path, args.img_size)


Expand Down

0 comments on commit ee27468

Please sign in to comment.