-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add function to meet mmdeploy support (#102)
* add init_model function for mmdeploy * fix lint * add unittest for init_xxx_model * fix lint * mv test_inference.py to test_apis directory
- Loading branch information
1 parent
81e0e34
commit 7e251d8
Showing
7 changed files
with
249 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .inference import init_mmcls_model | ||
from .train import set_random_seed, train_model | ||
|
||
__all__ = [ | ||
'set_random_seed', | ||
'train_model', | ||
] | ||
__all__ = ['set_random_seed', 'train_model', 'init_mmcls_model'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
from typing import Dict, Optional, Union | ||
|
||
import mmcv | ||
from mmcv.runner import load_checkpoint | ||
from torch import nn | ||
|
||
from mmrazor.models import build_algorithm | ||
|
||
|
||
def init_mmcls_model(config: Union[str, mmcv.Config], | ||
checkpoint: Optional[str] = None, | ||
device: str = 'cuda:0', | ||
cfg_options: Optional[Dict] = None) -> nn.Module: | ||
"""Initialize a mmcls model from config file. | ||
Args: | ||
config (str or :obj:`mmcv.Config`): Config file path or the config | ||
object. | ||
checkpoint (str, optional): Checkpoint path. If left as None, the model | ||
will not load any weights. | ||
cfg_options (dict): cfg_options to override some settings in the used | ||
config. | ||
Returns: | ||
nn.Module: The constructed classifier. | ||
""" | ||
if isinstance(config, str): | ||
config = mmcv.Config.fromfile(config) | ||
elif not isinstance(config, mmcv.Config): | ||
raise TypeError('config must be a filename or Config object, ' | ||
f'but got {type(config)}') | ||
if cfg_options is not None: | ||
config.merge_from_dict(cfg_options) | ||
|
||
model_cfg = config.algorithm.architecture.model | ||
model_cfg.pretrained = None | ||
algorithm = build_algorithm(config.algorithm) | ||
model = algorithm.architecture.model | ||
|
||
if checkpoint is not None: | ||
# Mapping the weights to GPU may cause unexpected video memory leak | ||
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405 | ||
checkpoint = load_checkpoint(algorithm, checkpoint, map_location='cpu') | ||
if 'CLASSES' in checkpoint.get('meta', {}): | ||
model.CLASSES = checkpoint['meta']['CLASSES'] | ||
else: | ||
from mmcls.datasets import ImageNet | ||
warnings.simplefilter('once') | ||
warnings.warn('Class names are not saved in the checkpoint\'s ' | ||
'meta data, use imagenet by default.') | ||
model.CLASSES = ImageNet.CLASSES | ||
model.cfg = config # save the config in the model for convenience | ||
model.to(device) | ||
model.eval() | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
from typing import Dict, Optional, Union | ||
|
||
import mmcv | ||
from mmcv.runner import load_checkpoint | ||
from mmdet.core import get_classes | ||
from torch import nn | ||
|
||
from mmrazor.models import build_algorithm | ||
|
||
|
||
def init_mmdet_model(config: Union[str, mmcv.Config], | ||
checkpoint: Optional[str] = None, | ||
device: str = 'cuda:0', | ||
cfg_options: Optional[Dict] = None) -> nn.Module: | ||
"""Initialize a mmdet model from config file. | ||
Args: | ||
config (str or :obj:`mmcv.Config`): Config file path or the config | ||
object. | ||
checkpoint (str, optional): Checkpoint path. If left as None, the model | ||
will not load any weights. | ||
cfg_options (dict): Options to override some settings in the used | ||
config. | ||
Returns: | ||
nn.Module: The constructed detector. | ||
""" | ||
if isinstance(config, str): | ||
config = mmcv.Config.fromfile(config) | ||
elif not isinstance(config, mmcv.Config): | ||
raise TypeError('config must be a filename or Config object, ' | ||
f'but got {type(config)}') | ||
if cfg_options is not None: | ||
config.merge_from_dict(cfg_options) | ||
|
||
model_cfg = config.algorithm.architecture.model | ||
if 'pretrained' in model_cfg: | ||
model_cfg.pretrained = None | ||
elif 'init_cfg' in model_cfg.backbone: | ||
model_cfg.backbone.init_cfg = None | ||
|
||
config.model.train_cfg = None | ||
algorithm = build_algorithm(config.algorithm) | ||
model = algorithm.architecture.model | ||
|
||
if checkpoint is not None: | ||
checkpoint = load_checkpoint(algorithm, checkpoint, map_location='cpu') | ||
if 'CLASSES' in checkpoint.get('meta', {}): | ||
model.CLASSES = checkpoint['meta']['CLASSES'] | ||
else: | ||
warnings.simplefilter('once') | ||
warnings.warn('Class names are not saved in the checkpoint\'s ' | ||
'meta data, use COCO classes by default.') | ||
model.CLASSES = get_classes('coco') | ||
model.cfg = config # save the config in the model for convenience | ||
model.to(device) | ||
model.eval() | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Optional, Union | ||
|
||
import mmcv | ||
from mmcv.runner import load_checkpoint | ||
from torch import nn | ||
|
||
from mmrazor.models import build_algorithm | ||
|
||
|
||
def init_mmseg_model(config: Union[str, mmcv.Config], | ||
checkpoint: Optional[str] = None, | ||
device: str = 'cuda:0') -> nn.Module: | ||
"""Initialize a mmseg model from config file. | ||
Args: | ||
config (str or :obj:`mmcv.Config`): Config file path or the config | ||
object. | ||
checkpoint (str, optional): Checkpoint path. If left as None, the model | ||
will not load any weights. | ||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'. | ||
Use 'cpu' for loading model on CPU. | ||
Returns: | ||
nn.Module: The constructed segmentor. | ||
""" | ||
if isinstance(config, str): | ||
config = mmcv.Config.fromfile(config) | ||
elif not isinstance(config, mmcv.Config): | ||
raise TypeError('config must be a filename or Config object, ' | ||
'but got {}'.format(type(config))) | ||
|
||
model_cfg = config.algorithm.architecture.model | ||
model_cfg.pretrained = None | ||
model_cfg.train_cfg = None | ||
algorithm = build_algorithm(config.algorithm) | ||
model = algorithm.architecture.model | ||
|
||
if checkpoint is not None: | ||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') | ||
model.CLASSES = checkpoint['meta']['CLASSES'] | ||
model.PALETTE = checkpoint['meta']['PALETTE'] | ||
model.cfg = config # save the config in the model for convenience | ||
model.to(device) | ||
model.eval() | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from pathlib import Path | ||
|
||
import mmcv | ||
from mmcls.apis import inference_model | ||
from mmdet.apis import inference_detector | ||
from mmseg.apis import inference_segmentor | ||
|
||
from mmrazor.apis import init_mmcls_model, init_mmdet_model, init_mmseg_model | ||
|
||
|
||
def test_init_mmcls_model(): | ||
from mmcls.datasets import ImageNet | ||
|
||
config_file = 'configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py' | ||
config = mmcv.Config.fromfile(config_file) | ||
|
||
mutable_file = 'configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml' | ||
model = init_mmcls_model( | ||
config, | ||
device='cpu', | ||
cfg_options={'algorithm.mutable_cfg': mutable_file}) | ||
model.CLASSES = ImageNet.CLASSES | ||
assert not hasattr(model, 'architecture') | ||
assert hasattr(model, 'backbone') | ||
assert hasattr(model, 'neck') | ||
assert hasattr(model, 'head') | ||
|
||
img = mmcv.imread(Path(__file__).parent.parent / 'data/color.jpg', 'color') | ||
result = inference_model(model, img) | ||
assert isinstance(result, dict) | ||
assert result.get('pred_label') is not None | ||
assert result.get('pred_score') is not None | ||
assert result.get('pred_class') is not None | ||
|
||
|
||
def test_init_mmdet_model(): | ||
config_file = \ | ||
'configs/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py' | ||
config = mmcv.Config.fromfile(config_file) | ||
|
||
mutable_file = \ | ||
'configs/nas/detnas/DETNAS_FRCNN_SHUFFLENETV2_340M_COCO_MMRAZOR.yaml' | ||
model = init_mmdet_model( | ||
config, | ||
device='cpu', | ||
cfg_options={'algorithm.mutable_cfg': mutable_file}) | ||
assert not hasattr(model, 'architecture') | ||
|
||
img = mmcv.imread(Path(__file__).parent.parent / 'data/color.jpg', 'color') | ||
result = inference_detector(model, img) | ||
assert isinstance(result, list) | ||
|
||
|
||
def test_init_mmseg_model(): | ||
config_file = 'configs/distill/cwd/' \ | ||
'cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py' | ||
config = mmcv.Config.fromfile(config_file) | ||
|
||
# Replace SyncBN with BN to inference on CPU | ||
norm_cfg = dict(type='BN', requires_grad=True) | ||
model_config = config.algorithm.architecture | ||
model_config.model.backbone.norm_cfg = norm_cfg | ||
model_config.model.decode_head.norm_cfg = norm_cfg | ||
model_config.model.auxiliary_head.norm_cfg = norm_cfg | ||
|
||
# Enable test time augmentation | ||
config.data.test.pipeline[1].flip = True | ||
|
||
model = init_mmseg_model(config, device='cpu') | ||
assert not hasattr(model, 'architecture') | ||
assert hasattr(model, 'backbone') | ||
assert hasattr(model, 'decode_head') | ||
assert hasattr(model, 'auxiliary_head') | ||
|
||
img = mmcv.imread(Path(__file__).parent.parent / 'data/color.jpg', 'color') | ||
result = inference_segmentor(model, img) | ||
assert result[0].shape == (300, 400) |