Skip to content

Commit

Permalink
support number for logging with sync_dist=True (#5080)
Browse files Browse the repository at this point in the history
* support number

* add two tests

* wip

* add ddp in special test

* remove a test

* move device to bottom

* simplify test

* update test

* Update pytorch_lightning/core/step_result.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* resolve sync_ddp

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and Borda committed Dec 29, 2020
1 parent e975c98 commit 25640bc
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 13 deletions.
1 change: 1 addition & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def log(
sync_dist_group,
accelerator.sync_tensor,
self._current_dataloader_idx,
self.device,
)

def log_dict(
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""

import numbers
import os
from copy import copy
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
import os

from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.metrics import Metric
from pytorch_lightning.utilities.distributed import sync_ddp_if_available


class Result(Dict):
Expand Down Expand Up @@ -128,6 +128,7 @@ def log(
sync_dist_group: Optional[Any] = None,
sync_fn: Callable = None,
dataloader_idx: Optional[int] = None,
device: torch.device = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -138,7 +139,10 @@ def log(
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
value = value.clone() if is_dist_initialized else value
if is_dist_initialized and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

if 'meta' not in self:
Expand Down
15 changes: 7 additions & 8 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import os
import warnings
from functools import wraps
from typing import Any, Optional, Union

import torch

from pytorch_lightning import _logger as log
from typing import Union, Optional, Any

if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch.distributed import group
from torch.distributed import ReduceOp, group
else:
class ReduceOp:
SUM = None
Expand Down Expand Up @@ -145,15 +145,14 @@ def sync_ddp(
if group is None:
group = torch.distributed.group.WORLD

if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = torch.distributed.ReduceOp.SUM
op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True

# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)

if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
Expand Down
2 changes: 1 addition & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
39 changes: 39 additions & 0 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
import itertools
import os
import platform
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -686,6 +687,7 @@ class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
return acc

def validation_step(self, batch, batch_idx):
Expand All @@ -705,9 +707,46 @@ def validation_step(self, batch, batch_idx):
trainer.fit(model)

assert trainer.logged_metrics['foo'] == fake_result
assert trainer.logged_metrics['foo_2'] == 2
assert trainer.logged_metrics['bar'] == fake_result


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_logging_sync_dist_true_ddp(tmpdir):
"""
Tests to ensure that the sync_dist flag works with ddp
"""
class TestLoggingSyncDistModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM')
return acc

def validation_step(self, batch, batch_idx):
self.training_step_called = True
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG')
return {"x": loss}

model = TestLoggingSyncDistModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
weights_summary=None,
accelerator="ddp",
gpus=2,
)
trainer.fit(model)

assert trainer.logged_metrics['foo'] == 2
assert trainer.logged_metrics['bar'] == 2


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_logging_sync_dist_true_gpu(tmpdir):
"""
Expand Down

0 comments on commit 25640bc

Please sign in to comment.