Skip to content

Commit

Permalink
init npu
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjiangben-hw committed Oct 5, 2022
1 parent 982cab4 commit 6706ba8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
10 changes: 9 additions & 1 deletion mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,14 @@ def train_model(model,
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
if cfg.device == 'npu':
current_device = torch.npu.current_device()
else:
current_device = torch.cuda.current_device()
model = wrap_distributed_model(
model,
cfg.device,
device_ids=[torch.cuda.current_device()],
device_ids=[current_device],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
Expand Down Expand Up @@ -173,6 +177,10 @@ def train_model(model,

# fp16 setting
fp16_cfg = cfg.get('fp16', None)

if fp16_cfg is None and device == 'npu':
fp16_cfg = {'loss_scale': 'dynamic'}

if fp16_cfg is not None:
if device == 'ipu':
from mmcv.device.ipu import IPUFp16OptimizerHook
Expand Down
7 changes: 6 additions & 1 deletion mmcls/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.device.utils import IS_NPU_AVAILABLE
from torch.utils.data import DistributedSampler as _DistributedSampler

from mmcls.core.utils import sync_random_seed
Expand Down Expand Up @@ -30,7 +31,11 @@ def __init__(self,
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
if IS_NPU_AVAILABLE:
device = 'npu'
else:
device = 'cuda'
self.seed = sync_random_seed(seed, device)

def __iter__(self):
# deterministically shuffle based on epoch
Expand Down
10 changes: 8 additions & 2 deletions mmcls/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
Returns:
model(nn.Module): the model to be parallelized.
"""
if device == 'cuda':
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
elif device == 'cpu':
Expand Down Expand Up @@ -49,7 +52,10 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
if device == 'cuda':
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
model = NPUDistributedDataParallel(model.npu(), *args, **kwargs)
elif device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
else:
Expand Down

0 comments on commit 6706ba8

Please sign in to comment.