Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HotFix] Resolve TPU Training #6027

Merged
merged 17 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,5 @@ cifar-10-batches-py
# ctags
tags
data
MNIST
runs
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))


- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))


## [1.1.8] - 2021-02-08

### Fixed
Expand Down
29 changes: 21 additions & 8 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning import _logger as log
from pytorch_lightning import LightningDataModule
from pytorch_lightning import LightningDataModule, seed_everything
tchaton marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info

Expand Down Expand Up @@ -148,7 +148,7 @@ def val_dataloader(self):
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser])
parser.add_argument(
"--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
"--num-workers", default=2, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
)
parser.add_argument(
"--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size"
Expand All @@ -164,21 +164,19 @@ class TransferLearningModel(pl.LightningModule):
def __init__(
self,
backbone: str = "resnet50",
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 32,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6,
**kwargs,
**_,
) -> None:
"""
Args:
dl_path: Path where the data will be downloaded
"""
super().__init__()
self.backbone = backbone
self.train_bn = train_bn
self.milestones = milestones
self.batch_size = batch_size
self.lr = lr
Expand Down Expand Up @@ -276,10 +274,23 @@ def add_model_specific_args(parent_parser):
help="Name (as in ``torchvision.models``) of the feature extractor",
)
parser.add_argument(
"--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs"
"--epochs", default=5, type=int, metavar="N", help="total number of epochs", dest="nb_epochs"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument(
"--limit_train_batches",
default=1.0,
type=float,
help="How much of training dataset to check (floats = percent, int = num_batches)"
)
parser.add_argument(
"--limit_val_batches",
default=1.0,
type=float,
help="How much of validation dataset to check (floats = percent, int = num_batches)"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size")
parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use")
parser.add_argument("--tpu_cores", type=int, default=None, help="number of tpu cores to use")
parser.add_argument(
"--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr"
)
Expand All @@ -300,7 +311,7 @@ def add_model_specific_args(parent_parser):
dest="train_bn",
)
parser.add_argument(
"--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones"
"--milestones", default=[5, 10], type=list, metavar="M", help="List of two epochs milestones"
)
return parser

Expand All @@ -315,17 +326,19 @@ def main(args: argparse.Namespace) -> None:
For the sake of the example, the images dataset will be downloaded
to a temporary directory.
"""
seed_everything(42)

datamodule = CatDogImageDataModule(
dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers
)
model = TransferLearningModel(**vars(args))
finetuning_callback = MilestonesFinetuning(milestones=args.milestones)
finetuning_callback = MilestonesFinetuning(milestones=args.milestones, train_bn=args.train_bn)

trainer = pl.Trainer(
weights_summary=None,
progress_bar_refresh_rate=1,
num_sanity_val_steps=0,
tpu_cores=args.tpu_cores,
gpus=args.gpus,
max_epochs=args.nb_epochs,
callbacks=[finetuning_callback]
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
Expand Down Expand Up @@ -388,3 +389,11 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Wraps the dataloader if necessary

Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return self.training_type_plugin.process_dataloader(dataloader)
7 changes: 7 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,13 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

# when `val_loss` is being logged and no ModelCheckpoint is being provided
# `val_loss` will be selected for monitor and need to be reduced to
# prevent processes divergence
# Todo: Move this logic to logger_connector
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self.monitor == "val_loss":
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
Expand Down
18 changes: 14 additions & 4 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ def create_mp_queue(self):
def distributed_sampler_kwargs(self) -> dict:
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

@property
def should_finalize(self):
return self.world_size == 1

@property
def is_distributed(self):
return self.world_size != 1
Expand Down Expand Up @@ -179,6 +175,14 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
should_stop = int(stop.item()) == self.world_size
return should_stop

def reduce(self, output, group: Optional[Any] = None, reduce_op: str = None):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(output, torch.Tensor):
output = torch.tensor(output, device=self.device)
output = xm.mesh_reduce('reduce', output, sum)
if isinstance(reduce_op, str) and reduce_op.lower() == "mean":
output /= self.world_size
return output

tchaton marked this conversation as resolved.
Show resolved Hide resolved
def post_dispatch(self) -> None:
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module
Expand Down Expand Up @@ -213,6 +217,10 @@ def __load_weights_on_main_process(self) -> None:

self._model = model

def _close_logger(self, trainer) -> None:
if hasattr(trainer, "logger"):
trainer.logger.finalize("success")

@property
def xmp_spawn_kwargs(self):
return {
Expand All @@ -225,9 +233,11 @@ def start_training(self, trainer) -> None:
# todo: precision pluging is call in accelerator setup and should be moved
if 'XLA_USE_BF16' in os.environ:
del os.environ["XLA_USE_BF16"]
self._close_logger(trainer)
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_testing(self, trainer) -> None:
self._close_logger(trainer)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)

def start_predicting(self, trainer) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def __init__(self) -> None:
self._results = None
self.global_rank = 0

@property
def should_finalize(self):
return True

@property
@abstractmethod
def on_gpu(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False):
for dataloader_idx, dataloader in enumerate(dataloaders):
# bookkeeping
dl_outputs = []
dataloader = self.training_type_plugin.process_dataloader(dataloader)
dataloader = self.accelerator.process_dataloader(dataloader)
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
Expand Down Expand Up @@ -823,7 +823,7 @@ def run_predict(self):

# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
dataloader = self.training_type_plugin.process_dataloader(dataloader)
dataloader = self.accelerator.process_dataloader(dataloader)
dl_max_batches = self.predict_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def on_train_end(self):
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu
# kill loggers
if self.trainer.logger is not None and self.trainer.training_type_plugin.should_finalize:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self.trainer.logger is not None:
self.trainer.logger.finalize("success")

# summarize profile results
Expand Down Expand Up @@ -502,7 +502,7 @@ def tbptt_split_batch(self, batch):

def run_training_epoch(self):
# modify dataloader if needed (ddp, etc...)
train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)

# track epoch output
epoch_output = [[] for _ in range(self.num_optimizers)]
Expand Down
3 changes: 0 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,6 @@ def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
@pl_multi_process_test
def test_broadcast_on_tpu():
""" Checks if an object from the master process is broadcasted to other processes correctly"""
Expand Down