Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HotFix] Resolve TPU Training #6027

Merged
merged 17 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,5 @@ cifar-10-batches-py
# ctags
tags
data
MNIST
runs
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))


- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))


## [1.1.8] - 2021-02-08

### Fixed
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
Expand Down Expand Up @@ -388,3 +389,11 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Wraps the dataloader if necessary

Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return self.training_type_plugin.process_dataloader(dataloader)
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,14 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

# when `val_loss` is being logged and no ModelCheckpoint is being provided
# `val_loss` will be selected for monitor and need to be reduced to
# prevent processes divergence
# TODO: Move this logic to logger_connector. This also needs to be fixed for any
# other monitor logged value which aren't produced from a Metric.
if self.monitor == "val_loss":
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
Expand Down
36 changes: 31 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -46,10 +47,6 @@ def create_mp_queue(self):
def distributed_sampler_kwargs(self) -> dict:
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

@property
def should_finalize(self):
return self.world_size == 1

@property
def is_distributed(self):
return self.world_size != 1
Expand Down Expand Up @@ -179,6 +176,29 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
should_stop = int(stop.item()) == self.world_size
return should_stop

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, torch.Tensor):
output = torch.tensor(output, device=self.device)

if (isinstance(reduce_op, ReduceOp) and ReduceOp != ReduceOp.SUM) \
or isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg"):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
"Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation."
)

divide_by_world_size = False

output = xm.mesh_reduce('reduce', output, sum)

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True
# sync all processes before reduction

if divide_by_world_size:
output = output / self.world_size
tchaton marked this conversation as resolved.
Show resolved Hide resolved

return output

def post_dispatch(self) -> None:
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module
Expand Down Expand Up @@ -213,6 +233,10 @@ def __load_weights_on_main_process(self) -> None:

self._model = model

def _close_logger(self, trainer) -> None:
if hasattr(trainer, "logger"):
trainer.logger.finalize("success")

@property
def xmp_spawn_kwargs(self):
return {
Expand All @@ -225,9 +249,11 @@ def start_training(self, trainer) -> None:
# todo: precision pluging is call in accelerator setup and should be moved
if 'XLA_USE_BF16' in os.environ:
del os.environ["XLA_USE_BF16"]
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_testing(self, trainer) -> None:
self._close_logger(trainer)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_predicting(self, trainer) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def __init__(self) -> None:
self._results = None
self.global_rank = 0

@property
def should_finalize(self):
return True

@property
@abstractmethod
def on_gpu(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False):
for dataloader_idx, dataloader in enumerate(dataloaders):
# bookkeeping
dl_outputs = []
dataloader = self.training_type_plugin.process_dataloader(dataloader)
dataloader = self.accelerator.process_dataloader(dataloader)
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
Expand Down Expand Up @@ -823,7 +823,7 @@ def run_predict(self):

# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
dataloader = self.training_type_plugin.process_dataloader(dataloader)
dataloader = self.accelerator.process_dataloader(dataloader)
dl_max_batches = self.predict_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def on_train_end(self):
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu
# kill loggers
if self.trainer.logger is not None and self.trainer.training_type_plugin.should_finalize:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self.trainer.logger is not None:
self.trainer.logger.finalize("success")

# summarize profile results
Expand Down Expand Up @@ -502,7 +502,7 @@ def tbptt_split_batch(self, batch):

def run_training_epoch(self):
# modify dataloader if needed (ddp, etc...)
train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)

# track epoch output
epoch_output = [[] for _ in range(self.num_optimizers)]
Expand Down
3 changes: 0 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,6 @@ def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
@pl_multi_process_test
def test_broadcast_on_tpu():
""" Checks if an object from the master process is broadcasted to other processes correctly"""
Expand Down