diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index edad1be868cad..4e0d366e8a7c5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -276,6 +276,7 @@ def log( sync_dist_group, accelerator.sync_tensor, self._current_dataloader_idx, + self.device, ) def log_dict( diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 7101ec17c4bbc..64fe5b58651f7 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -15,15 +15,15 @@ """[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction.""" import numbers +import os from copy import copy -from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable +from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor -import os -from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.metrics import Metric +from pytorch_lightning.utilities.distributed import sync_ddp_if_available class Result(Dict): @@ -128,6 +128,7 @@ def log( sync_dist_group: Optional[Any] = None, sync_fn: Callable = None, dataloader_idx: Optional[int] = None, + device: torch.device = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -138,7 +139,10 @@ def log( if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() # TODO: Find a way to make the reduction only once, so we don't need to clone. - value = value.clone() if is_dist_initialized else value + if is_dist_initialized and isinstance(value, torch.Tensor): + value = value.clone() + else: + value = torch.tensor(value, device=device, dtype=torch.float) value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) if 'meta' not in self: diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 9724f05247c00..c315c6633b6fb 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,14 +15,14 @@ import os import warnings from functools import wraps +from typing import Any, Optional, Union import torch + from pytorch_lightning import _logger as log -from typing import Union, Optional, Any if torch.distributed.is_available(): - from torch.distributed import ReduceOp - from torch.distributed import group + from torch.distributed import ReduceOp, group else: class ReduceOp: SUM = None @@ -145,15 +145,14 @@ def sync_ddp( if group is None: group = torch.distributed.group.WORLD - if reduce_op is None: - reduce_op = torch.distributed.ReduceOp.SUM - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): - reduce_op = torch.distributed.ReduceOp.SUM + op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): divide_by_world_size = True # sync all processes before reduction torch.distributed.barrier(group=group) - torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) + torch.distributed.all_reduce(result, op=op, group=group, async_op=False) if divide_by_world_size: result = result / torch.distributed.get_world_size(group) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index f7cb581951783..950e3776bbc7f 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -19,4 +19,4 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic -# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance +python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index a77b4eb451e28..6b021462129ef 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -18,6 +18,7 @@ import collections import itertools import os +import platform from unittest import mock import numpy as np @@ -686,6 +687,7 @@ class TestModel(BoringModel): def training_step(self, batch, batch_idx): acc = self.step(batch[0]) self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') return acc def validation_step(self, batch, batch_idx): @@ -705,9 +707,46 @@ def validation_step(self, batch, batch_idx): trainer.fit(model) assert trainer.logged_metrics['foo'] == fake_result + assert trainer.logged_metrics['foo_2'] == 2 assert trainer.logged_metrics['bar'] == fake_result +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_logging_sync_dist_true_ddp(tmpdir): + """ + Tests to ensure that the sync_dist flag works with ddp + """ + class TestLoggingSyncDistModel(BoringModel): + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM') + return acc + + def validation_step(self, batch, batch_idx): + self.training_step_called = True + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG') + return {"x": loss} + + model = TestLoggingSyncDistModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + weights_summary=None, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + assert trainer.logged_metrics['foo'] == 2 + assert trainer.logged_metrics['bar'] == 2 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_logging_sync_dist_true_gpu(tmpdir): """