diff --git a/CHANGELOG.md b/CHANGELOG.md index 27e5f4be2d04a..db1f3970e0e6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) +- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) + + ## [1.2.1] - 2021-02-23 ### Fixed diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index c2b16303e5d4e..af8cfa7755974 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -19,6 +19,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection class DataParallelPlugin(ParallelPlugin): @@ -46,8 +47,13 @@ def reduce(self, tensor, *args, **kwargs): if isinstance(tensor, Result): tensor.dp_reduce() - elif isinstance(tensor, torch.Tensor): - tensor = tensor.mean() + else: + + def _reduce(tensor: torch.Tensor): + dtype_tensor = tensor.dtype + return tensor.float().mean().type(dtype_tensor) + + tensor = apply_to_collection(tensor, torch.Tensor, _reduce) return tensor diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 8aeb687f1c927..4736c6788c208 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -123,3 +123,22 @@ def test_dp_test(tmpdir): new_weights = model.layer_0.weight.clone().detach().cpu() assert torch.all(torch.eq(old_weights, new_weights)) + + +@RunIf(min_gpus=2) +def test_dp_training_step_dict(tmpdir): + """ + This test verify dp properly reduce dictionaries + """ + + model = BoringModel() + model.training_step_end = None + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + gpus=2, + accelerator='dp', + ) + trainer.fit(model) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 23b2fcbb52235..0e63fc29d49b1 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from collections import OrderedDict from logging import INFO @@ -22,7 +21,7 @@ from torch.nn import Sequential from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -274,6 +273,7 @@ def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog): seed_everything(0) class TestPruning(ModelPruning): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): super().on_save_checkpoint(trainer, pl_module, checkpoint) assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]