From 87201df00d87e2f47e9119cf75dd4924215129ff Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 11:13:28 +0000 Subject: [PATCH 01/15] fix tpus --- pytorch_lightning/accelerators/accelerator.py | 8 ++++++++ pytorch_lightning/loggers/tensorboard.py | 3 +++ .../plugins/training_type/tpu_spawn.py | 18 ++++++++++++++---- .../training_type/training_type_plugin.py | 4 ---- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 4 ++-- 6 files changed, 29 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2e8e31139dda2..1526b99003f41 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -388,3 +388,11 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s A tensor of shape (world_size, batch, ...) """ return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Wraps the dataloader if necessary + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return self.training_type_plugin.process_dataloader(dataloader) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 2f8c888eba963..3cba030e4ed40 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -234,6 +234,9 @@ def save(self) -> None: @rank_zero_only def finalize(self, status: str) -> None: + self.close() + + def close(self): self.experiment.flush() self.experiment.close() self.save() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d4374d0ef9c6a..15e59407f9494 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -46,10 +46,6 @@ def create_mp_queue(self): def distributed_sampler_kwargs(self) -> dict: return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - @property - def should_finalize(self): - return self.world_size == 1 - @property def is_distributed(self): return self.world_size != 1 @@ -179,6 +175,14 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = int(stop.item()) == self.world_size return should_stop + def reduce(self, output, group: Optional[Any] = None, reduce_op: str = None): + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.device) + output = xm.mesh_reduce('reduce', output, sum) + if isinstance(reduce_op, str) and reduce_op.lower() == "mean": + output /= self.world_size + return output + def post_dispatch(self) -> None: # TODO: Check if trainer references can be resolved otherwise model = self.lightning_module @@ -213,6 +217,10 @@ def __load_weights_on_main_process(self) -> None: self._model = model + def _close_logger(self, trainer) -> None: + if hasattr(trainer, "logger"): + trainer.logger.close() + @property def xmp_spawn_kwargs(self): return { @@ -225,9 +233,11 @@ def start_training(self, trainer) -> None: # todo: precision pluging is call in accelerator setup and should be moved if 'XLA_USE_BF16' in os.environ: del os.environ["XLA_USE_BF16"] + self._close_logger() xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_testing(self, trainer) -> None: + self._close_logger() xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_predicting(self, trainer) -> None: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cede3e5f98b43..938a17249e9f6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -35,10 +35,6 @@ def __init__(self) -> None: self._results = None self.global_rank = 0 - @property - def should_finalize(self): - return True - @property @abstractmethod def on_gpu(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2453a08ba9067..2b2b2f92dce59 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -711,7 +711,7 @@ def run_evaluation(self, max_batches=None, on_epoch=False): for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] - dataloader = self.training_type_plugin.process_dataloader(dataloader) + dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): @@ -823,7 +823,7 @@ def run_predict(self): # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): - dataloader = self.training_type_plugin.process_dataloader(dataloader) + dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.predict_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0908e96bd1c17..57c0b10f12412 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -140,7 +140,7 @@ def on_train_end(self): # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu # kill loggers - if self.trainer.logger is not None and self.trainer.training_type_plugin.should_finalize: + if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results @@ -502,7 +502,7 @@ def tbptt_split_batch(self, batch): def run_training_epoch(self): # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) + train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] From d6ced923fd7353358609602e13b1d2cd4c23e251 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 11:43:11 +0000 Subject: [PATCH 02/15] update --- .gitignore | 2 ++ .../domain_templates/computer_vision_fine_tuning.py | 12 +++++++++--- pytorch_lightning/accelerators/accelerator.py | 3 ++- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 390551b8f6e60..cd0ba22453512 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,5 @@ cifar-10-batches-py # ctags tags data +MNIST +runs diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 65bf1bde141fa..b690b0b0c6b45 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -55,7 +55,7 @@ import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import _logger as log -from pytorch_lightning import LightningDataModule +from pytorch_lightning import LightningDataModule, seed_everything from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.utilities import rank_zero_info @@ -148,7 +148,7 @@ def val_dataloader(self): def add_model_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser]) parser.add_argument( - "--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers" + "--num-workers", default=2, type=int, metavar="W", help="number of CPU workers", dest="num_workers" ) parser.add_argument( "--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size" @@ -276,10 +276,11 @@ def add_model_specific_args(parent_parser): help="Name (as in ``torchvision.models``) of the feature extractor", ) parser.add_argument( - "--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" + "--epochs", default=5, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" ) parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size") parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use") + parser.add_argument("--tpu_cores", type=int, default=None, help="number of tpu cores to use") parser.add_argument( "--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr" ) @@ -315,6 +316,7 @@ def main(args: argparse.Namespace) -> None: For the sake of the example, the images dataset will be downloaded to a temporary directory. """ + seed_everything(42) datamodule = CatDogImageDataModule( dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers @@ -326,6 +328,7 @@ def main(args: argparse.Namespace) -> None: weights_summary=None, progress_bar_refresh_rate=1, num_sanity_val_steps=0, + tpu_cores=args.tpu_cores, gpus=args.gpus, max_epochs=args.nb_epochs, callbacks=[finetuning_callback] @@ -333,6 +336,9 @@ def main(args: argparse.Namespace) -> None: trainer.fit(model, datamodule=datamodule) + if args.nb_epochs >= 5: + assert trainer.callbacks_metrics["val_acc"] > 0.7 + def get_args() -> argparse.Namespace: parent_parser = argparse.ArgumentParser(add_help=False) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1526b99003f41..967b6a85c878b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ( diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 15e59407f9494..0120e21d45726 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -233,11 +233,11 @@ def start_training(self, trainer) -> None: # todo: precision pluging is call in accelerator setup and should be moved if 'XLA_USE_BF16' in os.environ: del os.environ["XLA_USE_BF16"] - self._close_logger() + self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_testing(self, trainer) -> None: - self._close_logger() + self._close_logger(trainer) xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) def start_predicting(self, trainer) -> None: From 30f388e114f7c6048661188d1f7694b192007fd8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 11:52:13 +0000 Subject: [PATCH 03/15] add back reduction in val_loss --- pytorch_lightning/callbacks/model_checkpoint.py | 6 ++++++ tests/helpers/utils.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a8024bef2a539..d68b3e29683c5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -554,6 +554,12 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): epoch = metrics.get("epoch") step = metrics.get("step") + # when `val_loss` is being logged and no ModelCheckpoint is being provided + # `val_loss` will be selected for monitor and need to be reduced to + # prevent processes divergence + if self.monitor == "val_loss": + current = trainer.training_type_plugin.reduce(current, reduce_op="mean") + if self.check_monitor_top_k(current): self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) elif self.verbose: diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index d23f3d5540e78..6c13e2f506bae 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -91,7 +91,7 @@ def wrapper(*args, **kwargs): def inner_f(queue, **kwargs): try: - func(**kwargs) + func(*args, **kwargs) queue.put(1) except Exception: _trace = traceback.format_exc() From 8f974a588c5b16f33c830d07b39b5f9f7928a4ab Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 12:39:14 +0000 Subject: [PATCH 04/15] resolve some bugs with TPUs --- dockers/tpu-tests/tpu_test_cases.jsonnet | 6 ++++++ .../computer_vision_fine_tuning.py | 19 ++++++++++++++----- tests/models/test_tpu.py | 3 --- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 8c3f3693fda50..ade399bc2e0e9 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -25,6 +25,12 @@ local tputests = base.BaseTest { pytorch_lightning/utilities/xla_device_utils.py \ tests/accelerators/test_tpu_backend.py \ tests/models/test_tpu.py + # Takes too long + # python pl_examples/domain_templates/computer_vision_fine_tuning.py \ + # --tpu_cores 8 \ + # --epochs 15 \ + # --limit_train_batches 8 \ + # --limit_val_batches 8 test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index b690b0b0c6b45..af81e320d0ced 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -71,7 +71,7 @@ def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False): self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule): - self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn) + self.freeze(modules=pl_module.feature_extractor, train_bn=False) def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): if epoch == self.milestones[0]: @@ -278,6 +278,18 @@ def add_model_specific_args(parent_parser): parser.add_argument( "--epochs", default=5, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" ) + parser.add_argument( + "--limit_train_batches", + default=1.0, + type=float, + help="How much of training dataset to check (floats = percent, int = num_batches)" + ) + parser.add_argument( + "--limit_val_batches", + default=1.0, + type=float, + help="How much of validation dataset to check (floats = percent, int = num_batches)" + ) parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size") parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use") parser.add_argument("--tpu_cores", type=int, default=None, help="number of tpu cores to use") @@ -301,7 +313,7 @@ def add_model_specific_args(parent_parser): dest="train_bn", ) parser.add_argument( - "--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones" + "--milestones", default=[5, 10], type=list, metavar="M", help="List of two epochs milestones" ) return parser @@ -336,9 +348,6 @@ def main(args: argparse.Namespace) -> None: trainer.fit(model, datamodule=datamodule) - if args.nb_epochs >= 5: - assert trainer.callbacks_metrics["val_acc"] > 0.7 - def get_args() -> argparse.Namespace: parent_parser = argparse.ArgumentParser(add_help=False) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index d9ea8a9917d2b..cc1206183d739 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -264,9 +264,6 @@ def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores): @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) @pl_multi_process_test def test_broadcast_on_tpu(): """ Checks if an object from the master process is broadcasted to other processes correctly""" From 3d14ff6fc2021eb0ad3af8fbc598ce206c873e3e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 12:40:46 +0000 Subject: [PATCH 05/15] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d78578fc1fb2..2de970869880b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) +- Fixed synchrnization issues with TPUs Training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) + + ## [1.1.8] - 2021-02-08 ### Fixed From e4a956db21a3ee98a1918d51a8285806835dcd87 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 14:08:32 +0000 Subject: [PATCH 06/15] update on comments --- CHANGELOG.md | 2 +- dockers/tpu-tests/tpu_test_cases.jsonnet | 6 ------ pytorch_lightning/callbacks/model_checkpoint.py | 1 + pytorch_lightning/loggers/tensorboard.py | 3 --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- tests/helpers/utils.py | 2 +- 6 files changed, 4 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2de970869880b..c28e1fa2f202e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,7 +288,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) -- Fixed synchrnization issues with TPUs Training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) +- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) ## [1.1.8] - 2021-02-08 diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index ade399bc2e0e9..8c3f3693fda50 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -25,12 +25,6 @@ local tputests = base.BaseTest { pytorch_lightning/utilities/xla_device_utils.py \ tests/accelerators/test_tpu_backend.py \ tests/models/test_tpu.py - # Takes too long - # python pl_examples/domain_templates/computer_vision_fine_tuning.py \ - # --tpu_cores 8 \ - # --epochs 15 \ - # --limit_train_batches 8 \ - # --limit_val_batches 8 test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d68b3e29683c5..999457ee9ba36 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -557,6 +557,7 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): # when `val_loss` is being logged and no ModelCheckpoint is being provided # `val_loss` will be selected for monitor and need to be reduced to # prevent processes divergence + # Todo: Move this logic to logger_connector if self.monitor == "val_loss": current = trainer.training_type_plugin.reduce(current, reduce_op="mean") diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 3cba030e4ed40..2f8c888eba963 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -234,9 +234,6 @@ def save(self) -> None: @rank_zero_only def finalize(self, status: str) -> None: - self.close() - - def close(self): self.experiment.flush() self.experiment.close() self.save() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 0120e21d45726..05a4913e2dab1 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -219,7 +219,7 @@ def __load_weights_on_main_process(self) -> None: def _close_logger(self, trainer) -> None: if hasattr(trainer, "logger"): - trainer.logger.close() + trainer.logger.finalize() @property def xmp_spawn_kwargs(self): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 6c13e2f506bae..d23f3d5540e78 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -91,7 +91,7 @@ def wrapper(*args, **kwargs): def inner_f(queue, **kwargs): try: - func(*args, **kwargs) + func(**kwargs) queue.put(1) except Exception: _trace = traceback.format_exc() From 23a594e958f7c686f636d0f4f6aafc1d57349ebb Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 14:11:32 +0000 Subject: [PATCH 07/15] forgot status --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 05a4913e2dab1..7cd3006b20152 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -219,7 +219,7 @@ def __load_weights_on_main_process(self) -> None: def _close_logger(self, trainer) -> None: if hasattr(trainer, "logger"): - trainer.logger.finalize() + trainer.logger.finalize("success") @property def xmp_spawn_kwargs(self): From efcff122e7c8560c1699dbb82426a433f7a75527 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 17 Feb 2021 15:20:48 +0100 Subject: [PATCH 08/15] Fix train_bn arg --- .../domain_templates/computer_vision_fine_tuning.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index af81e320d0ced..4fb7da458638d 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -71,7 +71,7 @@ def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False): self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule): - self.freeze(modules=pl_module.feature_extractor, train_bn=False) + self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn) def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): if epoch == self.milestones[0]: @@ -164,13 +164,12 @@ class TransferLearningModel(pl.LightningModule): def __init__( self, backbone: str = "resnet50", - train_bn: bool = True, milestones: tuple = (5, 10), batch_size: int = 32, lr: float = 1e-2, lr_scheduler_gamma: float = 1e-1, num_workers: int = 6, - **kwargs, + **_, ) -> None: """ Args: @@ -178,7 +177,6 @@ def __init__( """ super().__init__() self.backbone = backbone - self.train_bn = train_bn self.milestones = milestones self.batch_size = batch_size self.lr = lr @@ -334,7 +332,7 @@ def main(args: argparse.Namespace) -> None: dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers ) model = TransferLearningModel(**vars(args)) - finetuning_callback = MilestonesFinetuning(milestones=args.milestones) + finetuning_callback = MilestonesFinetuning(milestones=args.milestones, train_bn=args.train_bn) trainer = pl.Trainer( weights_summary=None, From d9a67ae9c06046efc0b05ae690f99ac322362065 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 15:12:03 +0000 Subject: [PATCH 09/15] resolve comments --- pytorch_lightning/callbacks/model_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 999457ee9ba36..83d86b619c7c9 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -557,7 +557,8 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): # when `val_loss` is being logged and no ModelCheckpoint is being provided # `val_loss` will be selected for monitor and need to be reduced to # prevent processes divergence - # Todo: Move this logic to logger_connector + # TODO: Move this logic to logger_connector. This also needs to be fixed for any + # other monitor logged value which aren't produced from a Metric. if self.monitor == "val_loss": current = trainer.training_type_plugin.reduce(current, reduce_op="mean") From f8c77fd06b3e1e157d96e41f479fe9dee0f42106 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 15:43:59 +0000 Subject: [PATCH 10/15] update on comments --- .../computer_vision_fine_tuning.py | 29 +++++-------------- .../plugins/training_type/tpu_spawn.py | 24 ++++++++++++--- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 4fb7da458638d..65bf1bde141fa 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -55,7 +55,7 @@ import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import _logger as log -from pytorch_lightning import LightningDataModule, seed_everything +from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.utilities import rank_zero_info @@ -148,7 +148,7 @@ def val_dataloader(self): def add_model_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser]) parser.add_argument( - "--num-workers", default=2, type=int, metavar="W", help="number of CPU workers", dest="num_workers" + "--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers" ) parser.add_argument( "--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size" @@ -164,12 +164,13 @@ class TransferLearningModel(pl.LightningModule): def __init__( self, backbone: str = "resnet50", + train_bn: bool = True, milestones: tuple = (5, 10), batch_size: int = 32, lr: float = 1e-2, lr_scheduler_gamma: float = 1e-1, num_workers: int = 6, - **_, + **kwargs, ) -> None: """ Args: @@ -177,6 +178,7 @@ def __init__( """ super().__init__() self.backbone = backbone + self.train_bn = train_bn self.milestones = milestones self.batch_size = batch_size self.lr = lr @@ -274,23 +276,10 @@ def add_model_specific_args(parent_parser): help="Name (as in ``torchvision.models``) of the feature extractor", ) parser.add_argument( - "--epochs", default=5, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" - ) - parser.add_argument( - "--limit_train_batches", - default=1.0, - type=float, - help="How much of training dataset to check (floats = percent, int = num_batches)" - ) - parser.add_argument( - "--limit_val_batches", - default=1.0, - type=float, - help="How much of validation dataset to check (floats = percent, int = num_batches)" + "--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" ) parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size") parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use") - parser.add_argument("--tpu_cores", type=int, default=None, help="number of tpu cores to use") parser.add_argument( "--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr" ) @@ -311,7 +300,7 @@ def add_model_specific_args(parent_parser): dest="train_bn", ) parser.add_argument( - "--milestones", default=[5, 10], type=list, metavar="M", help="List of two epochs milestones" + "--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones" ) return parser @@ -326,19 +315,17 @@ def main(args: argparse.Namespace) -> None: For the sake of the example, the images dataset will be downloaded to a temporary directory. """ - seed_everything(42) datamodule = CatDogImageDataModule( dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers ) model = TransferLearningModel(**vars(args)) - finetuning_callback = MilestonesFinetuning(milestones=args.milestones, train_bn=args.train_bn) + finetuning_callback = MilestonesFinetuning(milestones=args.milestones) trainer = pl.Trainer( weights_summary=None, progress_bar_refresh_rate=1, num_sanity_val_steps=0, - tpu_cores=args.tpu_cores, gpus=args.gpus, max_epochs=args.nb_epochs, callbacks=[finetuning_callback] diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7cd3006b20152..77ee664bcb883 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -10,7 +10,8 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything if _TPU_AVAILABLE: @@ -175,12 +176,27 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = int(stop.item()) == self.world_size return should_stop - def reduce(self, output, group: Optional[Any] = None, reduce_op: str = None): + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.device) + + if (isinstance(reduce_op, ReduceOp) and ReduceOp != ReduceOp.SUM) \ + or isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg"): + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + divide_by_world_size = False + output = xm.mesh_reduce('reduce', output, sum) - if isinstance(reduce_op, str) and reduce_op.lower() == "mean": - output /= self.world_size + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + divide_by_world_size = True + # sync all processes before reduction + + if divide_by_world_size: + output = output / self.world_size + return output def post_dispatch(self) -> None: From 1476ef5e3541694422669936fb244dd5fc47f148 Mon Sep 17 00:00:00 2001 From: chaton Date: Wed, 17 Feb 2021 16:00:32 +0000 Subject: [PATCH 11/15] Update pytorch_lightning/plugins/training_type/tpu_spawn.py Co-authored-by: Jirka Borovec --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 77ee664bcb883..db057281edf4e 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -180,8 +180,8 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.device) - if (isinstance(reduce_op, ReduceOp) and ReduceOp != ReduceOp.SUM) \ - or isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg"): + _valid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if (isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM) or _valid_reduce_op_str: raise MisconfigurationException( "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." ) From 6deb4465dc5557b8e1b0be1f8b3921d82052a9c7 Mon Sep 17 00:00:00 2001 From: chaton Date: Wed, 17 Feb 2021 16:00:46 +0000 Subject: [PATCH 12/15] Update pytorch_lightning/plugins/training_type/tpu_spawn.py Co-authored-by: Jirka Borovec --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index db057281edf4e..2263a2ba6ae59 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -186,15 +186,9 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." ) - divide_by_world_size = False - output = xm.mesh_reduce('reduce', output, sum) if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - divide_by_world_size = True - # sync all processes before reduction - - if divide_by_world_size: output = output / self.world_size return output From 9af54f9f801fedbf3892d9c4a408de0f6fed0d6b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Feb 2021 17:02:47 +0100 Subject: [PATCH 13/15] Apply suggestions from code review --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2263a2ba6ae59..0136e78a4381f 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -180,8 +180,9 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.device) - _valid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") - if (isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM) or _valid_reduce_op_str: + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: raise MisconfigurationException( "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." ) From 0b00c9976aaa88fdb66364fc16a6c7fd734e308b Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 17 Feb 2021 16:16:29 +0000 Subject: [PATCH 14/15] add reduce test --- tests/models/test_tpu.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index cc1206183d739..3fa52bab95055 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -29,6 +29,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.utils import pl_multi_process_test +from pytorch_lightning.utilities.distributed import ReduceOp + if _TPU_AVAILABLE: import torch_xla @@ -324,3 +326,25 @@ def test_tpu_cores_with_argparse(cli_args, expected): for k, v in expected.items(): assert getattr(args, k) == v assert Trainer.from_argparse_args(args) + + +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_tpu_reduce(): + """Test passing tpu_cores in command line""" + + def test_reduce(rank): + trainer = Trainer(tpu_cores=8) + reduce_ops = ["mean", "AVG", "undefined", "sum", ReduceOp.SUM, ReduceOp.MAX] + for reduce_op in reduce_ops: + if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: + with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): + result = trainer.training_type_plugin.reduce(1, reduce_op) + else: + result = trainer.training_type_plugin.reduce(1, reduce_op) + if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): + assert result.item() == 1 + else: + assert result.item() == 8 + + xmp.spawn(test_reduce, nprocs=8, start_method='fork') From 254ce3b30fcbde8526eb2a1d2e6b6eb1d2479173 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 17 Feb 2021 16:19:15 +0000 Subject: [PATCH 15/15] resolve flake8 --- tests/models/test_tpu.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 3fa52bab95055..4c6620b07b74a 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -26,11 +26,10 @@ from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.utils import pl_multi_process_test -from pytorch_lightning.utilities.distributed import ReduceOp - if _TPU_AVAILABLE: import torch_xla @@ -331,14 +330,15 @@ def test_tpu_cores_with_argparse(cli_args, expected): @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_tpu_reduce(): - """Test passing tpu_cores in command line""" - + """Test tpu spawn reduce operation """ + def test_reduce(rank): trainer = Trainer(tpu_cores=8) + # faster this way reduce_ops = ["mean", "AVG", "undefined", "sum", ReduceOp.SUM, ReduceOp.MAX] for reduce_op in reduce_ops: if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: - with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): + with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): result = trainer.training_type_plugin.reduce(1, reduce_op) else: result = trainer.training_type_plugin.reduce(1, reduce_op)