Skip to content

Commit

Permalink
Fix TPU Spawn gather (#6896)
Browse files Browse the repository at this point in the history
(cherry picked from commit 5552503)
  • Loading branch information
kaushikb11 authored and SeanNaren committed Apr 13, 2021
1 parent f895e9f commit 8245540
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 39 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))


- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896))


- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


Expand Down
46 changes: 27 additions & 19 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from typing import Any, Callable, Optional, Union
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.accelerators.accelerator import Accelerator
Expand All @@ -16,10 +28,19 @@

xla_clip_grad_norm_ = clip_grad_norm_

if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer


class TPUAccelerator(Accelerator):

def setup(self, trainer, model):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training.
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. "
Expand All @@ -30,24 +51,11 @@ def setup(self, trainer, model):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
Return:
A tensor of shape (world_size, batch, ...)
"""
# todo: Add support for backward with all_gather
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
return xm.all_gather(tensor).view(-1, *tensor.shape)
return tensor

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):

model = self.lightning_module
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,14 @@ def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = torch.tensor(int(decision), device=self.lightning_module.device)
decision = self.reduce(decision, reduce_op="sum")
decision = bool(decision == self.world_size)
return decision

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, torch.Tensor):
output = torch.tensor(output, device=self.device)
output = torch.tensor(output, device=self.lightning_module.device)

_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
Expand Down Expand Up @@ -265,3 +265,15 @@ def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None:
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))
42 changes: 25 additions & 17 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def val_dataloader(self):
return DataLoader(RandomDataset(32, 2000), batch_size=32)


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_model_tpu_cores_1(tmpdir):
"""Make sure model trains on TPU."""
Expand All @@ -73,7 +73,7 @@ def test_model_tpu_cores_1(tmpdir):


@pytest.mark.parametrize('tpu_core', [1, 5])
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_model_tpu_index(tmpdir, tpu_core):
"""Make sure model trains on TPU."""
Expand All @@ -92,7 +92,7 @@ def test_model_tpu_index(tmpdir, tpu_core):
assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}'


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_model_tpu_cores_8(tmpdir):
"""Make sure model trains on TPU."""
Expand All @@ -111,7 +111,7 @@ def test_model_tpu_cores_8(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05)


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_model_16bit_tpu_cores_1(tmpdir):
"""Make sure model trains on TPU."""
Expand All @@ -132,7 +132,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):


@pytest.mark.parametrize('tpu_core', [1, 5])
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_model_16bit_tpu_index(tmpdir, tpu_core):
"""Make sure model trains on TPU."""
Expand All @@ -153,7 +153,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_model_16bit_tpu_cores_8(tmpdir):
"""Make sure model trains on TPU."""
Expand All @@ -173,7 +173,7 @@ def test_model_16bit_tpu_cores_8(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05)


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""
Expand All @@ -200,7 +200,7 @@ def validation_step(self, *args, **kwargs):
trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_grad_norm(tmpdir):
"""Test if grad_norm works on TPU."""
Expand All @@ -219,16 +219,24 @@ def test_tpu_grad_norm(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_dataloaders_passed_to_fit(tmpdir):
"""Test if dataloaders passed to trainer works on TPU"""

tutils.reset_seed()
model = BoringModel()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8)
trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
tpu_cores=8,
)
trainer.fit(
model,
train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader(),
)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


Expand All @@ -237,7 +245,7 @@ def test_dataloaders_passed_to_fit(tmpdir):
[pytest.param(1, None), pytest.param(8, None),
pytest.param([1], 1), pytest.param([8], 8)],
)
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires missing TPU")
@RunIf(tpu=True)
def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id):
"""Test if trainer.tpu_id is set as expected"""
assert Trainer(tpu_cores=tpu_cores).accelerator_connector.tpu_id == expected_tpu_id
Expand All @@ -258,13 +266,13 @@ def test_exception_when_no_tpu_found(tmpdir):


@pytest.mark.parametrize('tpu_cores', [1, 8, [1]])
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):
"""Test if distributed_backend is set to `tpu` when tpu_cores is not None"""
assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu"


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_broadcast_on_tpu():
""" Checks if an object from the master process is broadcasted to other processes correctly"""
Expand Down Expand Up @@ -296,7 +304,7 @@ def test_broadcast(rank):
pytest.param(10, None, True),
],
)
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
if error_expected:
Expand All @@ -312,7 +320,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
[pytest.param('--tpu_cores=8', {'tpu_cores': 8}),
pytest.param("--tpu_cores=1,", {'tpu_cores': '1,'})]
)
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_cores_with_argparse(cli_args, expected):
"""Test passing tpu_cores in command line"""
Expand All @@ -327,7 +335,7 @@ def test_tpu_cores_with_argparse(cli_args, expected):
assert Trainer.from_argparse_args(args)


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_reduce():
"""Test tpu spawn reduce operation """
Expand Down

0 comments on commit 8245540

Please sign in to comment.