Skip to content

Commit

Permalink
[Feature] Support MPS device. (open-mmlab#894)
Browse files Browse the repository at this point in the history
* [Feature] Support MPS device.

* Add `auto_select_device`

* Add unit tests
  • Loading branch information
mzr1996 authored and Ezra-Yu committed Sep 6, 2022
1 parent 587c269 commit 4cb6332
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 47 deletions.
2 changes: 1 addition & 1 deletion mmcls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def digit_version(version_str: str, length: int = 4):


mmcv_minimum_version = '1.4.2'
mmcv_maximum_version = '1.6.0'
mmcv_maximum_version = '1.7.0'
mmcv_version = digit_version(mmcv.__version__)


Expand Down
26 changes: 7 additions & 19 deletions mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
build_optimizer, build_runner, get_dist_info)

from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger
from mmcls.utils import (get_root_logger, wrap_distributed_model,
wrap_non_distributed_model)


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -128,27 +128,15 @@ def train_model(model,
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
model = wrap_distributed_model(
model,
cfg.device,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if device == 'cpu':
warnings.warn(
'The argument `device` is deprecated. To use cpu to train, '
'please refers to https://mmclassification.readthedocs.io/en'
'/latest/getting_started.html#train-a-model')
model = model.cpu()
elif device == 'ipu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
from mmcv import __version__, digit_version
assert digit_version(__version__) >= (1, 4, 4), \
'To train with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
model = wrap_non_distributed_model(
model, cfg.device, device_ids=cfg.gpu_ids)

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand Down
6 changes: 5 additions & 1 deletion mmcls/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .device import auto_select_device
from .distribution import wrap_distributed_model, wrap_non_distributed_model
from .logger import get_root_logger, load_json_log
from .setup_env import setup_multi_processes

__all__ = [
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes'
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes',
'wrap_non_distributed_model', 'wrap_distributed_model',
'auto_select_device'
]
15 changes: 15 additions & 0 deletions mmcls/utils/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.utils import digit_version


def auto_select_device() -> str:
mmcv_version = digit_version(mmcv.__version__)
if mmcv_version >= digit_version('1.6.0'):
from mmcv.device import get_device
return get_device()
elif torch.cuda.is_available():
return 'cuda'
else:
return 'cpu'
58 changes: 58 additions & 0 deletions mmcls/utils/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.


def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
"""Wrap module in non-distributed environment by device type.
- For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`.
- For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`.
- For CPU & IPU, not wrap the model.
Args:
model(:class:`nn.Module`): model to be parallelized.
device(str): device type, cuda, cpu or mlu. Defaults to cuda.
dim(int): Dimension used to scatter the data. Defaults to 0.
Returns:
model(nn.Module): the model to be parallelized.
"""
if device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
elif device == 'cpu':
model = model.cpu()
elif device == 'ipu':
model = model.cpu()
elif device == 'mps':
from mmcv.device import mps
model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')

return model


def wrap_distributed_model(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
- For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`.
- Other device types are not supported by now.
Args:
model(:class:`nn.Module`): module to be parallelized.
device(str): device type, mlu or cuda.
Returns:
model(:class:`nn.Module`): the module to be parallelized
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
if device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')

return model
2 changes: 1 addition & 1 deletion requirements/mminstall.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
mmcv-full>=1.4.2,<1.6.0
mmcv-full>=1.4.2,<1.7.0
3 changes: 2 additions & 1 deletion tests/test_models/test_backbones/test_timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_timm_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 192))
# Disable the test since TIMM's behavior changes between 0.5.4 and 0.5.5
# assert feat[0].shape == torch.Size((1, 197, 192))


def test_timm_backbone_features_only():
Expand Down
28 changes: 28 additions & 0 deletions tests/test_utils/test_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import patch

import mmcv

from mmcls.utils import auto_select_device


class TestAutoSelectDevice(TestCase):

@patch.object(mmcv, '__version__', '1.6.0')
@patch('mmcv.device.get_device', create=True)
def test_mmcv(self, mock):
auto_select_device()
mock.assert_called_once()

@patch.object(mmcv, '__version__', '1.5.0')
@patch('torch.cuda.is_available', return_value=True)
def test_cuda(self, mock):
device = auto_select_device()
self.assertEqual(device, 'cuda')

@patch.object(mmcv, '__version__', '1.5.0')
@patch('torch.cuda.is_available', return_value=False)
def test_cpu(self, mock):
device = auto_select_device()
self.assertEqual(device, 'cpu')
35 changes: 14 additions & 21 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import numpy as np
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)

from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.models import build_classifier
from mmcls.utils import get_root_logger, setup_multi_processes
from mmcls.utils import (auto_select_device, get_root_logger,
setup_multi_processes, wrap_distributed_model,
wrap_non_distributed_model)


def parse_args():
Expand Down Expand Up @@ -92,11 +93,7 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda', 'ipu'],
default='cuda',
help='device used for testing')
parser.add_argument('--device', help='device used for testing')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
Expand Down Expand Up @@ -130,6 +127,7 @@ def main():
'in `gpu_ids` now.')
else:
cfg.gpu_ids = [args.gpu_id]
cfg.device = args.device or auto_select_device()

# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
Expand All @@ -144,7 +142,7 @@ def main():
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=1 if args.device == 'ipu' else len(cfg.gpu_ids),
num_gpus=1 if cfg.device == 'ipu' else len(cfg.gpu_ids),
dist=distributed,
round_up=True,
)
Expand Down Expand Up @@ -182,29 +180,24 @@ def main():
CLASSES = ImageNet.CLASSES

if not distributed:
if args.device == 'cpu':
model = model.cpu()
elif args.device == 'ipu':
model = wrap_non_distributed_model(
model, device=cfg.device, device_ids=cfg.gpu_ids)
if cfg.device == 'ipu':
from mmcv.device.ipu import cfg2options, ipu_model_wrapper
opts = cfg2options(cfg.runner.get('options_cfg', {}))
if fp16_cfg is not None:
model.half()
model = ipu_model_wrapper(model, opts, fp16_cfg=fp16_cfg)
data_loader.init(opts['inference'])
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \
'To test with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
model.CLASSES = CLASSES
show_kwargs = {} if args.show_options is None else args.show_options
show_kwargs = args.show_options or {}
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
**show_kwargs)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = wrap_distributed_model(
model,
device=cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
Expand Down
8 changes: 5 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from mmcls.apis import init_random_seed, set_random_seed, train_model
from mmcls.datasets import build_dataset
from mmcls.models import build_classifier
from mmcls.utils import collect_env, get_root_logger, setup_multi_processes
from mmcls.utils import (auto_select_device, collect_env, get_root_logger,
setup_multi_processes)


def parse_args():
Expand Down Expand Up @@ -162,7 +163,8 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
seed = init_random_seed(args.seed)
cfg.device = args.device or auto_select_device()
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
Expand Down Expand Up @@ -195,7 +197,7 @@ def main():
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
device=args.device,
device=cfg.device,
meta=meta)


Expand Down

0 comments on commit 4cb6332

Please sign in to comment.