Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/all gather #5221

Merged
merged 37 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
74e508e
resolve bug
Dec 21, 2020
a7cdde8
add tests
tchaton Dec 21, 2020
34f5d34
add tests
tchaton Dec 21, 2020
b850968
resolve flake8
tchaton Dec 21, 2020
0007566
update
Dec 22, 2020
44d7c57
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
Dec 22, 2020
9205af6
update
tchaton Dec 22, 2020
b7c5df9
remove globals
Dec 28, 2020
a4b740e
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
Dec 28, 2020
c42d4cf
typo
Dec 28, 2020
a109502
Update pytorch_lightning/utilities/distributed.py
tchaton Dec 28, 2020
03b53d8
update
tchaton Dec 28, 2020
ee299a6
update
tchaton Jan 4, 2021
710c2ef
add suport int, float
Jan 4, 2021
053c004
update
tchaton Jan 4, 2021
c34101c
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 4, 2021
04924ae
resolve pep8
tchaton Jan 4, 2021
cccce28
Update pytorch_lightning/core/lightning.py
tchaton Jan 4, 2021
f764832
Update tests/utilities/test_all_gather_grad.py
tchaton Jan 4, 2021
2df387b
update doc
tchaton Jan 5, 2021
8a906d3
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 6, 2021
0ad2f3a
add bool and np.ndarray
tchaton Jan 7, 2021
ce32675
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
tchaton Jan 7, 2021
00813ef
resolve conflicts
tchaton Jan 7, 2021
832bb98
resolve conflicts
tchaton Jan 7, 2021
56e1e35
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 7, 2021
a00a51c
resolve pep8
tchaton Jan 7, 2021
1d27fd1
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
tchaton Jan 7, 2021
96b3d98
add changelog
tchaton Jan 7, 2021
e826acd
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 7, 2021
5112f4b
Update pytorch_lightning/core/lightning.py
tchaton Jan 8, 2021
1f90f49
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 8, 2021
da92045
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 8, 2021
6e9141e
update
tchaton Jan 9, 2021
448040e
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 9, 2021
46126a2
resolve bug
tchaton Jan 9, 2021
0df1a98
resolve flake8
tchaton Jan 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))


- Accelerator `all_gather` supports collection ([#5221](https://github.com/PyTorchLightning/pytorch-lightning/pull/5221))


### Changed

- `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
33 changes: 25 additions & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

"""nn.Module with additional great features."""

from abc import ABC
from argparse import Namespace
import collections
import copy
from functools import partial
import inspect
import os
from pathlib import Path
import re
import tempfile
from abc import ABC
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -35,10 +36,12 @@
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

Expand Down Expand Up @@ -364,7 +367,12 @@ def __auto_choose_log_on_epoch(self, on_epoch):

return on_epoch

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
def all_gather(
self,
data: Union[torch.Tensor, Dict, List, Tuple],
group: Optional[Any] = None,
sync_grads: bool = False,
):
r"""
Allows users to call ``self.all_gather()`` from the LightningModule, thus making
the ```all_gather``` operation accelerator agnostic.
Expand All @@ -373,14 +381,23 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
distributed processes

Args:
tensor: tensor of shape (batch, ...)
tensor: int, float, tensor of shape (batch, ...), or a collection of
int, float, tensor of shape (batch, ...)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape.
"""
return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads)
if self.trainer.accelerator_backend is not None:
all_gather = self.trainer.accelerator_backend.all_gather
else:
all_gather = all_gather_ddp_if_available

data = convert_to_tensors(data, device=self.device)
all_gather = partial(all_gather, group=group, sync_grads=sync_grads)
return apply_to_collection(data, torch.Tensor, all_gather)

def forward(self, *args, **kwargs):
r"""
Expand Down
41 changes: 32 additions & 9 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from copy import copy
from typing import Any, Callable, Union, Optional
from functools import partial
from typing import Any, Callable, Optional, Union

import numpy as np
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE

if _TORCHTEXT_AVAILABLE:
Expand All @@ -27,11 +30,18 @@
Batch = type(None)


CONVERSION_DTYPES = [
(bool, torch.bool),
(int, torch.int),
(float, torch.float),
(np.ndarray, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively you could define a callable (torch.from_numpy) here instead of specifically checking for ndarray below

]


def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args,
wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.

Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
Expand All @@ -40,10 +50,8 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
wrong_dtype: the given function won't be applied if this type is specified and the given collections is of
the :attr:`wrong_type` even if it is of type :attr`dtype`
**kwargs: keyword arguments (will be forwarded to calls of ``function``)

Returns:
the resulting collection

"""
elem_type = type(data)

Expand All @@ -67,9 +75,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
class TransferableDataType(ABC):
"""
A custom type for data that can be moved to a torch device via `.to(...)`.

Example:

>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
Expand All @@ -96,15 +102,12 @@ def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.

Args:
batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
See :func:`apply_to_collection` for a list of supported collection types.
device: The device to which the data should be moved

Return:
the same collection but with all contained tensors residing on the new device.

See Also:
- :meth:`torch.Tensor.to`
- :class:`torch.device`
Expand All @@ -128,3 +131,23 @@ def batch_to(data):

dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
return apply_to_collection(batch, dtype=dtype, function=batch_to)


def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
if isinstance(value, np.ndarray):
return torch.from_numpy(value).to(device)
return torch.tensor(value, dtype=dtype, device=device)


def convert_to_tensors(data, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
for src_dtype, dst_dtype in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, partial(to_dtype_tensor, dtype=dst_dtype, device=device))
return data
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning import _logger as log

if torch.distributed.is_available():
from torch.distributed import ReduceOp, group
from torch.distributed import group, ReduceOp
else:
class ReduceOp:
SUM = None
Expand Down
2 changes: 2 additions & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_properly_works
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
55 changes: 54 additions & 1 deletion tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os
import pytest
import sys

import pytest
import torch
import numpy as np

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import AllGatherGrad
from tests.base.boring_model import BoringModel


def setup_ddp(rank, world_size):
Expand Down Expand Up @@ -41,3 +45,52 @@ def _test_all_gather_ddp(rank, world_size):
def test_all_gather_ddp():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_all_gather_collection(tmpdir):

class TestModel(BoringModel):

training_epoch_end_called = False

def training_epoch_end(self, outputs) -> None:
self.training_epoch_end_called = True
losses = torch.stack([x["loss"] for x in outputs])
gathered_loss = self.all_gather({
"losses_np_ndarray": np.array([1, 2, 3]),
"losses_bool": [True, False],
"losses_float": [0., 1., 2.],
"losses_int": [0, 1, 2],
"losses": losses,
"losses_list": [losses, losses]
})
assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64
assert gathered_loss["losses_bool"][0].dtype == torch.bool
assert gathered_loss["losses_float"][0].dtype == torch.float
assert gathered_loss["losses_int"][0].dtype == torch.int
assert gathered_loss["losses_list"][0].numel() == 2 * len(losses)
assert gathered_loss["losses"].numel() == 2 * len(losses)

seed_everything(42)

model = TestModel()

limit_train_batches = 8
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
gpus=2,
accelerator="ddp",
)

trainer.fit(model)
assert model.training_epoch_end_called