Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/all gather #5221

Merged
merged 37 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
74e508e
resolve bug
Dec 21, 2020
a7cdde8
add tests
tchaton Dec 21, 2020
34f5d34
add tests
tchaton Dec 21, 2020
b850968
resolve flake8
tchaton Dec 21, 2020
0007566
update
Dec 22, 2020
44d7c57
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
Dec 22, 2020
9205af6
update
tchaton Dec 22, 2020
b7c5df9
remove globals
Dec 28, 2020
a4b740e
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
Dec 28, 2020
c42d4cf
typo
Dec 28, 2020
a109502
Update pytorch_lightning/utilities/distributed.py
tchaton Dec 28, 2020
03b53d8
update
tchaton Dec 28, 2020
ee299a6
update
tchaton Jan 4, 2021
710c2ef
add suport int, float
Jan 4, 2021
053c004
update
tchaton Jan 4, 2021
c34101c
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 4, 2021
04924ae
resolve pep8
tchaton Jan 4, 2021
cccce28
Update pytorch_lightning/core/lightning.py
tchaton Jan 4, 2021
f764832
Update tests/utilities/test_all_gather_grad.py
tchaton Jan 4, 2021
2df387b
update doc
tchaton Jan 5, 2021
8a906d3
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 6, 2021
0ad2f3a
add bool and np.ndarray
tchaton Jan 7, 2021
ce32675
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
tchaton Jan 7, 2021
00813ef
resolve conflicts
tchaton Jan 7, 2021
832bb98
resolve conflicts
tchaton Jan 7, 2021
56e1e35
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 7, 2021
a00a51c
resolve pep8
tchaton Jan 7, 2021
1d27fd1
Merge branch 'bugfix/all_gather' of https://github.com/PyTorchLightni…
tchaton Jan 7, 2021
96b3d98
add changelog
tchaton Jan 7, 2021
e826acd
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 7, 2021
5112f4b
Update pytorch_lightning/core/lightning.py
tchaton Jan 8, 2021
1f90f49
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 8, 2021
da92045
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 8, 2021
6e9141e
update
tchaton Jan 9, 2021
448040e
Merge branch 'release/1.2-dev' into bugfix/all_gather
tchaton Jan 9, 2021
46126a2
resolve bug
tchaton Jan 9, 2021
0df1a98
resolve flake8
tchaton Jan 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tempfile
from abc import ABC
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -38,7 +39,9 @@
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

Expand Down Expand Up @@ -364,7 +367,12 @@ def __auto_choose_log_on_epoch(self, on_epoch):

return on_epoch

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
def all_gather(
self,
data: Union[torch.Tensor, Dict, List, Tuple],
group: Optional[Any] = None,
sync_grads: bool = False,
):
r"""
Allows users to call ``self.all_gather()`` from the LightningModule, thus making
the ```all_gather``` operation accelerator agnostic.
Expand All @@ -373,14 +381,27 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
distributed processes

Args:
tensor: tensor of shape (batch, ...)
tensor: int, float, tensor of shape (batch, ...), or a collection of
int, float, tensor of shape (batch, ...)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape.
"""
return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads)
if self.trainer.accelerator_backend is not None:
all_gather = self.trainer.accelerator_backend.all_gather
else:
all_gather = all_gather_ddp_if_available

def to_dtype_tensor(value, dtype=None):
return torch.tensor(value, dtype=dtype, device=self.device)

data = apply_to_collection(data, float, partial(to_dtype_tensor, dtype=torch.float))
data = apply_to_collection(data, int, partial(to_dtype_tensor, dtype=torch.int))
tchaton marked this conversation as resolved.
Show resolved Hide resolved
all_gather = partial(all_gather, group=group, sync_grads=sync_grads)
return apply_to_collection(data, torch.Tensor, all_gather)

def forward(self, *args, **kwargs):
r"""
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from collections.abc import Mapping, Sequence
from copy import copy
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.exceptions import MisconfigurationException


if torch.distributed.is_available():
from torch.distributed import ReduceOp, group
Expand Down Expand Up @@ -202,6 +204,14 @@ def all_gather_ddp_if_available(
Return:
A tensor of shape (world_size, batch, ...)
"""
if group is None:
group = torch.distributed.group.WORLD
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if group is None:
raise MisconfigurationException(
"The provided group was None and `torch.distributed.group` isn't available."
" Gathering tensor across processes won't be possible."
)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if torch.distributed.is_available() and torch.distributed.is_initialized():
if sync_grads:
return AllGatherGrad.apply(tensor, group)
Expand Down
2 changes: 2 additions & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_properly_works
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
50 changes: 49 additions & 1 deletion tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import pytest
import sys

import pytest
import torch

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import AllGatherGrad
from tests.base.boring_model import BoringModel


def setup_ddp(rank, world_size):
Expand Down Expand Up @@ -41,3 +44,48 @@ def _test_all_gather_ddp(rank, world_size):
def test_all_gather_ddp():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_all_gather_collection(tmpdir):

class TestModel(BoringModel):

training_epoch_end_called = False

def training_epoch_end(self, outputs) -> None:
self.training_epoch_end_called = True
losses = torch.stack([x["loss"] for x in outputs])
gathered_loss = self.all_gather({
"losses_float": [0., 1., 2.],
"losses_int": [0, 1, 2],
"losses": losses,
"losses_list": [losses, losses]
})
assert gathered_loss["losses_float"][0].dtype == torch.float
assert gathered_loss["losses_int"][0].dtype == torch.int
assert gathered_loss["losses_list"][0].numel() == 2 * len(losses)
assert gathered_loss["losses"].numel() == 2 * len(losses)

seed_everything(42)

model = TestModel()

limit_train_batches = 8
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
gpus=2,
accelerator="ddp",
)

trainer.fit(model)
assert model.training_epoch_end_called