forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support MPS device. (open-mmlab#894)
* [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests
- Loading branch information
Showing
10 changed files
with
136 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .collect_env import collect_env | ||
from .device import auto_select_device | ||
from .distribution import wrap_distributed_model, wrap_non_distributed_model | ||
from .logger import get_root_logger, load_json_log | ||
from .setup_env import setup_multi_processes | ||
|
||
__all__ = [ | ||
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes' | ||
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes', | ||
'wrap_non_distributed_model', 'wrap_distributed_model', | ||
'auto_select_device' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import mmcv | ||
import torch | ||
from mmcv.utils import digit_version | ||
|
||
|
||
def auto_select_device() -> str: | ||
mmcv_version = digit_version(mmcv.__version__) | ||
if mmcv_version >= digit_version('1.6.0'): | ||
from mmcv.device import get_device | ||
return get_device() | ||
elif torch.cuda.is_available(): | ||
return 'cuda' | ||
else: | ||
return 'cpu' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
|
||
def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): | ||
"""Wrap module in non-distributed environment by device type. | ||
- For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`. | ||
- For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`. | ||
- For CPU & IPU, not wrap the model. | ||
Args: | ||
model(:class:`nn.Module`): model to be parallelized. | ||
device(str): device type, cuda, cpu or mlu. Defaults to cuda. | ||
dim(int): Dimension used to scatter the data. Defaults to 0. | ||
Returns: | ||
model(nn.Module): the model to be parallelized. | ||
""" | ||
if device == 'cuda': | ||
from mmcv.parallel import MMDataParallel | ||
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs) | ||
elif device == 'cpu': | ||
model = model.cpu() | ||
elif device == 'ipu': | ||
model = model.cpu() | ||
elif device == 'mps': | ||
from mmcv.device import mps | ||
model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs) | ||
else: | ||
raise RuntimeError(f'Unavailable device "{device}"') | ||
|
||
return model | ||
|
||
|
||
def wrap_distributed_model(model, device='cuda', *args, **kwargs): | ||
"""Build DistributedDataParallel module by device type. | ||
- For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`. | ||
- Other device types are not supported by now. | ||
Args: | ||
model(:class:`nn.Module`): module to be parallelized. | ||
device(str): device type, mlu or cuda. | ||
Returns: | ||
model(:class:`nn.Module`): the module to be parallelized | ||
References: | ||
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. | ||
DistributedDataParallel.html | ||
""" | ||
if device == 'cuda': | ||
from mmcv.parallel import MMDistributedDataParallel | ||
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs) | ||
else: | ||
raise RuntimeError(f'Unavailable device "{device}"') | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
mmcv-full>=1.4.2,<1.6.0 | ||
mmcv-full>=1.4.2,<1.7.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from unittest import TestCase | ||
from unittest.mock import patch | ||
|
||
import mmcv | ||
|
||
from mmcls.utils import auto_select_device | ||
|
||
|
||
class TestAutoSelectDevice(TestCase): | ||
|
||
@patch.object(mmcv, '__version__', '1.6.0') | ||
@patch('mmcv.device.get_device', create=True) | ||
def test_mmcv(self, mock): | ||
auto_select_device() | ||
mock.assert_called_once() | ||
|
||
@patch.object(mmcv, '__version__', '1.5.0') | ||
@patch('torch.cuda.is_available', return_value=True) | ||
def test_cuda(self, mock): | ||
device = auto_select_device() | ||
self.assertEqual(device, 'cuda') | ||
|
||
@patch.object(mmcv, '__version__', '1.5.0') | ||
@patch('torch.cuda.is_available', return_value=False) | ||
def test_cpu(self, mock): | ||
device = auto_select_device() | ||
self.assertEqual(device, 'cpu') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters