Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support MPS device. #894

Merged
merged 4 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmcls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def digit_version(version_str: str, length: int = 4):


mmcv_minimum_version = '1.4.2'
mmcv_maximum_version = '1.6.0'
mmcv_maximum_version = '1.7.0'
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
mmcv_version = digit_version(mmcv.__version__)


Expand Down
26 changes: 7 additions & 19 deletions mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
build_optimizer, build_runner, get_dist_info)

from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger
from mmcls.utils import (get_root_logger, wrap_distributed_model,
wrap_non_distributed_model)


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -128,27 +128,15 @@ def train_model(model,
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
model = wrap_distributed_model(
model,
cfg.device,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if device == 'cpu':
warnings.warn(
'The argument `device` is deprecated. To use cpu to train, '
'please refers to https://mmclassification.readthedocs.io/en'
'/latest/getting_started.html#train-a-model')
model = model.cpu()
elif device == 'ipu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
from mmcv import __version__, digit_version
assert digit_version(__version__) >= (1, 4, 4), \
'To train with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
model = wrap_non_distributed_model(
model, cfg.device, device_ids=cfg.gpu_ids)

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand Down
6 changes: 5 additions & 1 deletion mmcls/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .device import auto_select_device
from .distribution import wrap_distributed_model, wrap_non_distributed_model
from .logger import get_root_logger, load_json_log
from .setup_env import setup_multi_processes

__all__ = [
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes'
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes',
'wrap_non_distributed_model', 'wrap_distributed_model',
'auto_select_device'
]
15 changes: 15 additions & 0 deletions mmcls/utils/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.utils import digit_version


def auto_select_device() -> str:
mmcv_version = digit_version(mmcv.__version__)
if mmcv_version >= digit_version('1.6.0'):
from mmcv.device import get_device
return get_device()
elif torch.cuda.is_available():
return 'cuda'
else:
return 'cpu'
58 changes: 58 additions & 0 deletions mmcls/utils/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.


def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
"""Wrap module in non-distributed environment by device type.

- For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`.
- For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`.
- For CPU & IPU, not wrap the model.

Args:
model(:class:`nn.Module`): model to be parallelized.
device(str): device type, cuda, cpu or mlu. Defaults to cuda.
dim(int): Dimension used to scatter the data. Defaults to 0.

Returns:
model(nn.Module): the model to be parallelized.
"""
if device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
elif device == 'cpu':
model = model.cpu()
elif device == 'ipu':
model = model.cpu()
elif device == 'mps':
from mmcv.device import mps
model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')

return model


def wrap_distributed_model(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.

- For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`.
- Other device types are not supported by now.

Args:
model(:class:`nn.Module`): module to be parallelized.
device(str): device type, mlu or cuda.

Returns:
model(:class:`nn.Module`): the module to be parallelized

References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
if device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')

return model
35 changes: 14 additions & 21 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import numpy as np
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)

from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.models import build_classifier
from mmcls.utils import get_root_logger, setup_multi_processes
from mmcls.utils import (auto_select_device, get_root_logger,
setup_multi_processes, wrap_distributed_model,
wrap_non_distributed_model)


def parse_args():
Expand Down Expand Up @@ -92,11 +93,7 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda', 'ipu'],
default='cuda',
help='device used for testing')
parser.add_argument('--device', help='device used for testing')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
Expand Down Expand Up @@ -130,6 +127,7 @@ def main():
'in `gpu_ids` now.')
else:
cfg.gpu_ids = [args.gpu_id]
cfg.device = args.device or auto_select_device()

# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
Expand All @@ -144,7 +142,7 @@ def main():
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=1 if args.device == 'ipu' else len(cfg.gpu_ids),
num_gpus=1 if cfg.device == 'ipu' else len(cfg.gpu_ids),
dist=distributed,
round_up=True,
)
Expand Down Expand Up @@ -182,29 +180,24 @@ def main():
CLASSES = ImageNet.CLASSES

if not distributed:
if args.device == 'cpu':
model = model.cpu()
elif args.device == 'ipu':
model = wrap_non_distributed_model(
model, device=cfg.device, device_ids=cfg.gpu_ids)
if cfg.device == 'ipu':
from mmcv.device.ipu import cfg2options, ipu_model_wrapper
opts = cfg2options(cfg.runner.get('options_cfg', {}))
if fp16_cfg is not None:
model.half()
model = ipu_model_wrapper(model, opts, fp16_cfg=fp16_cfg)
data_loader.init(opts['inference'])
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \
'To test with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
model.CLASSES = CLASSES
show_kwargs = {} if args.show_options is None else args.show_options
show_kwargs = args.show_options or {}
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
**show_kwargs)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = wrap_distributed_model(
model,
device=cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
Expand Down
8 changes: 5 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from mmcls.apis import init_random_seed, set_random_seed, train_model
from mmcls.datasets import build_dataset
from mmcls.models import build_classifier
from mmcls.utils import collect_env, get_root_logger, setup_multi_processes
from mmcls.utils import (auto_select_device, collect_env, get_root_logger,
setup_multi_processes)


def parse_args():
Expand Down Expand Up @@ -162,7 +163,8 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
seed = init_random_seed(args.seed)
cfg.device = args.device or auto_select_device()
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
Expand Down Expand Up @@ -195,7 +197,7 @@ def main():
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
device=args.device,
device=cfg.device,
meta=meta)


Expand Down