Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sync_dist for tpus #6950

Merged
merged 12 commits into from
Apr 13, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
self.precision_plugin.pre_dispatch()

def post_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor
from torchmetrics import Metric

from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed


class Result(Dict):
Expand Down Expand Up @@ -105,10 +105,11 @@ def log(

# sync across workers when using distributed training
sync_fn = sync_fn or sync_ddp_if_available

if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if is_dist_initialized and isinstance(value, torch.Tensor):
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
import time
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -41,7 +41,6 @@
if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -278,4 +277,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
return xm.all_gather(tensor)
4 changes: 1 addition & 3 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TPU_AVAILABLE,
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
Expand Down
29 changes: 15 additions & 14 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
import warnings
from functools import partial, wraps
from typing import Any, Optional, Union
from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
)

import torch

from torch.nn.parallel.distributed import DistributedDataParallel

log = logging.getLogger(__name__)
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
Expand All @@ -40,6 +38,9 @@ class group:
WORLD = None


log = logging.getLogger(__name__)


def rank_zero_only(fn):

@wraps(fn)
Expand Down Expand Up @@ -294,19 +295,13 @@ def register_ddp_comm_hook(
)
"""
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn(
"Not registering DDP comm hook. "
"To use communication hooks, please use pytorch>=1.8.0."
)
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
return
if ddp_comm_hook is None:
return
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn(
"Not applying DDP comm wrapper. "
"To use communication wrapper, please use pytorch>=1.9.0."
)
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
else:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
Expand All @@ -318,3 +313,9 @@ def register_ddp_comm_hook(
state=ddp_comm_state,
hook=ddp_comm_hook,
)


def tpu_distributed():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat !

kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
if _TPU_AVAILABLE:
return xm.xrt_world_size() > 1
return False
4 changes: 4 additions & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")

from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
carmocca marked this conversation as resolved.
Show resolved Hide resolved

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
25 changes: 25 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from unittest import mock

import pytest
import torch
from torch.utils.data import DataLoader

import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE
Expand Down Expand Up @@ -416,3 +418,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_sync_dist():
"""Test tpu spawn sync dist operation """

def test_sync_dist(rank):
tensor = torch.tensor([1.0])
training_type_plugin = TPUSpawnPlugin()

res = Result()
res.log(
"test_tensor",
tensor,
sync_fn=training_type_plugin.reduce,
sync_dist=True,
sync_dist_op=torch.distributed.ReduceOp.SUM
)

assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"

xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')