Skip to content

Commit

Permalink
Fix TPU Spawn gather (#6896)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Apr 9, 2021
1 parent 2e53fd3 commit 5552503
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))


- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896))


- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


Expand Down
18 changes: 1 addition & 17 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -57,21 +56,6 @@ def run_optimizer_step(
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
Return:
A tensor of shape (world_size, batch, ...)
"""
# todo: Add support for backward with all_gather
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
return xm.all_gather(tensor).view(-1, *tensor.shape)
return tensor

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):

model = self.lightning_module
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = torch.tensor(int(decision), device=self.lightning_module.device)
decision = self.reduce(decision, reduce_op="sum")
decision = bool(decision == self.world_size)
return decision

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, torch.Tensor):
output = torch.tensor(output, device=self.device)
output = torch.tensor(output, device=self.lightning_module.device)

_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
Expand Down Expand Up @@ -267,3 +267,15 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))
4 changes: 2 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def test_tpu_clip_grad_by_value(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
limit_train_batches=4,
limit_val_batches=4,
limit_train_batches=10,
limit_val_batches=10,
gradient_clip_val=0.5,
gradient_clip_algorithm='value'
)
Expand Down

0 comments on commit 5552503

Please sign in to comment.