diff --git a/CHANGELOG.md b/CHANGELOG.md index c2b504f1c7fc7..ffd88e751bd3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) +- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) + + ## [1.2.1] - 2021-02-23 ### Fixed diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index e1002faf8a3b4..af8cfa7755974 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -19,6 +19,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection class DataParallelPlugin(ParallelPlugin): @@ -31,14 +32,30 @@ def setup(self, model): model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) - def reduce(self, output, *args, **kwargs): - if isinstance(output, Result): - output.dp_reduce() + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all parallel processes to one aggregated tensor. - elif isinstance(output, torch.Tensor): - output = output.mean() + Args: + tensor: the tensor to sync and reduce + *args: ignored for DP + **kwargs: ignored for DP - return output + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + else: + + def _reduce(tensor: torch.Tensor): + dtype_tensor = tensor.dtype + return tensor.float().mean().type(dtype_tensor) + + tensor = apply_to_collection(tensor, torch.Tensor, _reduce) + + return tensor @property def root_device(self): diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 15faf98d94d57..6e6b1be6254e2 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -235,3 +235,22 @@ def test_dp_test(tmpdir): new_weights = model.c_d1.weight.clone().detach().cpu() assert torch.all(torch.eq(old_weights, new_weights)) + + +@RunIf(min_gpus=2) +def test_dp_training_step_dict(tmpdir): + """ + This test verify dp properly reduce dictionaries + """ + + model = BoringModel() + model.training_step_end = None + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + gpus=2, + accelerator='dp', + ) + trainer.fit(model)