Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for uneven inputs in LightningDDP #5141

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
# limitations under the License.

import itertools
import logging
import threading
from collections.abc import Iterable, Mapping
from itertools import chain

import torch
import torch.distributed as dist
from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE
from pytorch_lightning.utilities.warning_utils import WarningCache


Expand Down Expand Up @@ -161,7 +164,30 @@ def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

def forward(self, *inputs, **kwargs): # pragma: no-cover
self._sync_params()
# TODO: Update uneven inputs code path when PyTorch 1.8 is released.
if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
ones = torch.ones(
1, device=self.device
)
work = dist.all_reduce(ones, group=self.process_group, async_op=True)
self.reducer._set_forward_pass_work_handle(
work, self.ddp_join_divide_by_initial_world_size
)

# Calling _rebuild_buckets before forward computation,
# It may allocate new buckets before deallocating old buckets
# inside _rebuild_buckets. To save peak memory usage,
# call _rebuild_buckets before the peak memory usage increases
# during forward computation.
# This should be called only once during whole training period.
if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.reducer._rebuild_buckets():
logging.info("Reducer buckets have been rebuilt in this iteration.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably a debug or rank_zero_debug message


if self.require_forward_param_sync:
self._sync_params()
if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled:
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
self.reducer_reset_hooks()
fx_called: str = ''

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def _module_available(module_path: str) -> bool:
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_BOLTS_AVAILABLE = _module_available('pl_bolts')

DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE = LooseVersion(torch.__version__) >= LooseVersion("1.7.0")

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
Expand Down