Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support torchserver for unconditional models #131

Merged
merged 2 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ data
*.log.json
work_dirs/
*.DS_Store

# PyTorch
*.pth
mmgen/configs/
mmgen/tools/
runs/

# Pytorch Server
*.mar
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ line_length=79
multi_line_output=0
known_standard_library=argparse,inspect,contextlib,hashlib,subprocess,unittest,tempfile,copy,pkg_resources,logging,pickle,platform,setuptools,abc,collections,functools,os,math,time,warnings,random,shutil,sys
known_first_party=mmgen
known_third_party=PIL,click,cv2,m2r,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm
known_third_party=PIL,click,cv2,m2r,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm,ts
no_lines_before=STDLIB,LOCALFOLDER
default_section=THIRDPARTY
114 changes: 114 additions & 0 deletions tools/deployment/mmgen2torchserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

import mmcv

try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None


def mmgen2torchserver(config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
model_type: str = 'unconditional',
force: bool = False):
"""Converts MMGeneration model (config + checkpoint) to TorchServe `.mar`.

Args:
config_file (str): Path of config file. The config should in
MMGeneration format.
checkpoint_file (str): Path of checkpoint. The checkpoint should in
MMGeneration checkpoint format.
output_folder (str): Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name (str): Name of the generated ``'mar'`` file. If not None,
used for naming the `{model_name}.mar` file that will be created
under `output_folder`. If None, `{Path(checkpoint_file).stem}`
will be used.
model_version (str, optional): Model's version. Defaults to '1.0'.
model_type (str, optional): Type of the model to be convert. Handler
named ``{model_type}_handler`` would be used to generate ``mar``
file. Defaults to 'unconditional'.
force (bool, optional): If True, existing `{model_name}.mar` will be
overwritten. Default to False.
"""
mmcv.mkdir_or_exist(output_folder)

config = mmcv.Config.fromfile(config_file)

with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')

args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler':
f'{Path(__file__).parent}/mmgen_{model_type}_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)


def parse_args():
parser = ArgumentParser(
description='Convert MMGeneration models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-type',
type=str,
default='unconditional',
help='Which model type and handler to be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
LeoXing1996 marked this conversation as resolved.
Show resolved Hide resolved
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()

return args


if __name__ == '__main__':
args = parse_args()

if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')

mmgen2torchserver(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.model_type,
args.force)
57 changes: 57 additions & 0 deletions tools/deployment/mmgen_unconditional_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler

from mmgen.apis import init_model


class MMGenUnconditionalHandler(BaseHandler):

def initialize(self, context):
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest

model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')

self.model = init_model(self.config_file, checkpoint, self.device)
self.initialized = True

def preprocess(self, data, *args, **kwargs):
data_decode = dict()
# `data` type is `list[dict]`
for k, v in data[0].items():
# deocde strings
if isinstance(v, bytearray):
data_decode[k] = v.decode()
return data_decode

def inference(self, data, *args, **kwargs):
sample_model = data['sample_model']
print(sample_model)
results = self.model.sample_from_noise(
None, num_batches=1, sample_model=sample_model, **kwargs)
return results

def postprocess(self, data):
# convert torch tensor to numpy and then covert to bytes
output_list = []
for data_ in data:
data_ = (data_ + 1) / 2
data_ = data_[[2, 1, 0], ...]
data_ = data_.clamp_(0, 1)
data_ = (data_ * 255).permute(1, 2, 0)
data_np = data_.detach().cpu().numpy().astype(np.uint8)
data_byte = data_np.tobytes()
output_list.append(data_byte)

return output_list
58 changes: 58 additions & 0 deletions tools/deployment/test_torchserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from argparse import ArgumentParser

import numpy as np
import requests
from PIL import Image


def parse_args():
parser = ArgumentParser()
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--img-path',
type=str,
default='demo.png',
help='Path to save generated image.')
parser.add_argument(
'--img-size', type=int, default=128, help='Size of the output image.')
parser.add_argument(
'--sample-model',
type=str,
default='ema/orig',
help='Which model you want to use.')
args = parser.parse_args()
return args


def save_results(contents, img_path, img_size):
if not isinstance(contents, list):
Image.frombytes('RGB', (img_size, img_size), contents).save(img_path)
return

imgs = []
for content in contents:
imgs.append(
np.array(Image.frombytes('RGB', (img_size, img_size), content)))
Image.fromarray(np.concatenate(imgs, axis=1)).save(img_path)


def main(args):
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name

if args.sample_model == 'ema/orig':
cont_ema = requests.post(url, {'sample_model': 'ema'}).content
cont_orig = requests.post(url, {'sample_model': 'orig'}).content
save_results([cont_ema, cont_orig], args.img_path, args.img_size)
return

response = requests.post(url, {'sample_model': args.sample_model})
save_results(response.content, args.img_path, args.img_size)


if __name__ == '__main__':
args = parse_args()
main(args)