diff --git a/CHANGELOG.md b/CHANGELOG.md index 24fcc76c10ecf..3ee272dcee98b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -228,6 +228,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) +- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896)) + + - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 772c9f354ac9f..087f6df7a1c6a 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -1,6 +1,18 @@ -from typing import Any, Callable, Optional, Union +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, TYPE_CHECKING, Union -import torch from torch.optim import Optimizer from pytorch_lightning.accelerators.accelerator import Accelerator @@ -16,10 +28,19 @@ xla_clip_grad_norm_ = clip_grad_norm_ +if TYPE_CHECKING: + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.trainer.trainer import Trainer + class TPUAccelerator(Accelerator): - def setup(self, trainer, model): + def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training. + """ if isinstance(self.precision_plugin, MixedPrecisionPlugin): raise MisconfigurationException( "amp + tpu is not supported. " @@ -30,24 +51,11 @@ def setup(self, trainer, model): raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") return super().setup(trainer, model) - def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): + def run_optimizer_step( + self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any + ) -> None: xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - # todo: Add support for backward with all_gather - if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: - return xm.all_gather(tensor).view(-1, *tensor.shape) - return tensor - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index aa37586d8bfb5..2ca747b98ae78 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -192,14 +192,14 @@ def broadcast(self, obj: object, src: int = 0) -> object: return obj def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.device) - decision = self.reduce(decision, "sum") + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op="sum") decision = bool(decision == self.world_size) return decision def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): - output = torch.tensor(output, device=self.device) + output = torch.tensor(output, device=self.lightning_module.device) _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") @@ -265,3 +265,15 @@ def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None: if _OMEGACONF_AVAILABLE: checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + return xm.all_gather(tensor.unsqueeze(0)) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 24c0b615b95bb..261adf5bc45af 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -54,7 +54,7 @@ def val_dataloader(self): return DataLoader(RandomDataset(32, 2000), batch_size=32) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_model_tpu_cores_1(tmpdir): """Make sure model trains on TPU.""" @@ -73,7 +73,7 @@ def test_model_tpu_cores_1(tmpdir): @pytest.mark.parametrize('tpu_core', [1, 5]) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_model_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" @@ -92,7 +92,7 @@ def test_model_tpu_index(tmpdir, tpu_core): assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}' -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_model_tpu_cores_8(tmpdir): """Make sure model trains on TPU.""" @@ -111,7 +111,7 @@ def test_model_tpu_cores_8(tmpdir): tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_model_16bit_tpu_cores_1(tmpdir): """Make sure model trains on TPU.""" @@ -132,7 +132,7 @@ def test_model_16bit_tpu_cores_1(tmpdir): @pytest.mark.parametrize('tpu_core', [1, 5]) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_model_16bit_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" @@ -153,7 +153,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core): assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_model_16bit_tpu_cores_8(tmpdir): """Make sure model trains on TPU.""" @@ -173,7 +173,7 @@ def test_model_16bit_tpu_cores_8(tmpdir): tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works""" @@ -200,7 +200,7 @@ def validation_step(self, *args, **kwargs): trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32)) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_tpu_grad_norm(tmpdir): """Test if grad_norm works on TPU.""" @@ -219,7 +219,7 @@ def test_tpu_grad_norm(tmpdir): tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_dataloaders_passed_to_fit(tmpdir): """Test if dataloaders passed to trainer works on TPU""" @@ -227,8 +227,16 @@ def test_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8) - trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader()) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + tpu_cores=8, + ) + trainer.fit( + model, + train_dataloader=model.train_dataloader(), + val_dataloaders=model.val_dataloader(), + ) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -237,7 +245,7 @@ def test_dataloaders_passed_to_fit(tmpdir): [pytest.param(1, None), pytest.param(8, None), pytest.param([1], 1), pytest.param([8], 8)], ) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires missing TPU") +@RunIf(tpu=True) def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id): """Test if trainer.tpu_id is set as expected""" assert Trainer(tpu_cores=tpu_cores).accelerator_connector.tpu_id == expected_tpu_id @@ -258,13 +266,13 @@ def test_exception_when_no_tpu_found(tmpdir): @pytest.mark.parametrize('tpu_cores', [1, 8, [1]]) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores): """Test if distributed_backend is set to `tpu` when tpu_cores is not None""" assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu" -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_broadcast_on_tpu(): """ Checks if an object from the master process is broadcasted to other processes correctly""" @@ -296,7 +304,7 @@ def test_broadcast(rank): pytest.param(10, None, True), ], ) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): if error_expected: @@ -312,7 +320,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): [pytest.param('--tpu_cores=8', {'tpu_cores': 8}), pytest.param("--tpu_cores=1,", {'tpu_cores': '1,'})] ) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_tpu_cores_with_argparse(cli_args, expected): """Test passing tpu_cores in command line""" @@ -327,7 +335,7 @@ def test_tpu_cores_with_argparse(cli_args, expected): assert Trainer.from_argparse_args(args) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test def test_tpu_reduce(): """Test tpu spawn reduce operation """