Skip to content

Commit

Permalink
Merge 254ce3b into 5157ba5
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Feb 17, 2021
2 parents 5157ba5 + 254ce3b commit 083beea
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 17 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,5 @@ cifar-10-batches-py
# ctags
tags
data
MNIST
runs
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))


- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))


## [1.1.8] - 2021-02-08

### Fixed
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
Expand Down Expand Up @@ -388,3 +389,11 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return self.training_type_plugin.process_dataloader(dataloader)
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,14 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

# when `val_loss` is being logged and no ModelCheckpoint is being provided
# `val_loss` will be selected for monitor and need to be reduced to
# prevent processes divergence
# TODO: Move this logic to logger_connector. This also needs to be fixed for any
# other monitor logged value which aren't produced from a Metric.
if self.monitor == "val_loss":
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
Expand Down
31 changes: 26 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -46,10 +47,6 @@ def create_mp_queue(self):
def distributed_sampler_kwargs(self) -> dict:
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

@property
def should_finalize(self):
return self.world_size == 1

@property
def is_distributed(self):
return self.world_size != 1
Expand Down Expand Up @@ -179,6 +176,24 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
should_stop = int(stop.item()) == self.world_size
return should_stop

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, torch.Tensor):
output = torch.tensor(output, device=self.device)

_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if _invalid_reduce_op or _invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation."
)

output = xm.mesh_reduce('reduce', output, sum)

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
output = output / self.world_size

return output

def post_dispatch(self) -> None:
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module
Expand Down Expand Up @@ -213,6 +228,10 @@ def __load_weights_on_main_process(self) -> None:

self._model = model

def _close_logger(self, trainer) -> None:
if hasattr(trainer, "logger"):
trainer.logger.finalize("success")

@property
def xmp_spawn_kwargs(self):
return {
Expand All @@ -225,9 +244,11 @@ def start_training(self, trainer) -> None:
# todo: precision pluging is call in accelerator setup and should be moved
if 'XLA_USE_BF16' in os.environ:
del os.environ["XLA_USE_BF16"]
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_testing(self, trainer) -> None:
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_predicting(self, trainer) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def __init__(self) -> None:
self._results = None
self.global_rank = 0

@property
def should_finalize(self):
return True

@property
@abstractmethod
def on_gpu(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False):
for dataloader_idx, dataloader in enumerate(dataloaders):
# bookkeeping
dl_outputs = []
dataloader = self.training_type_plugin.process_dataloader(dataloader)
dataloader = self.accelerator.process_dataloader(dataloader)
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
Expand Down Expand Up @@ -823,7 +823,7 @@ def run_predict(self):

# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
dataloader = self.training_type_plugin.process_dataloader(dataloader)
dataloader = self.accelerator.process_dataloader(dataloader)
dl_max_batches = self.predict_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def on_train_end(self):
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu
# kill loggers
if self.trainer.logger is not None and self.trainer.training_type_plugin.should_finalize:
if self.trainer.logger is not None:
self.trainer.logger.finalize("success")

# summarize profile results
Expand Down Expand Up @@ -502,7 +502,7 @@ def tbptt_split_batch(self, batch):

def run_training_epoch(self):
# modify dataloader if needed (ddp, etc...)
train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)

# track epoch output
epoch_output = [[] for _ in range(self.num_optimizers)]
Expand Down
27 changes: 24 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.utils import pl_multi_process_test
Expand Down Expand Up @@ -264,9 +265,6 @@ def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
@pl_multi_process_test
def test_broadcast_on_tpu():
""" Checks if an object from the master process is broadcasted to other processes correctly"""
Expand Down Expand Up @@ -327,3 +325,26 @@ def test_tpu_cores_with_argparse(cli_args, expected):
for k, v in expected.items():
assert getattr(args, k) == v
assert Trainer.from_argparse_args(args)


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_tpu_reduce():
"""Test tpu spawn reduce operation """

def test_reduce(rank):
trainer = Trainer(tpu_cores=8)
# faster this way
reduce_ops = ["mean", "AVG", "undefined", "sum", ReduceOp.SUM, ReduceOp.MAX]
for reduce_op in reduce_ops:
if reduce_op == "undefined" or reduce_op == ReduceOp.MAX:
with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"):
result = trainer.training_type_plugin.reduce(1, reduce_op)
else:
result = trainer.training_type_plugin.reduce(1, reduce_op)
if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"):
assert result.item() == 1
else:
assert result.item() == 8

xmp.spawn(test_reduce, nprocs=8, start_method='fork')

0 comments on commit 083beea

Please sign in to comment.