Skip to content

Commit

Permalink
Fix sync_dist for tpus (#6950)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Apr 13, 2021
1 parent 80c5293 commit 1b3e4f9
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))


- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
self.precision_plugin.pre_dispatch()

def post_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor
from torchmetrics import Metric

from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed


class Result(Dict):
Expand Down Expand Up @@ -105,10 +105,11 @@ def log(

# sync across workers when using distributed training
sync_fn = sync_fn or sync_ddp_if_available

if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if is_dist_initialized and isinstance(value, torch.Tensor):
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
import time
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -41,7 +41,6 @@
if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -278,4 +277,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
return xm.all_gather(tensor)
4 changes: 1 addition & 3 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TPU_AVAILABLE,
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
Expand Down
29 changes: 15 additions & 14 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
import warnings
from functools import partial, wraps
from typing import Any, Optional, Union
from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
)

import torch

from torch.nn.parallel.distributed import DistributedDataParallel

log = logging.getLogger(__name__)
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
Expand All @@ -40,6 +38,9 @@ class group:
WORLD = None


log = logging.getLogger(__name__)


def rank_zero_only(fn):

@wraps(fn)
Expand Down Expand Up @@ -294,19 +295,13 @@ def register_ddp_comm_hook(
)
"""
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn(
"Not registering DDP comm hook. "
"To use communication hooks, please use pytorch>=1.8.0."
)
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
return
if ddp_comm_hook is None:
return
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn(
"Not applying DDP comm wrapper. "
"To use communication wrapper, please use pytorch>=1.9.0."
)
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
else:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
Expand All @@ -318,3 +313,9 @@ def register_ddp_comm_hook(
state=ddp_comm_state,
hook=ddp_comm_hook,
)


def tpu_distributed() -> bool:
if _TPU_AVAILABLE:
return xm.xrt_world_size() > 1
return False
4 changes: 4 additions & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")

from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
25 changes: 25 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from unittest import mock

import pytest
import torch
from torch.utils.data import DataLoader

import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE
Expand Down Expand Up @@ -416,3 +418,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_sync_dist():
"""Test tpu spawn sync dist operation """

def test_sync_dist(rank):
tensor = torch.tensor([1.0])
training_type_plugin = TPUSpawnPlugin()

res = Result()
res.log(
"test_tensor",
tensor,
sync_fn=training_type_plugin.reduce,
sync_dist=True,
sync_dist_op=torch.distributed.ReduceOp.SUM
)

assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"

xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')

0 comments on commit 1b3e4f9

Please sign in to comment.