From 7b38ba7423c654ec279f4eed8c3f4b1a98fd0aa3 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Mon, 14 Dec 2020 18:40:26 -0800 Subject: [PATCH] [Feat] Added uneven input support/sync with upstream DDP --- pytorch_lightning/overrides/data_parallel.py | 28 +++++++++++++++++++- pytorch_lightning/utilities/__init__.py | 2 ++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 393138fff9248..631c55090115d 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -13,17 +13,20 @@ # limitations under the License. import itertools +import logging import threading from collections.abc import Iterable, Mapping from itertools import chain import torch +import torch.distributed as dist from torch.cuda._utils import _get_device_index from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel._functions import Gather from pytorch_lightning.core.step_result import Result +from pytorch_lightning.utilities import DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE from pytorch_lightning.utilities.warning_utils import WarningCache @@ -161,7 +164,30 @@ def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) def forward(self, *inputs, **kwargs): # pragma: no-cover - self._sync_params() + # TODO: Update uneven inputs code path when PyTorch 1.8 is released. + if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled: + ones = torch.ones( + 1, device=self.device + ) + work = dist.all_reduce(ones, group=self.process_group, async_op=True) + self.reducer._set_forward_pass_work_handle( + work, self.ddp_join_divide_by_initial_world_size + ) + + # Calling _rebuild_buckets before forward computation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.reducer._rebuild_buckets(): + logging.info("Reducer buckets have been rebuilt in this iteration.") + + if self.require_forward_param_sync: + self._sync_params() + if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) self.reducer_reset_hooks() fx_called: str = '' diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 6cb2cac438714..f06fd0a8417c1 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -62,6 +62,8 @@ def _module_available(module_path: str) -> bool: _FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0") _BOLTS_AVAILABLE = _module_available('pl_bolts') +DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE = LooseVersion(torch.__version__) >= LooseVersion("1.7.0") + FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps