Skip to content

Commit

Permalink
[Fix] fix train example (#1502)
Browse files Browse the repository at this point in the history
* [Fix] fix train example

* [Fix] fix a detail in train example and add warning in MMDP

* [Fix] fix docstring

* [Fix] fix docstring
  • Loading branch information
teamwong111 authored Nov 23, 2021
1 parent e85c43a commit b57825f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 3 additions & 1 deletion examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def train_step(self, data, optimizer):
if __name__ == '__main__':
model = Model()
if torch.cuda.is_available():
model = MMDataParallel(model.cuda())
# only use gpu:0 to train
# Solved issue https://github.com/open-mmlab/mmcv/issues/1470
model = MMDataParallel(model.cuda(), device_ids=[0])

# dataset and dataloader
transform = transforms.Compose([
Expand Down
10 changes: 9 additions & 1 deletion mmcv/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class MMDataParallel(DataParallel):
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
train with multiple GPUs, please use MMDistributedDataParallel
instead. If you have multiple GPUs and you just want to use
MMDataParallel, you can set the environment variable
``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
``device_ids=[0]``.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
device_ids (list[int]): Device IDS of modules to be scattered to.
Expand Down Expand Up @@ -54,7 +62,7 @@ def train_step(self, *inputs, **kwargs):
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.')
' instead.')

for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
Expand Down

0 comments on commit b57825f

Please sign in to comment.