Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support number for logging with sync_dist=True #5080

Merged
merged 17 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def log(
sync_dist_group,
accelerator.sync_tensor,
self._current_dataloader_idx,
self.device,
)

def log_dict(
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""

import numbers
import os
from copy import copy
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
import os

from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.metrics import Metric
from pytorch_lightning.utilities.distributed import sync_ddp_if_available


class Result(Dict):
Expand Down Expand Up @@ -128,6 +128,7 @@ def log(
sync_dist_group: Optional[Any] = None,
sync_fn: Callable = None,
dataloader_idx: Optional[int] = None,
device: torch.device = None,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -138,7 +139,10 @@ def log(
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
value = value.clone() if is_dist_initialized else value
if is_dist_initialized and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

if 'meta' not in self:
Expand Down
15 changes: 7 additions & 8 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import os
import warnings
from functools import wraps
from typing import Any, Optional, Union

import torch

from pytorch_lightning import _logger as log
from typing import Union, Optional, Any

if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch.distributed import group
from torch.distributed import ReduceOp, group
else:
class ReduceOp:
SUM = None
Expand Down Expand Up @@ -145,15 +145,14 @@ def sync_ddp(
if group is None:
group = torch.distributed.group.WORLD

if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = torch.distributed.ReduceOp.SUM
op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True

# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)

if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
Expand Down
2 changes: 1 addition & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
39 changes: 39 additions & 0 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
import itertools
import os
import platform
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -685,6 +686,7 @@ class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
Borda marked this conversation as resolved.
Show resolved Hide resolved
return acc

def validation_step(self, batch, batch_idx):
Expand All @@ -704,9 +706,46 @@ def validation_step(self, batch, batch_idx):
trainer.fit(model)

assert trainer.logged_metrics['foo'] == fake_result
assert trainer.logged_metrics['foo_2'] == 2
assert trainer.logged_metrics['bar'] == fake_result


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_logging_sync_dist_true_ddp(tmpdir):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
Tests to ensure that the sync_dist flag works with ddp
"""
class TestLoggingSyncDistModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM')
return acc

def validation_step(self, batch, batch_idx):
self.training_step_called = True
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG')
return {"x": loss}

model = TestLoggingSyncDistModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
weights_summary=None,
accelerator="ddp",
gpus=2,
)
trainer.fit(model)

assert trainer.logged_metrics['foo'] == 2
assert trainer.logged_metrics['bar'] == 2


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_logging_sync_dist_true_gpu(tmpdir):
"""
Expand Down