Skip to content

Commit

Permalink
[Feature] support torchserve (open-mmlab#568)
Browse files Browse the repository at this point in the history
* Support torchserver

* Fix

* Update torchserve after test

* update

Co-authored-by: liyinshuo <liyinshuo@sensetime.com>
  • Loading branch information
Yshuo-Li and liyinshuo authored Oct 25, 2021
1 parent 8b09b49 commit 906e21f
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ mmedit/.mim

# local history
.history/**

# Pytorch Server
*.mar
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmedit
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,packaging,pymatting,pytest,scipy,titlecase,torch,torchvision
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,packaging,pymatting,pytest,requests,scipy,titlecase,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
110 changes: 110 additions & 0 deletions tools/deployment/mmedit2torchserve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

import mmcv

try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None


def mmedit2torchserve(
config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
force: bool = False,
):
"""Converts MMEditing model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file:
In MMEditing config format.
The contents vary for each task repository.
checkpoint_file:
In MMEditing checkpoint format.
The contents vary for each task repository.
output_folder:
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name:
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version:
Model's version.
force:
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
"""
mmcv.mkdir_or_exist(output_folder)

config = mmcv.Config.fromfile(config_file)

with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')

args_ = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler': f'{Path(__file__).parent}/mmedit_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
print(args_.model_name)
manifest = ModelExportUtils.generate_manifest_json(args_)
package_model(args_, manifest)


def parse_args():
parser = ArgumentParser(
description='Convert MMEditing models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args_ = parser.parse_args()

return args_


if __name__ == '__main__':
args = parse_args()

if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')

mmedit2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.force)
59 changes: 59 additions & 0 deletions tools/deployment/mmedit_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
import string
from io import BytesIO

import PIL.Image as Image
import torch
from ts.torch_handler.base_handler import BaseHandler

from mmedit.apis import init_model, restoration_inference
from mmedit.core import tensor2img


class MMEditHandler(BaseHandler):

def initialize(self, context):
print('MMEditHandler.initialize is called')
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest

model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')

self.model = init_model(self.config_file, checkpoint, self.device)
self.initialized = True

def preprocess(self, data, *args, **kwargs):
body = data[0].get('data') or data[0].get('body')
result = Image.open(BytesIO(body))
# data preprocess is in inference.
return result

def inference(self, data, *args, **kwargs):
# generate temp image path for restoration_inference
temp_name = ''.join(
random.sample(string.ascii_letters + string.digits, 18))
temp_path = f'./{temp_name}.png'
data.save(temp_path)
results = restoration_inference(self.model, temp_path)
# delete the temp image path
os.remove(temp_path)
return results

def postprocess(self, data):
# convert torch tensor to numpy and then covert to bytes
output_list = []
for data_ in data:
data_np = tensor2img(data_)
data_byte = data_np.tobytes()
output_list.append(data_byte)

return output_list
42 changes: 42 additions & 0 deletions tools/deployment/test_torchserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from argparse import ArgumentParser

import cv2
import numpy as np
import requests
from PIL import Image


def parse_args():
parser = ArgumentParser()
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument('--img-path', type=str, help='The input LQ image.')
parser.add_argument(
'--save-path', type=str, help='Path to save the generated GT image.')
args = parser.parse_args()
return args


def save_results(content, save_path, ori_shape):
ori_len = np.prod(ori_shape)
scale = int(np.sqrt(len(content) / ori_len))
target_size = [int(size * scale) for size in ori_shape[:2][::-1]]
# Convert to RGB and save image
img = Image.frombytes('RGB', target_size, content, 'raw', 'BGR', 0, 0)
img.save(save_path)


def main(args):
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
ori_shape = cv2.imread(args.img_path).shape
with open(args.img_path, 'rb') as image:
response = requests.post(url, image)
save_results(response.content, args.save_path, ori_shape)


if __name__ == '__main__':
parsed_args = parse_args()
main(parsed_args)

0 comments on commit 906e21f

Please sign in to comment.