forked from open-mmlab/mmagic
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] support torchserve (open-mmlab#568)
* Support torchserver * Fix * Update torchserve after test * update Co-authored-by: liyinshuo <liyinshuo@sensetime.com>
- Loading branch information
Showing
5 changed files
with
215 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,3 +124,6 @@ mmedit/.mim | |
|
||
# local history | ||
.history/** | ||
|
||
# Pytorch Server | ||
*.mar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from argparse import ArgumentParser, Namespace | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
|
||
import mmcv | ||
|
||
try: | ||
from model_archiver.model_packaging import package_model | ||
from model_archiver.model_packaging_utils import ModelExportUtils | ||
except ImportError: | ||
package_model = None | ||
|
||
|
||
def mmedit2torchserve( | ||
config_file: str, | ||
checkpoint_file: str, | ||
output_folder: str, | ||
model_name: str, | ||
model_version: str = '1.0', | ||
force: bool = False, | ||
): | ||
"""Converts MMEditing model (config + checkpoint) to TorchServe `.mar`. | ||
Args: | ||
config_file: | ||
In MMEditing config format. | ||
The contents vary for each task repository. | ||
checkpoint_file: | ||
In MMEditing checkpoint format. | ||
The contents vary for each task repository. | ||
output_folder: | ||
Folder where `{model_name}.mar` will be created. | ||
The file created will be in TorchServe archive format. | ||
model_name: | ||
If not None, used for naming the `{model_name}.mar` file | ||
that will be created under `output_folder`. | ||
If None, `{Path(checkpoint_file).stem}` will be used. | ||
model_version: | ||
Model's version. | ||
force: | ||
If True, if there is an existing `{model_name}.mar` | ||
file under `output_folder` it will be overwritten. | ||
""" | ||
mmcv.mkdir_or_exist(output_folder) | ||
|
||
config = mmcv.Config.fromfile(config_file) | ||
|
||
with TemporaryDirectory() as tmpdir: | ||
config.dump(f'{tmpdir}/config.py') | ||
|
||
args_ = Namespace( | ||
**{ | ||
'model_file': f'{tmpdir}/config.py', | ||
'serialized_file': checkpoint_file, | ||
'handler': f'{Path(__file__).parent}/mmedit_handler.py', | ||
'model_name': model_name or Path(checkpoint_file).stem, | ||
'version': model_version, | ||
'export_path': output_folder, | ||
'force': force, | ||
'requirements_file': None, | ||
'extra_files': None, | ||
'runtime': 'python', | ||
'archive_format': 'default' | ||
}) | ||
print(args_.model_name) | ||
manifest = ModelExportUtils.generate_manifest_json(args_) | ||
package_model(args_, manifest) | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser( | ||
description='Convert MMEditing models to TorchServe `.mar` format.') | ||
parser.add_argument('config', type=str, help='config file path') | ||
parser.add_argument('checkpoint', type=str, help='checkpoint file path') | ||
parser.add_argument( | ||
'--output-folder', | ||
type=str, | ||
required=True, | ||
help='Folder where `{model_name}.mar` will be created.') | ||
parser.add_argument( | ||
'--model-name', | ||
type=str, | ||
default=None, | ||
help='If not None, used for naming the `{model_name}.mar`' | ||
'file that will be created under `output_folder`.' | ||
'If None, `{Path(checkpoint_file).stem}` will be used.') | ||
parser.add_argument( | ||
'--model-version', | ||
type=str, | ||
default='1.0', | ||
help='Number used for versioning.') | ||
parser.add_argument( | ||
'-f', | ||
'--force', | ||
action='store_true', | ||
help='overwrite the existing `{model_name}.mar`') | ||
args_ = parser.parse_args() | ||
|
||
return args_ | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
|
||
if package_model is None: | ||
raise ImportError('`torch-model-archiver` is required.' | ||
'Try: pip install torch-model-archiver') | ||
|
||
mmedit2torchserve(args.config, args.checkpoint, args.output_folder, | ||
args.model_name, args.model_version, args.force) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
import random | ||
import string | ||
from io import BytesIO | ||
|
||
import PIL.Image as Image | ||
import torch | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
from mmedit.apis import init_model, restoration_inference | ||
from mmedit.core import tensor2img | ||
|
||
|
||
class MMEditHandler(BaseHandler): | ||
|
||
def initialize(self, context): | ||
print('MMEditHandler.initialize is called') | ||
properties = context.system_properties | ||
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
self.device = torch.device(self.map_location + ':' + | ||
str(properties.get('gpu_id')) if torch.cuda. | ||
is_available() else self.map_location) | ||
self.manifest = context.manifest | ||
|
||
model_dir = properties.get('model_dir') | ||
serialized_file = self.manifest['model']['serializedFile'] | ||
checkpoint = os.path.join(model_dir, serialized_file) | ||
self.config_file = os.path.join(model_dir, 'config.py') | ||
|
||
self.model = init_model(self.config_file, checkpoint, self.device) | ||
self.initialized = True | ||
|
||
def preprocess(self, data, *args, **kwargs): | ||
body = data[0].get('data') or data[0].get('body') | ||
result = Image.open(BytesIO(body)) | ||
# data preprocess is in inference. | ||
return result | ||
|
||
def inference(self, data, *args, **kwargs): | ||
# generate temp image path for restoration_inference | ||
temp_name = ''.join( | ||
random.sample(string.ascii_letters + string.digits, 18)) | ||
temp_path = f'./{temp_name}.png' | ||
data.save(temp_path) | ||
results = restoration_inference(self.model, temp_path) | ||
# delete the temp image path | ||
os.remove(temp_path) | ||
return results | ||
|
||
def postprocess(self, data): | ||
# convert torch tensor to numpy and then covert to bytes | ||
output_list = [] | ||
for data_ in data: | ||
data_np = tensor2img(data_) | ||
data_byte = data_np.tobytes() | ||
output_list.append(data_byte) | ||
|
||
return output_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from argparse import ArgumentParser | ||
|
||
import cv2 | ||
import numpy as np | ||
import requests | ||
from PIL import Image | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument('model_name', help='The model name in the server') | ||
parser.add_argument( | ||
'--inference-addr', | ||
default='127.0.0.1:8080', | ||
help='Address and port of the inference server') | ||
parser.add_argument('--img-path', type=str, help='The input LQ image.') | ||
parser.add_argument( | ||
'--save-path', type=str, help='Path to save the generated GT image.') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def save_results(content, save_path, ori_shape): | ||
ori_len = np.prod(ori_shape) | ||
scale = int(np.sqrt(len(content) / ori_len)) | ||
target_size = [int(size * scale) for size in ori_shape[:2][::-1]] | ||
# Convert to RGB and save image | ||
img = Image.frombytes('RGB', target_size, content, 'raw', 'BGR', 0, 0) | ||
img.save(save_path) | ||
|
||
|
||
def main(args): | ||
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name | ||
ori_shape = cv2.imread(args.img_path).shape | ||
with open(args.img_path, 'rb') as image: | ||
response = requests.post(url, image) | ||
save_results(response.content, args.save_path, ori_shape) | ||
|
||
|
||
if __name__ == '__main__': | ||
parsed_args = parse_args() | ||
main(parsed_args) |