Skip to content

Commit

Permalink
Fix sync_dist for tpus (#6950)
Browse files Browse the repository at this point in the history
(cherry picked from commit 1b3e4f9)
  • Loading branch information
kaushikb11 authored and SeanNaren committed Apr 13, 2021
1 parent 60590ee commit d54a883
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))


- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))


Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import Tensor
from torchmetrics import Metric

from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed


class Result(Dict):
Expand Down Expand Up @@ -139,10 +139,11 @@ def log(

# sync across workers when using distributed training
sync_fn = sync_fn or sync_ddp_if_available

if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if is_dist_initialized and isinstance(value, torch.Tensor):
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
import time
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -40,7 +40,6 @@
if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -276,4 +275,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
return xm.all_gather(tensor)
4 changes: 1 addition & 3 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TPU_AVAILABLE,
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

import torch

log = logging.getLogger(__name__)
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
Expand All @@ -34,6 +37,9 @@ class group:
WORLD = None


log = logging.getLogger(__name__)


def rank_zero_only(fn):

@wraps(fn)
Expand Down Expand Up @@ -222,3 +228,9 @@ def all_gather_ddp_if_available(
with torch.no_grad():
return AllGatherGrad.apply(tensor, group)
return tensor


def tpu_distributed() -> bool:
if _TPU_AVAILABLE:
return xm.xrt_world_size() > 1
return False
4 changes: 4 additions & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")

from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
25 changes: 25 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from unittest import mock

import pytest
import torch
from torch.utils.data import DataLoader

import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE
Expand Down Expand Up @@ -397,3 +399,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_sync_dist():
"""Test tpu spawn sync dist operation """

def test_sync_dist(rank):
tensor = torch.tensor([1.0])
training_type_plugin = TPUSpawnPlugin()

res = Result()
res.log(
"test_tensor",
tensor,
sync_fn=training_type_plugin.reduce,
sync_dist=True,
sync_dist_op=torch.distributed.ReduceOp.SUM
)

assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"

xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')

0 comments on commit d54a883

Please sign in to comment.