Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] uneven input support for DDP #14284

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
from typing import Any, cast, Iterable, Iterator, List, Optional, Sized, Union

import torch
from torch import distributed as dist
from torch import Tensor
from torch.distributed.algorithms.join import JoinHook
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
from typing_extensions import Self

import pytorch_lightning as pl
from lightning_lite.utilities.distributed import _DatasetSamplerWrapper
Expand Down Expand Up @@ -147,3 +151,79 @@ def batch_size(self) -> int:
@property
def sampler(self) -> Union[Sampler, Iterable]:
return self._sampler.sampler


class DistCallRecorder:
_PATCH_NAMES = [
"send",
"recv",
"broadcast",
"all_reduce",
"reduce",
"all_gather",
"gather",
"scatter",
"reduce_scatter",
"all_to_all",
"barrier",
]

def __init__(self, patch_names: Optional[List[str]] = None) -> None:
self.patch_names = self._PATCH_NAMES if patch_names is None else patch_names
self.patched = dict()
self.recorded_calls = []

def __enter__(self) -> Self:
self.recorded_calls = []
self.patched = dict()
for name in self.patch_names:
if hasattr(dist, name):
orig_fn = getattr(dist, name)
self.patched[name] = orig_fn
patched_fn = self._create_patch(name, orig_fn)
setattr(dist, name, patched_fn)
return self

def __exit__(self, _, __, ___) -> None:
for name, fn in self.patched.items():
setattr(dist, name, fn)
self.recorded_calls = []
self.patched = dict()

def _create_patch(self, name, fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
self.recorded_calls.append((name, args, kwargs))
return fn(*args, **kwargs)

return wrapper


class LightningJoinHook(JoinHook):
def __init__(self, ddp, divide_by_initial_world_size, dist_recorder: DistCallRecorder):
"""Sets config variables for internal usage."""
assert isinstance(ddp, DistributedDataParallel), (
"DDP join hook requires passing in a DistributedDataParallel " "instance as the state"
)
ddp.logger._set_uneven_input_join()
self.ddp = ddp
self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
self.dist_recorder = dist_recorder
super().__init__()

def main_hook(self) -> None:
while self.dist_recorder.recorded_calls:
recorded_name, args, kwargs = self.dist_recorder.recorded_calls.pop(0)
self.dist_recorder.patched[recorded_name](*args, **kwargs)

def post_hook(self, is_last_joiner: bool) -> None:
return super().post_hook(is_last_joiner)


class LightningDistributedDataParallel(DistributedDataParallel):
def join_hook(self, **kwargs):
divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
dist_recorder = kwargs.get("dist_recorder", DistCallRecorder())
return LightningJoinHook(
self, divide_by_initial_world_size=divide_by_initial_world_size, dist_recorder=dist_recorder
)
13 changes: 10 additions & 3 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torch.distributed
from torch import Tensor
from torch.distributed.algorithms.join import Join
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim.optimizer import Optimizer
Expand All @@ -41,7 +42,11 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.distributed import (
DistCallRecorder,
LightningDistributedDataParallel,
prepare_for_backward,
)
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
Expand Down Expand Up @@ -188,7 +193,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
return LightningDistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

def setup_distributed(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
Expand Down Expand Up @@ -345,7 +350,9 @@ def reduce(
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)
with DistCallRecorder() as recorder:
with Join([self.model], enable=True, dist_recorder=recorder):
Copy link
Contributor

@carmocca carmocca Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be opt-in (Maybe you set enable=True by default whilst it's WIP)

return self.model(*args, **kwargs)

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
Expand Down