diff --git a/CHANGELOG.md b/CHANGELOG.md index 254fa7e2bbbc5..8e4ca17b5922c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -231,6 +231,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) +- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896)) + + - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 35a475e3e790d..087f6df7a1c6a 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, TYPE_CHECKING, Union -import torch from torch.optim import Optimizer from pytorch_lightning.accelerators.accelerator import Accelerator @@ -57,21 +56,6 @@ def run_optimizer_step( ) -> None: xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - # todo: Add support for backward with all_gather - if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: - return xm.all_gather(tensor).view(-1, *tensor.shape) - return tensor - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d546067e88a1c..b072a29c7fbc6 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -195,14 +195,14 @@ def broadcast(self, obj: object, src: int = 0) -> object: return obj def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.device) - decision = self.reduce(decision, "sum") + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op="sum") decision = bool(decision == self.world_size) return decision def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): - output = torch.tensor(output, device=self.device) + output = torch.tensor(output, device=self.lightning_module.device) _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") @@ -267,3 +267,15 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: if _OMEGACONF_AVAILABLE: checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + return xm.all_gather(tensor.unsqueeze(0)) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 30149a7e022cd..6409f2ef4bcbf 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -229,8 +229,8 @@ def test_tpu_clip_grad_by_value(tmpdir): progress_bar_refresh_rate=0, max_epochs=4, tpu_cores=1, - limit_train_batches=4, - limit_val_batches=4, + limit_train_batches=10, + limit_val_batches=10, gradient_clip_val=0.5, gradient_clip_algorithm='value' )