Skip to content

Commit

Permalink
[Feature] Support MLU backend. (#1159)
Browse files Browse the repository at this point in the history
* Training on MLU is available
  • Loading branch information
Qiza-lyhm authored Nov 15, 2022
1 parent 05e4bc1 commit dc8691e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
9 changes: 5 additions & 4 deletions mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import (get_root_logger, wrap_distributed_model,
wrap_non_distributed_model)
from mmcls.utils import (auto_select_device, get_root_logger,
wrap_distributed_model, wrap_non_distributed_model)


def init_random_seed(seed=None, device='cuda'):
def init_random_seed(seed=None, device=None):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
Expand All @@ -30,7 +30,8 @@ def init_random_seed(seed=None, device='cuda'):
"""
if seed is not None:
return seed

if device is None:
device = auto_select_device()
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
Expand Down
6 changes: 5 additions & 1 deletion mmcls/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)

from mmcls.utils import auto_select_device


def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
Expand Down Expand Up @@ -59,7 +61,7 @@ def after_train_iter(self, runner):
runner.optimizer.step()


def sync_random_seed(seed=None, device='cuda'):
def sync_random_seed(seed=None, device=None):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
Expand All @@ -81,6 +83,8 @@ def sync_random_seed(seed=None, device='cuda'):
Returns:
int: Seed to be used.
"""
if device is None:
device = auto_select_device()
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)
Expand Down
3 changes: 1 addition & 2 deletions mmcls/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from mmcls.core.utils import sync_random_seed
from mmcls.datasets import SAMPLERS
from mmcls.utils import auto_select_device


@SAMPLERS.register_module()
Expand All @@ -31,7 +30,7 @@ def __init__(self,
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed, device=auto_select_device())
self.seed = sync_random_seed(seed)

def __iter__(self):
# deterministically shuffle based on epoch
Expand Down
12 changes: 12 additions & 0 deletions mmcls/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
elif device == 'mlu':
from mmcv.device.mlu import MLUDataParallel
model = MLUDataParallel(model.mlu(), dim=dim, *args, **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
Expand Down Expand Up @@ -57,6 +60,15 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
from torch.npu import current_device
model = NPUDistributedDataParallel(
model.npu(), *args, device_ids=[current_device()], **kwargs)
elif device == 'mlu':
import os

from mmcv.device.mlu import MLUDistributedDataParallel
model = MLUDistributedDataParallel(
model.mlu(),
*args,
device_ids=[int(os.environ['LOCAL_RANK'])],
**kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
from torch.cuda import current_device
Expand Down

0 comments on commit dc8691e

Please sign in to comment.