Skip to content

Commit

Permalink
Bugfix/all gather (#5221)
Browse files Browse the repository at this point in the history
* resolve bug

* add tests

* add tests

* resolve flake8

* update

* update

* remove globals

* typo

* Update pytorch_lightning/utilities/distributed.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update

* update

* add suport int, float

* update

* resolve pep8

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update tests/utilities/test_all_gather_grad.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* update doc

* add bool and np.ndarray

* resolve conflicts

* resolve conflicts

* resolve pep8

* add changelog

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Ubuntu <ubuntu@ip-172-31-62-109.ec2.internal>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored Jan 9, 2021
1 parent 48718d7 commit be255de
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377))


- Accelerator `all_gather` supports collection ([#5221](https://github.com/PyTorchLightning/pytorch-lightning/pull/5221))


- Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056))


Expand Down
33 changes: 25 additions & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

"""nn.Module with additional great features."""

from abc import ABC
from argparse import Namespace
import collections
import copy
from functools import partial
import inspect
import os
from pathlib import Path
import re
import tempfile
from abc import ABC
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -35,10 +36,12 @@
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

Expand Down Expand Up @@ -364,7 +367,12 @@ def __auto_choose_log_on_epoch(self, on_epoch):

return on_epoch

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
def all_gather(
self,
data: Union[torch.Tensor, Dict, List, Tuple],
group: Optional[Any] = None,
sync_grads: bool = False,
):
r"""
Allows users to call ``self.all_gather()`` from the LightningModule, thus making
the ```all_gather``` operation accelerator agnostic.
Expand All @@ -373,14 +381,23 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
distributed processes
Args:
tensor: tensor of shape (batch, ...)
tensor: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape.
"""
return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads)
group = group if group is not None else torch.distributed.group.WORLD
if self.trainer.accelerator_backend is not None:
all_gather = self.trainer.accelerator_backend.all_gather
else:
all_gather = all_gather_ddp_if_available

data = convert_to_tensors(data, device=self.device)
all_gather = partial(all_gather, group=group, sync_grads=sync_grads)
return apply_to_collection(data, torch.Tensor, all_gather)

def forward(self, *args, **kwargs):
r"""
Expand Down
48 changes: 39 additions & 9 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from copy import copy
from typing import Any, Callable, Union, Optional
from functools import partial
from typing import Any, Callable, Optional, Union

import numpy as np
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE

if _TORCHTEXT_AVAILABLE:
Expand All @@ -27,11 +30,35 @@
Batch = type(None)


def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
return torch.tensor(value, dtype=dtype, device=device)


def from_numpy(value, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
return torch.from_numpy(value).to(device)


CONVERSION_DTYPES = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
(float, partial(to_dtype_tensor, dtype=torch.float)),
(np.ndarray, from_numpy),
]


def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args,
wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.
Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
Expand All @@ -40,10 +67,8 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
wrong_dtype: the given function won't be applied if this type is specified and the given collections is of
the :attr:`wrong_type` even if it is of type :attr`dtype`
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
Returns:
the resulting collection
"""
elem_type = type(data)

Expand All @@ -67,9 +92,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
class TransferableDataType(ABC):
"""
A custom type for data that can be moved to a torch device via `.to(...)`.
Example:
>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
Expand All @@ -96,15 +119,12 @@ def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
Args:
batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
See :func:`apply_to_collection` for a list of supported collection types.
device: The device to which the data should be moved
Return:
the same collection but with all contained tensors residing on the new device.
See Also:
- :meth:`torch.Tensor.to`
- :class:`torch.device`
Expand All @@ -128,3 +148,13 @@ def batch_to(data):

dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
return apply_to_collection(batch, dtype=dtype, function=batch_to)


def convert_to_tensors(data, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))
return data
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning import _logger as log

if torch.distributed.is_available():
from torch.distributed import ReduceOp, group
from torch.distributed import group, ReduceOp
else:
class ReduceOp:
SUM = None
Expand Down
2 changes: 2 additions & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
56 changes: 55 additions & 1 deletion tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os
import pytest
import sys

import numpy as np
import pytest
import torch

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.utilities import AllGatherGrad
from tests.base.boring_model import BoringModel


def setup_ddp(rank, world_size):
Expand Down Expand Up @@ -41,3 +45,53 @@ def _test_all_gather_ddp(rank, world_size):
def test_all_gather_ddp():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_all_gather_collection(tmpdir):

class TestModel(BoringModel):

training_epoch_end_called = False

def training_epoch_end(self, outputs) -> None:
self.training_epoch_end_called = True
losses = torch.stack([x["loss"] for x in outputs])
gathered_loss = self.all_gather({
"losses_np_ndarray": np.array([1, 2, 3]),
"losses_bool": [True, False],
"losses_float": [0., 1., 2.],
"losses_int": [0, 1, 2],
"losses": losses,
"losses_list": [losses, losses]
})
assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64
# torch.bool can't be all_gathered
assert gathered_loss["losses_bool"][0].dtype == torch.uint8
assert gathered_loss["losses_float"][0].dtype == torch.float
assert gathered_loss["losses_int"][0].dtype == torch.int
assert gathered_loss["losses_list"][0].numel() == 2 * len(losses)
assert gathered_loss["losses"].numel() == 2 * len(losses)

seed_everything(42)

model = TestModel()

limit_train_batches = 8
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
gpus=2,
accelerator="ddp",
)

trainer.fit(model)
assert model.training_epoch_end_called

0 comments on commit be255de

Please sign in to comment.