diff --git a/.gitignore b/.gitignore index e871f10b4..3f65d6213 100644 --- a/.gitignore +++ b/.gitignore @@ -114,8 +114,12 @@ data *.log.json work_dirs/ *.DS_Store + # PyTorch *.pth mmgen/configs/ mmgen/tools/ runs/ + +# Pytorch Server +*.mar diff --git a/setup.cfg b/setup.cfg index 03e746ec7..3c88b6c89 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length=79 multi_line_output=0 known_standard_library=argparse,inspect,contextlib,hashlib,subprocess,unittest,tempfile,copy,pkg_resources,logging,pickle,platform,setuptools,abc,collections,functools,os,math,time,warnings,random,shutil,sys known_first_party=mmgen -known_third_party=PIL,click,cv2,m2r,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm +known_third_party=PIL,click,cv2,m2r,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm,ts no_lines_before=STDLIB,LOCALFOLDER default_section=THIRDPARTY diff --git a/tools/deployment/mmgen2torchserver.py b/tools/deployment/mmgen2torchserver.py new file mode 100644 index 000000000..dea56ad5b --- /dev/null +++ b/tools/deployment/mmgen2torchserver.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser, Namespace +from pathlib import Path +from tempfile import TemporaryDirectory + +import mmcv + +try: + from model_archiver.model_packaging import package_model + from model_archiver.model_packaging_utils import ModelExportUtils +except ImportError: + package_model = None + + +def mmgen2torchserver(config_file: str, + checkpoint_file: str, + output_folder: str, + model_name: str, + model_version: str = '1.0', + model_type: str = 'unconditional', + force: bool = False): + """Converts MMGeneration model (config + checkpoint) to TorchServe `.mar`. + + Args: + config_file (str): Path of config file. The config should in + MMGeneration format. + checkpoint_file (str): Path of checkpoint. The checkpoint should in + MMGeneration checkpoint format. + output_folder (str): Folder where `{model_name}.mar` will be created. + The file created will be in TorchServe archive format. + model_name (str): Name of the generated ``'mar'`` file. If not None, + used for naming the `{model_name}.mar` file that will be created + under `output_folder`. If None, `{Path(checkpoint_file).stem}` + will be used. + model_version (str, optional): Model's version. Defaults to '1.0'. + model_type (str, optional): Type of the model to be convert. Handler + named ``{model_type}_handler`` would be used to generate ``mar`` + file. Defaults to 'unconditional'. + force (bool, optional): If True, existing `{model_name}.mar` will be + overwritten. Default to False. + """ + mmcv.mkdir_or_exist(output_folder) + + config = mmcv.Config.fromfile(config_file) + + with TemporaryDirectory() as tmpdir: + config.dump(f'{tmpdir}/config.py') + + args = Namespace( + **{ + 'model_file': f'{tmpdir}/config.py', + 'serialized_file': checkpoint_file, + 'handler': + f'{Path(__file__).parent}/mmgen_{model_type}_handler.py', + 'model_name': model_name or Path(checkpoint_file).stem, + 'version': model_version, + 'export_path': output_folder, + 'force': force, + 'requirements_file': None, + 'extra_files': None, + 'runtime': 'python', + 'archive_format': 'default' + }) + manifest = ModelExportUtils.generate_manifest_json(args) + package_model(args, manifest) + + +def parse_args(): + parser = ArgumentParser( + description='Convert MMGeneration models to TorchServe `.mar` format.') + parser.add_argument('config', type=str, help='config file path') + parser.add_argument('checkpoint', type=str, help='checkpoint file path') + parser.add_argument( + '--output-folder', + type=str, + required=True, + help='Folder where `{model_name}.mar` will be created.') + parser.add_argument( + '--model-name', + type=str, + default=None, + help='If not None, used for naming the `{model_name}.mar`' + 'file that will be created under `output_folder`.' + 'If None, `{Path(checkpoint_file).stem}` will be used.') + parser.add_argument( + '--model-type', + type=str, + default='unconditional', + help='Which model type and handler to be used.') + parser.add_argument( + '--model-version', + type=str, + default='1.0', + help='Number used for versioning.') + parser.add_argument( + '-f', + '--force', + action='store_true', + help='overwrite the existing `{model_name}.mar`') + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + args = parse_args() + + if package_model is None: + raise ImportError('`torch-model-archiver` is required.' + 'Try: pip install torch-model-archiver') + + mmgen2torchserver(args.config, args.checkpoint, args.output_folder, + args.model_name, args.model_version, args.model_type, + args.force) diff --git a/tools/deployment/mmgen_unconditional_handler.py b/tools/deployment/mmgen_unconditional_handler.py new file mode 100644 index 000000000..0c42d1f87 --- /dev/null +++ b/tools/deployment/mmgen_unconditional_handler.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import numpy as np +import torch +from ts.torch_handler.base_handler import BaseHandler + +from mmgen.apis import init_model + + +class MMGenUnconditionalHandler(BaseHandler): + + def initialize(self, context): + properties = context.system_properties + self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(self.map_location + ':' + + str(properties.get('gpu_id')) if torch.cuda. + is_available() else self.map_location) + self.manifest = context.manifest + + model_dir = properties.get('model_dir') + serialized_file = self.manifest['model']['serializedFile'] + checkpoint = os.path.join(model_dir, serialized_file) + self.config_file = os.path.join(model_dir, 'config.py') + + self.model = init_model(self.config_file, checkpoint, self.device) + self.initialized = True + + def preprocess(self, data, *args, **kwargs): + data_decode = dict() + # `data` type is `list[dict]` + for k, v in data[0].items(): + # deocde strings + if isinstance(v, bytearray): + data_decode[k] = v.decode() + return data_decode + + def inference(self, data, *args, **kwargs): + sample_model = data['sample_model'] + print(sample_model) + results = self.model.sample_from_noise( + None, num_batches=1, sample_model=sample_model, **kwargs) + return results + + def postprocess(self, data): + # convert torch tensor to numpy and then covert to bytes + output_list = [] + for data_ in data: + data_ = (data_ + 1) / 2 + data_ = data_[[2, 1, 0], ...] + data_ = data_.clamp_(0, 1) + data_ = (data_ * 255).permute(1, 2, 0) + data_np = data_.detach().cpu().numpy().astype(np.uint8) + data_byte = data_np.tobytes() + output_list.append(data_byte) + + return output_list diff --git a/tools/deployment/test_torchserver.py b/tools/deployment/test_torchserver.py new file mode 100644 index 000000000..d10c4a145 --- /dev/null +++ b/tools/deployment/test_torchserver.py @@ -0,0 +1,58 @@ +from argparse import ArgumentParser + +import numpy as np +import requests +from PIL import Image + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('model_name', help='The model name in the server') + parser.add_argument( + '--inference-addr', + default='127.0.0.1:8080', + help='Address and port of the inference server') + parser.add_argument( + '--img-path', + type=str, + default='demo.png', + help='Path to save generated image.') + parser.add_argument( + '--img-size', type=int, default=128, help='Size of the output image.') + parser.add_argument( + '--sample-model', + type=str, + default='ema/orig', + help='Which model you want to use.') + args = parser.parse_args() + return args + + +def save_results(contents, img_path, img_size): + if not isinstance(contents, list): + Image.frombytes('RGB', (img_size, img_size), contents).save(img_path) + return + + imgs = [] + for content in contents: + imgs.append( + np.array(Image.frombytes('RGB', (img_size, img_size), content))) + Image.fromarray(np.concatenate(imgs, axis=1)).save(img_path) + + +def main(args): + url = 'http://' + args.inference_addr + '/predictions/' + args.model_name + + if args.sample_model == 'ema/orig': + cont_ema = requests.post(url, {'sample_model': 'ema'}).content + cont_orig = requests.post(url, {'sample_model': 'orig'}).content + save_results([cont_ema, cont_orig], args.img_path, args.img_size) + return + + response = requests.post(url, {'sample_model': args.sample_model}) + save_results(response.content, args.img_path, args.img_size) + + +if __name__ == '__main__': + args = parse_args() + main(args)