Skip to content

Commit

Permalink
[bugfix] Perform reduction for dict in training_step and DP (#6324)
Browse files Browse the repository at this point in the history
* fix

* update

* update

* add changelog

* Update CHANGELOG.md

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update tests/accelerators/test_dp.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* update changelog

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit 248a8e8)
  • Loading branch information
tchaton authored and SeanNaren committed Mar 16, 2021
1 parent 4fc6e51 commit d202151
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
29 changes: 23 additions & 6 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection


class DataParallelPlugin(ParallelPlugin):
Expand All @@ -31,14 +32,30 @@ def setup(self, model):
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)

def reduce(self, output, *args, **kwargs):
if isinstance(output, Result):
output.dp_reduce()
def reduce(self, tensor, *args, **kwargs):
"""
Reduces a tensor from all parallel processes to one aggregated tensor.
elif isinstance(output, torch.Tensor):
output = output.mean()
Args:
tensor: the tensor to sync and reduce
*args: ignored for DP
**kwargs: ignored for DP
return output
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, Result):
tensor.dp_reduce()

else:

def _reduce(tensor: torch.Tensor):
dtype_tensor = tensor.dtype
return tensor.float().mean().type(dtype_tensor)

tensor = apply_to_collection(tensor, torch.Tensor, _reduce)

return tensor

@property
def root_device(self):
Expand Down
19 changes: 19 additions & 0 deletions tests/accelerators/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,22 @@ def test_dp_test(tmpdir):
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))


@RunIf(min_gpus=2)
def test_dp_training_step_dict(tmpdir):
"""
This test verify dp properly reduce dictionaries
"""

model = BoringModel()
model.training_step_end = None
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=0,
gpus=2,
accelerator='dp',
)
trainer.fit(model)

0 comments on commit d202151

Please sign in to comment.