From edb3e834db22fd8a0cfdc6bde89b1c4b555f3a7c Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 2 Dec 2020 13:24:50 +0100 Subject: [PATCH 01/32] Refactor Trainer in advance of implementing Trainer.validate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Replace the `Trainer.testing` attribute with `Trainer.evaluating`, which is currently set to `'test'` if the top-level function called by the user was `Trainer.test(…)` and `None` otherwise. In the next PR, it will be set to `'validation’` when the user calls `validate(…)`. * Update the other components to use the new attribute instead of `Trainer.testing` * Disable the `EarlyStopping` and `ModelCheckpoint` callbacks when `evaluating`. This has no effect when evaluating on the test set, since they were already disabled, but it will be necessary for the validation set * Rename a few other attributes of `Trainer` to clarify that they will be used by both `test(…)` and `validate(…)` --- pytorch_lightning/accelerators/accelerator.py | 4 +- pytorch_lightning/callbacks/early_stopping.py | 4 +- .../callbacks/model_checkpoint.py | 1 + .../trainer/configuration_validator.py | 6 +- .../logger_connector/logger_connector.py | 8 +- .../trainer/connectors/model_connector.py | 6 +- pytorch_lightning/trainer/evaluation_loop.py | 17 ++-- pytorch_lightning/trainer/trainer.py | 82 +++++++++++-------- pytorch_lightning/trainer/training_loop.py | 2 +- tests/trainer/test_trainer.py | 6 +- 10 files changed, 80 insertions(+), 56 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 931a39e07af89..ed49bb1a5a7b7 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -60,7 +60,7 @@ def broadcast(self, obj, src=0): return obj def train_or_test(self): - if self.trainer.testing: + if self.trainer.evaluating: results = self.trainer.run_test() else: results = self.trainer.train() @@ -160,7 +160,7 @@ def early_stopping_should_stop(self, pl_module): return self.trainer.should_stop def setup_optimizers(self, model): - if self.trainer.testing is True: + if self.trainer.evaluating: return optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 005a3f8cde4ad..3a2b5c2a57259 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -134,13 +134,13 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.running_sanity_check or trainer.evaluating: return self._run_early_stopping_check(trainer, pl_module) def on_validation_epoch_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.running_sanity_check or trainer.evaluating: return if self._validate_condition_metric(trainer.logger_connector.callback_metrics): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d41928cd55aea..0efaef9c660b7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -220,6 +220,7 @@ def save_checkpoint(self, trainer, pl_module): or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check + or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated or self.last_global_step_saved == global_step # already saved at the last step ): return diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 01c0119e857ec..23967dc1bc2a9 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if not self.trainer.evaluating: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'validation') else: - # check test loop configuration - self.__verify_eval_loop_configuration(model, 'test') + # check evaluation loop configurations + self.__verify_eval_loop_configuration(model, self.trainer.evaluating) def __verify_train_loop_configuration(self, model): # ----------------------------------- diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index cab08edd58531..33ff30380eabb 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -265,7 +265,7 @@ def prepare_eval_loop_results(self): for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): self.add_to_eval_loop_results(dl_idx, has_been_initialized) - def get_evaluate_epoch_results(self, test_mode): + def get_evaluate_epoch_results(self): if not self.trainer.running_sanity_check: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() @@ -274,11 +274,11 @@ def get_evaluate_epoch_results(self, test_mode): self.prepare_eval_loop_results() - # log results of test - if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: + # log results of evaluation + if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): - print(f'DATALOADER:{result_idx} TEST RESULTS') + print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS') pprint(results) print('-' * 80) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index c5a8c48357b44..5b4022488a5a8 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -36,7 +36,11 @@ def copy_trainer_model_properties(self, model): m.use_ddp2 = self.trainer.use_ddp2 m.use_ddp = self.trainer.use_ddp m.use_amp = self.trainer.amp_backend is not None - m.testing = self.trainer.testing + # Currently, the only users of m.testing appear to be DP and DDP, + # which use it to determine whether the model is currently inside + # the validation or test loop. For this reason it must check if + # trainer.evaluating is equal to "test" specifically. + m.testing = self.trainer.evaluating == 'test' m.use_single_gpu = self.trainer.use_single_gpu m.use_tpu = self.trainer.use_tpu m.tpu_local_core_rank = self.trainer.tpu_local_core_rank diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 097727a6bed78..11da428b83453 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +import pytorch_lightning as pl from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -22,7 +23,7 @@ class EvaluationLoop(object): - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer'): self.trainer = trainer self.testing = False self.outputs = [] @@ -39,13 +40,15 @@ def on_trainer_init(self): self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False - self.trainer.testing = False - # when .test() is called, it sets this - self.trainer.tested_ckpt_path = None + # .validate() sets this to 'validation' and .test() sets this to 'test' + self.trainer.evaluating = None - # when true, prints test results - self.trainer.verbose_test = True + # .validate() and .test() set this when they load a checkpoint + self.trainer.evaluated_ckpt_path = None + + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self, max_batches): # select dataloaders @@ -216,7 +219,7 @@ def evaluation_epoch_end(self): def log_epoch_metrics_on_evaluation_end(self): # get the final loop results - eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing) + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fd715988ef370..d7b9361f9f5a1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -441,10 +441,6 @@ def fit( # hook self.data_connector.prepare_data(model) - # bookkeeping - # we reuse fit in .test() but change its behavior using this flag - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -720,33 +716,31 @@ def test( datamodule: Optional[LightningDataModule] = None, ): r""" - - Separates from fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the test set. It's separated from + fit to make sure you never run on your test set until you want to. Args: ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the weights from the last epoch to test. Default to ``best``. - + If ``None``, use the current weights of the model. Default to ``best``. datamodule: A instance of :class:`LightningDataModule`. - - model: The model to test. - - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. - - verbose: If True, prints the test results + model: The model to evaluate. + test_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying test samples. + verbose: If True, prints the test results. Returns: - The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + The dictionary with final test results returned by test_epoch_end. + If test_epoch_end is not defined, the output is a list of the dictionaries + returned by test_step. """ # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose + self.verbose_evaluate = verbose self.logger_connector.set_stage("test") - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' @@ -756,15 +750,15 @@ def test( self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') if model is not None: - results = self.__test_given_model(model, test_dataloaders) + results = self.__evaluate_given_model(model, test_dataloaders, 'test') else: - results = self.__test_using_best_weights(ckpt_path, test_dataloaders) + results = self.__evaluate_using_best_weights(ckpt_path, test_dataloaders, 'test') self.teardown('test') return results - def __test_using_best_weights(self, ckpt_path, test_dataloaders): + def __evaluate_using_best_weights(self, ckpt_path, test_dataloaders, stage: str): model = self.get_model() # if user requests the best checkpoint but we don't have it, error @@ -796,22 +790,20 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) # run tests - self.tested_ckpt_path = ckpt_path - self.testing = True - os.environ['PL_TESTING_MODE'] = '1' + self.evaluating = stage + self.evaluated_ckpt_path = ckpt_path self.model = model results = self.fit(model) - self.testing = False - del os.environ['PL_TESTING_MODE'] + self.evaluating = None # teardown if self.is_function_implemented('teardown'): model_ref = self.get_model() - model_ref.teardown('test') + model_ref.teardown(stage) return results - def __test_given_model(self, model, test_dataloaders): + def __evaluate_given_model(self, model, test_dataloaders, stage: str): # attach data if test_dataloaders is not None: @@ -819,17 +811,35 @@ def __test_given_model(self, model, test_dataloaders): # run test # sets up testing so we short circuit to eval - self.testing = True + self.evaluating = stage self.model = model results = self.fit(model) - self.testing = False + self.evaluating = None # teardown if self.is_function_implemented('teardown'): - model.teardown('test') + model.teardown(stage) return results + @property + def testing(self): + warnings.warn( + 'Trainer.testing has been deprecated in v1.1 and will be removed ' + 'in v1.3, use Trainer.evaluating instead.', + DeprecationWarning, stacklevel=2 + ) + return bool(self.evaluating) + + @property + def tested_ckpt_path(self): + warnings.warn( + 'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path ' + 'in v1.1 and will be removed in v1.3.', + DeprecationWarning, stacklevel=2 + ) + return self.evaluated_ckpt_path + def tune( self, model: LightningModule, @@ -856,11 +866,17 @@ def tune( def call_setup_hook(self, model): # call setup after the ddp process has connected - stage_name = 'test' if self.testing else 'fit' + stage_name = self.evaluating or 'fit' + if self.datamodule is not None: - called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit + called = { + None: self.datamodule.has_setup_fit, + 'test': self.datamodule.has_setup_test, + }[self.evaluating] + if not called: self.datamodule.setup(stage_name) + self.setup(model, stage_name) model.setup(stage_name) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9a4f324033d39..ff19b9b8a9858 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -161,7 +161,7 @@ def setup_training(self, model: LightningModule): ref_model.on_pretrain_routine_start() # print model summary - if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: + if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.evaluating: if self.trainer.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.trainer.weights_summary) else: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 085d361952844..ea1bfc3a75605 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -728,12 +728,12 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): trainer.test(ckpt_path=ckpt_path) else: trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path is None + assert trainer.evaluated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -746,7 +746,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): ].absolute() ) trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + assert trainer.evaluated_ckpt_path == ckpt_path def test_disabled_training(tmpdir): From 5a54485fe8bd8d3d2ca1026c4b2ba5b0555f3a1e Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 2 Dec 2020 16:43:20 +0100 Subject: [PATCH 02/32] Add Trainer.validate(...) method to perform one evaluation epoch over the validation set --- pytorch_lightning/callbacks/progress.py | 22 ++++++-- pytorch_lightning/core/datamodule.py | 15 ++++- pytorch_lightning/trainer/trainer.py | 73 ++++++++++++++++++++++--- tests/base/datamodules.py | 4 +- 4 files changed, 99 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 6582f16fd27be..b00dca548671f 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -282,9 +282,13 @@ def init_train_tqdm(self) -> tqdm: def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ + + # The main progress bar doesn't exist in trainer.validate(...) + has_main_bar = int(self.main_progress_bar is not None) + bar = tqdm( desc='Validating', - position=(2 * self.process_position + 1), + position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, @@ -341,7 +345,10 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) if not trainer.running_sanity_check: - self._update_bar(self.main_progress_bar) # fill up remaining + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar is not None: + self._update_bar(self.main_progress_bar) # fill up remaining + self.val_progress_bar = self.init_validation_tqdm() self.val_progress_bar.total = convert_inf(self.total_val_batches) @@ -349,11 +356,18 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx, self.total_val_batches): self._update_bar(self.val_progress_bar) - self._update_bar(self.main_progress_bar) + + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar is not None: + self._update_bar(self.main_progress_bar) def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index fe81d641c86d6..3ff9f4cf889d4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -76,13 +76,16 @@ def wrapped_fn(*args, **kwargs): if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit' and 'test' to True. + # If not provided, set call status of 'fit', 'validation', and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() stage = args[1] if len(args) > 1 else kwargs.get("stage", None) if stage == "fit" or stage is None: obj._has_setup_fit = True + if stage == "validation" or stage is None: + obj._has_setup_validation = True + if stage == "test" or stage is None: obj._has_setup_test = True @@ -155,6 +158,7 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False self._has_setup_fit = False + self._has_setup_validation = False self._has_setup_test = False @property @@ -230,6 +234,15 @@ def has_setup_fit(self): """ return self._has_setup_fit + @property + def has_setup_validation(self): + """Return bool letting you know if datamodule.setup('validation') has been called or not. + + Returns: + bool: True if datamodule.setup('validation') has been called. False by default. + """ + return self._has_setup_validation + @property def has_setup_test(self): """Return bool letting you know if datamodule.setup('test') has been called or not. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5aeb3aba3d31c..a335c49beae71 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -658,8 +658,12 @@ def track_output_for_epoch_end(self, outputs, output): def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) - with self.profiler.profile("run_test_evaluation"): - eval_loop_results, _ = self.run_evaluation(test_mode=True) + if self.evaluating == 'test': + with self.profiler.profile("run_test_evaluation"): + eval_loop_results, _ = self.run_evaluation(test_mode=True) + else: + with self.profiler.profile("run_validate_evaluation"): + eval_loop_results, _ = self.run_evaluation(test_mode=False) if len(eval_loop_results) == 0: return 1 @@ -707,6 +711,56 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_end() self.running_sanity_check = False + def validate( + self, + model: Optional[LightningModule] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Perform one evaluation epoch over the validation set. + + Args: + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. Default to ``best``. + datamodule: A instance of :class:`LightningDataModule`. + model: The model to evaluate. + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. + verbose: If True, prints the validation results. + + Returns: + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. + """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_evaluate = verbose + + self.logger_connector.set_stage("validation") + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass val_dataloaders to trainer.validate if you supply a datamodule' + ) + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'validation') + + if model is not None: + results = self.__evaluate_given_model(model, val_dataloaders, 'validation') + else: + results = self.__evaluate_using_best_weights(ckpt_path, val_dataloaders, 'validation') + + self.teardown('validation') + + return results + def test( self, model: Optional[LightningModule] = None, @@ -758,7 +812,7 @@ def test( return results - def __evaluate_using_best_weights(self, ckpt_path, test_dataloaders, stage: str): + def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str): model = self.get_model() # if user requests the best checkpoint but we don't have it, error @@ -786,8 +840,9 @@ def __evaluate_using_best_weights(self, ckpt_path, test_dataloaders, stage: str) model.load_state_dict(ckpt['state_dict']) # attach dataloaders - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders} + self.data_connector.attach_dataloaders(model, **kwargs) # run tests self.evaluating = stage @@ -803,11 +858,12 @@ def __evaluate_using_best_weights(self, ckpt_path, test_dataloaders, stage: str) return results - def __evaluate_given_model(self, model, test_dataloaders, stage: str): + def __evaluate_given_model(self, model, dataloaders, stage: str): # attach data - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders} + self.data_connector.attach_dataloaders(model, **kwargs) # run test # sets up testing so we short circuit to eval @@ -871,6 +927,7 @@ def call_setup_hook(self, model): if self.datamodule is not None: called = { None: self.datamodule.has_setup_fit, + 'validation': self.datamodule.has_setup_validation, 'test': self.datamodule.has_setup_test, }[self.evaluating] diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index e4d0b4bff89d7..94e4ba9c1efe9 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -33,7 +33,7 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): - if stage == "fit" or stage is None: + if stage != 'test': mnist_full = TrialMNIST( root=self.data_dir, train=True, num_samples=64, download=True ) @@ -88,7 +88,7 @@ def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders # TODO: need to split using random_split once updated to torch >= 1.6 - if stage == "fit" or stage is None: + if stage != 'test': self.mnist_train = MNIST( self.data_dir, train=True, normalize=(0.1307, 0.3081) ) From e06775c02112059e2f32233b0972be836f86a5a0 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 2 Dec 2020 16:45:05 +0100 Subject: [PATCH 03/32] =?UTF-8?q?Rename=20methods=20in=20Trainer=20and=20A?= =?UTF-8?q?ccelerator=20to=20reflect=20that=20they=20are=20used=20by=20bot?= =?UTF-8?q?h=20test(=E2=80=A6)=20and=20validate(=E2=80=A6)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/accelerators/accelerator.py | 4 ++-- pytorch_lightning/accelerators/cpu_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp2_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 4 ++-- pytorch_lightning/accelerators/dp_accelerator.py | 4 ++-- pytorch_lightning/accelerators/gpu_accelerator.py | 5 +++-- pytorch_lightning/accelerators/horovod_accelerator.py | 4 ++-- pytorch_lightning/accelerators/tpu_accelerator.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 2 +- 12 files changed, 24 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index eae8a20fda471..600d074a1cb55 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -59,9 +59,9 @@ def barrier(self, name: Optional[str] = None): def broadcast(self, obj, src=0): return obj - def train_or_test(self): + def train_or_evaluate(self): if self.trainer.evaluating: - results = self.trainer.run_test() + results = self.trainer.run_test_or_validate() else: results = self.trainer.train() return results diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index fe0ab59fb554f..279b6327bba5a 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -57,8 +57,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index f43866881cabb..0acc5d6b65339 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -181,8 +181,8 @@ def ddp_train(self, process_idx, mp_queue, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 687b5c21874fb..90347a60a4566 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -275,8 +275,8 @@ def ddp_train(self, process_idx, model): self.barrier('ddp_setup') self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 982da2f53216b..879ad3cdb8b74 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -145,8 +145,8 @@ def ddp_train(self, process_idx, mp_queue, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 28817c6845f5b..316fac61ca732 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -174,8 +174,8 @@ def ddp_train(self, process_idx, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a06d0b82d6d15..b871f6cbf0c6d 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -157,8 +157,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 4b4e1eac8a66c..214b4d88f03aa 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -106,8 +106,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() return results diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index b12d275c8ac26..e3f0fb9890809 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -62,8 +62,9 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() + return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index b2cec906178f9..d4027c772e061 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -111,8 +111,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(self.trainer.model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 6da5150d1fa8a..b38bc27c2b391 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -129,8 +129,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # save weights at the end of training self.__save_end_of_training_weights(model, trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a335c49beae71..e2758254f0419 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -655,7 +655,7 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs - def run_test(self): + def run_test_or_validate(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) if self.evaluating == 'test': From b4e409c7e70052cd3028b57d385e459150e9fdb4 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 2 Dec 2020 16:47:22 +0100 Subject: [PATCH 04/32] Update docs to mention the new Trainer.validate method and associated hook and callback calls --- docs/source/trainer.rst | 15 ++++++++++++++- pytorch_lightning/callbacks/base.py | 4 ++-- pytorch_lightning/core/hooks.py | 8 ++++---- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 79d6284a4e27c..a69a327722ab3 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -148,6 +148,19 @@ So you can run it like so: ------------ +Validation +---------- +You can perform an evaluation epoch over the validation set, outside of the training loop, +using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model right at its initialization +or that has already been trained. + +.. code-block:: python + + trainer.validate(val_dataloaders=val_dataloaders) + +------------ + Testing ------- Once you're done training, feel free to run the test set! @@ -155,7 +168,7 @@ Once you're done training, feel free to run the test set! .. code-block:: python - trainer.test(test_dataloader=test_dataloader) + trainer.test(test_dataloaders=test_dataloaders) ------------ diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 3f6b4ffe9622a..8ca0ef301c260 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -28,11 +28,11 @@ class Callback(abc.ABC): """ def setup(self, trainer, pl_module, stage: str): - """Called when fit or test begins""" + """Called when fit, validate, or test begins""" pass def teardown(self, trainer, pl_module, stage: str): - """Called when fit or test ends""" + """Called when fit, validate, or test ends""" pass def on_init_start(self, trainer): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 57979b73f2cb6..a4251484991f2 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -26,12 +26,12 @@ class ModelHooks: """Hooks to be used in LightningModule.""" def setup(self, stage: str): """ - Called at the beginning of fit and test. + Called at the beginning of fit (training + validation), validation, and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either 'fit' or 'test' + stage: either 'fit', 'validation', or 'test' Example:: @@ -54,10 +54,10 @@ def setup(stage): def teardown(self, stage: str): """ - Called at the end of fit and test. + Called at the end of fit (training + validation), validation, and test. Args: - stage: either 'fit' or 'test' + stage: either 'fit', 'validation', or 'test' """ def on_fit_start(self): From 96e42ba08c9e430496dfdd5510543d71f770c710 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 2 Dec 2020 16:47:49 +0100 Subject: [PATCH 05/32] =?UTF-8?q?Add=20tests=20for=20Trainer.validate(?= =?UTF-8?q?=E2=80=A6)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/backends/test_dp.py | 18 ++++- tests/callbacks/test_callbacks.py | 22 ++++++ tests/callbacks/test_progress_bar.py | 30 +++++++- tests/checkpointing/test_model_checkpoint.py | 8 +++ tests/core/test_datamodules.py | 56 +++++++++++++++ tests/trainer/test_config_validator.py | 37 +++++++--- tests/trainer/test_dataloaders.py | 42 +++++++++++ tests/trainer/test_optimizers.py | 18 +++++ tests/trainer/test_states.py | 40 ++++++++++- tests/trainer/test_trainer.py | 45 ++++++++++++ tests/trainer/test_trainer_validate_loop.py | 76 ++++++++++++++++++++ 11 files changed, 378 insertions(+), 14 deletions(-) create mode 100644 tests/trainer/test_trainer_validate_loop.py diff --git a/tests/backends/test_dp.py b/tests/backends/test_dp.py index c051b442cb7a7..b697440280f80 100644 --- a/tests/backends/test_dp.py +++ b/tests/backends/test_dp.py @@ -67,7 +67,7 @@ def test_multi_gpu_model_dp(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_dp_test(tmpdir): +def test_dp_evaluate(tmpdir): tutils.set_random_master_port() import os @@ -84,6 +84,22 @@ def test_dp_test(tmpdir): ) trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path + + # validate + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + + # test results = trainer.test() assert 'test_acc' in results[0] diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index bb740b1dcbb1c..6f427afef7728 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -101,6 +101,28 @@ def test_trainer_callback_system(torch_save): call.teardown(trainer, model, 'fit'), ] + callback_mock.reset_mock() + trainer = Trainer(**trainer_options) + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.setup(trainer, model, 'validation'), + call.on_fit_start(trainer, model), + call.on_pretrain_routine_start(trainer, model), + call.on_pretrain_routine_end(trainer, model), + call.on_validation_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.on_fit_end(trainer, model), + call.teardown(trainer, model, 'fit'), + call.teardown(trainer, model, 'validation'), + ] + callback_mock.reset_mock() trainer = Trainer(**trainer_options) trainer.test(model) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 3c19748765e52..988da6f233dd2 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -84,7 +84,7 @@ def test_progress_bar_totals(tmpdir): limit_val_batches=1.0, max_epochs=1, ) - bar = trainer.progress_bar_callback + bar: ProgressBar = trainer.progress_bar_callback assert 0 == bar.total_train_batches assert 0 == bar.total_val_batches assert 0 == bar.total_test_batches @@ -113,6 +113,17 @@ def test_progress_bar_totals(tmpdir): assert 0 == bar.total_test_batches assert bar.test_progress_bar is None + trainer.validate(model) + + # check validation progress bar total + k = bar.total_val_batches + assert sum(len(loader) for loader in trainer.val_dataloaders) == k + assert bar.val_progress_bar.total == k + + # validation progress bar should have reached the end + assert bar.val_progress_bar.n == k + assert bar.val_batch_idx == k + trainer.test(model) # check test progress bar total @@ -135,7 +146,7 @@ def test_progress_bar_fast_dev_run(tmpdir): trainer.fit(model) - progress_bar = trainer.progress_bar_callback + progress_bar: ProgressBar = trainer.progress_bar_callback assert 1 == progress_bar.total_train_batches # total val batches are known only after val dataloaders have reloaded @@ -150,6 +161,13 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == progress_bar.val_batch_idx + assert 1 == progress_bar.val_progress_bar.total + assert 1 == progress_bar.val_progress_bar.n + trainer.test(model) # the test progress bar should display 1 batch @@ -207,8 +225,16 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 + + trainer.validate(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 trainer.test(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps assert progress_bar.test_batches_seen == progress_bar.total_test_batches diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 33bc19a894d8f..e3e6dfe4ceddc 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -797,6 +797,9 @@ def get_model(): assert trainer.current_epoch == epochs - 1 assert_checkpoint_log_dir(0) + trainer.validate(model) + assert trainer.current_epoch == epochs - 1 + trainer.test(model) assert trainer.current_epoch == epochs - 1 @@ -817,6 +820,11 @@ def get_model(): ) assert_trainer_init(trainer) + trainer.validate(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + trainer.test(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3e683025e8867..32f4aebe445d4 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -111,6 +111,7 @@ def test_base_datamodule_with_verbose_setup(tmpdir): dm = TrialMNISTDataModule() dm.prepare_data() dm.setup('fit') + dm.setup('validation') dm.setup('test') @@ -118,16 +119,19 @@ def test_data_hooks_called(tmpdir): dm = TrialMNISTDataModule() assert dm.has_prepared_data is False assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.prepare_data() assert dm.has_prepared_data is True assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.setup() assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -135,21 +139,31 @@ def test_data_hooks_called_verbose(tmpdir): dm = TrialMNISTDataModule() assert dm.has_prepared_data is False assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.prepare_data() assert dm.has_prepared_data is True assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.setup('fit') assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is False + assert dm.has_setup_test is False + + dm.setup('validation') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is False dm.setup('test') assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -160,10 +174,17 @@ def test_data_hooks_called_with_stage_kwarg(tmpdir): dm.setup(stage='fit') assert dm.has_setup_fit is True + assert dm.has_setup_validation is False + assert dm.has_setup_test is False + + dm.setup(stage='validation') + assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is False dm.setup(stage='test') assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -254,6 +275,21 @@ def test_dm_checkpoint_save(tmpdir): assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ +def test_validate_loop_only(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + ) + trainer.validate(model, datamodule=dm) + + def test_test_loop_only(tmpdir): reset_seed() @@ -287,6 +323,11 @@ def test_full_loop(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] @@ -312,6 +353,11 @@ def test_trainer_attached_to_dm(tmpdir): assert result == 1 assert dm.trainer is not None + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert dm.trainer is not None + # test result = trainer.test(datamodule=dm) result = result[0] @@ -338,6 +384,11 @@ def test_full_loop_single_gpu(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] @@ -365,6 +416,11 @@ def test_full_loop_dp(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 1ab97304f2338..b724fc8587e24 100755 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -19,9 +19,6 @@ from tests.base import EvalModelTemplate -# TODO: add matching messages - - def test_wrong_train_setting(tmpdir): """ * Test that an error is thrown when no `train_dataloader()` is defined @@ -31,12 +28,12 @@ def test_wrong_train_setting(tmpdir): hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'): model = EvalModelTemplate(**hparams) model.train_dataloader = None trainer.fit(model) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'): model = EvalModelTemplate(**hparams) model.training_step = None trainer.fit(model) @@ -47,7 +44,7 @@ def test_wrong_configure_optimizers(tmpdir): tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'): model = EvalModelTemplate() model.configure_optimizers = None trainer.fit(model) @@ -62,13 +59,13 @@ def test_val_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): model = EvalModelTemplate(**hparams) model.validation_step = None trainer.fit(model) # has val loop but no val data - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): model = EvalModelTemplate(**hparams) model.val_dataloader = None trainer.fit(model) @@ -82,13 +79,33 @@ def test_test_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'): model = EvalModelTemplate(**hparams) model.test_dataloader = None trainer.test(model) # has test data but no test loop - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'): model = EvalModelTemplate(**hparams) model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) + + +def test_validation_loop_config(tmpdir): + """" + When either validation loop or validation data are missing + """ + hparams = EvalModelTemplate.get_default_hparams() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has val loop but no val data + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = EvalModelTemplate(**hparams) + model.val_dataloader = None + trainer.validate(model) + + # has val data but no val loop + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = EvalModelTemplate(**hparams) + model.validation_step = None + trainer.validate(model, val_dataloaders=model.dataloader(train=False)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f16ef22faa507..d0b838b5fbf45 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -172,6 +172,48 @@ def test_step(self, batch, batch_idx, *args, **kwargs): trainer.test(ckpt_path=ckpt_path) +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_multiple_validate_dataloader(tmpdir, ckpt_path): + """Verify multiple val_dataloaders.""" + + model_template = EvalModelTemplate() + + class MultipleValDataloaderModel(EvalModelTemplate): + def val_dataloader(self): + return model_template.val_dataloader__multiple() + + def validation_step(self, batch, batch_idx, *args, **kwargs): + return model_template.validation_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs) + + def validation_epoch_end(self, outputs): + return model_template.validation_epoch_end__multiple_dataloaders(outputs) + + model = MultipleValDataloaderModel() + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + trainer.fit(model) + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer.validate(ckpt_path=ckpt_path) + + # verify there are 2 test loaders + assert len(trainer.val_dataloaders) == 2, \ + 'Multiple val_dataloaders not initiated properly' + + # make sure predictions are good for each test set + for dataloader in trainer.val_dataloaders: + tpipes.run_prediction(dataloader, trainer.model) + + # run the validate method + trainer.validate(ckpt_path=ckpt_path) + + def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 2e76192836740..27f0bcda66926 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -337,6 +337,24 @@ def test_init_optimizers_during_testing(tmpdir): assert len(trainer.optimizer_frequencies) == 0 +def test_init_optimizers_during_validation(tmpdir): + """ + Test that optimizers is an empty list during validation. + """ + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__multiple_schedulers + + trainer = Trainer( + default_root_dir=tmpdir, + limit_test_batches=10 + ) + trainer.validate(model, ckpt_path=None) + + assert len(trainer.lr_schedulers) == 0 + assert len(trainer.optimizers) == 0 + assert len(trainer.optimizer_frequencies) == 0 + + def test_multiple_optimizers_callbacks(tmpdir): """ Tests that multiple optimizers can be used with callbacks diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 0244f654227a2..f6e29b7187d61 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -23,7 +23,7 @@ class StateSnapshotCallback(Callback): def __init__(self, snapshot_method: str): super().__init__() - assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] + assert snapshot_method in ['on_batch_start', 'on_validation_batch_start', 'on_test_batch_start'] self.snapshot_method = snapshot_method self.trainer_state = None @@ -31,6 +31,10 @@ def on_batch_start(self, trainer, pl_module): if self.snapshot_method == 'on_batch_start': self.trainer_state = trainer.state + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + if self.snapshot_method == 'on_validation_batch_start': + self.trainer_state = trainer.state + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): if self.snapshot_method == 'on_test_batch_start': self.trainer_state = trainer.state @@ -191,6 +195,40 @@ def test_finished_state_after_test(tmpdir): assert trainer.state == TrainerState.FINISHED +def test_running_state_during_validation(tmpdir): + """ Tests that state is set to RUNNING during test """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback(snapshot_method='on_validation_batch_start') + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.validate(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + +def test_finished_state_after_validation(tmpdir): + """ Tests that state is FINISHED after fit """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.validate(model) + + assert trainer.state == TrainerState.FINISHED + + @pytest.mark.parametrize("extra_params", [ pytest.param(dict(fast_dev_run=True), id='Fast-Run'), pytest.param(dict(max_steps=1), id='Single-Step'), diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9fac8f3cccf91..3a6543a8389da 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -749,6 +749,47 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): assert trainer.evaluated_ckpt_path == ckpt_path +@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) +@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) +def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k): + hparams = EvalModelTemplate.get_default_hparams() + + model = EvalModelTemplate(**hparams) + trainer = Trainer( + max_epochs=2, + progress_bar_refresh_rate=0, + default_root_dir=tmpdir, + checkpoint_callback=ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k), + ) + trainer.fit(model) + if ckpt_path == "best": + # ckpt_path is 'best', meaning we load the best weights + if save_top_k == 0: + with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): + trainer.validate(ckpt_path=ckpt_path) + else: + trainer.validate(ckpt_path=ckpt_path) + assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path + elif ckpt_path is None: + # ckpt_path is None, meaning we don't load any checkpoints and + # use the weights from the end of training + trainer.validate(ckpt_path=ckpt_path) + assert trainer.evaluated_ckpt_path is None + else: + # specific checkpoint, pick one from saved ones + if save_top_k == 0: + with pytest.raises(FileNotFoundError): + trainer.validate(ckpt_path="random.ckpt") + else: + ckpt_path = str( + list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir())[ + 0 + ].absolute() + ) + trainer.validate(ckpt_path=ckpt_path) + assert trainer.evaluated_ckpt_path == ckpt_path + + def test_disabled_training(tmpdir): """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" @@ -1425,6 +1466,10 @@ def setup(self, model, stage): assert trainer.stage == "test" assert trainer.get_model().stage == "test" + trainer.validate(ckpt_path=None) + assert trainer.stage == "validation" + assert trainer.get_model().stage == "validation" + @pytest.mark.parametrize( "train_batches, max_steps, log_interval", diff --git a/tests/trainer/test_trainer_validate_loop.py b/tests/trainer/test_trainer_validate_loop.py new file mode 100644 index 0000000000000..a2205a4b50dc2 --- /dev/null +++ b/tests/trainer/test_trainer_validate_loop.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +import pytorch_lightning as pl +import tests.base.develop_utils as tutils +from tests.base import EvalModelTemplate + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_single_gpu_validate(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0], + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_ddp_spawn_validate(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) From 85b3c9fe2593d848adc0358c08af8ecfd404c426 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 2 Dec 2020 16:47:59 +0100 Subject: [PATCH 06/32] Update CHANGELOG.md --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 241424deae720..46c2aff62c79b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Pytorch Geometric` integration example with Lightning ([#4568](https://github.com/PyTorchLightning/pytorch-lightning/pull/4568)) +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ( + [#4707](https://github.com/PyTorchLightning/pytorch-lightning/pull/4707)) + + ### Changed - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) From a6be0d888875836bb3800b3cbae057b7472f7c53 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 4 Dec 2020 19:30:25 +0100 Subject: [PATCH 07/32] Replace usages of Trainer.testing with Trainer.evaluating, should be the last in the codebase --- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 4 ++-- pytorch_lightning/accelerators/tpu_accelerator.py | 6 +++--- pytorch_lightning/plugins/sharded_plugin.py | 2 +- tests/base/model_train_steps.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a06d0b82d6d15..b538e6c67dad0 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -223,7 +223,7 @@ def __recover_child_process_weights(self, model, best_path, last_path): # todo, pass also best score # load last weights - if last_path is not None and not self.trainer.testing: + if last_path is not None and not self.trainer.evaluating: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) @@ -242,7 +242,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): # save the last weights last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + if not self.trainer.evaluating and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) atomic_save(model.state_dict(), last_path) mp_queue.put(last_path) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 6da5150d1fa8a..edc415ee4118c 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -81,7 +81,7 @@ def teardown(self): # todo, pass also bets score # load last weights - if last_path and not self.trainer.testing: + if last_path and not self.trainer.evaluating: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) @@ -109,7 +109,7 @@ def __load_weights_on_main_process(self): model = self.trainer.model # load weights if not interrupted - if self.trainer.on_colab_kaggle and not self.trainer.testing: + if self.trainer.on_colab_kaggle and not self.trainer.evaluating: self.load_spawn_weights(model) self.trainer.model = model @@ -342,7 +342,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): # save the last weights last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + if not self.trainer.evaluating and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) state_dict = move_data_to_device(model.state_dict(), torch.device("cpu")) atomic_save(state_dict, last_path) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index f8a793af85310..cba7f21df8c1d 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -57,7 +57,7 @@ def _optim_state_dict(self, optimizer): def _wrap_optimizers(self, model): trainer = model.trainer - if trainer.testing is True: + if trainer.evaluating is True: return self._reinit_with_fairscale_oss(trainer) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index caec6db9aaa10..4b78f79a08f2e 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -146,7 +146,7 @@ def eval_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None loss_val = y_hat.sum() result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val) - eval_name = 'validation' if not self.trainer.testing else 'test' + eval_name = 'test' if self.trainer.evaluating == 'test' else 'validation' result.log(f'{eval_name}_step_metric', loss_val + 1, on_step=True) setattr(self, f'{eval_name}_step_called', True) @@ -157,7 +157,7 @@ def eval_step_end_full_loop_result_obj_dp(self, result): """ Full loop flow train step (result obj + dp) """ - eval_name = 'validation' if not self.trainer.testing else 'test' + eval_name = 'test' if self.trainer.evaluating == 'test' else 'validation' reduced = getattr(result, f'{eval_name}_step_metric_step').mean() setattr(result, f'{eval_name}_step_metric_step', reduced) @@ -178,7 +178,7 @@ def eval_epoch_end_full_loop_result_obj_dp(self, result): """ Full loop flow train step (result obj + dp) """ - eval_name = 'validation' if not self.trainer.testing else 'test' + eval_name = 'test' if self.trainer.evaluating == 'test' else 'validation' result.log(f'{eval_name}_epoch_end_metric', torch.tensor(1).type_as(result.checkpoint_on), on_epoch=True) result.checkpoint_on = result.checkpoint_on.mean() result.early_stop_on = result.early_stop_on.mean() From 595f4e8e331f3eb0b15fd4b0ceaca63f682a7bb2 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 8 Dec 2020 12:30:26 +0100 Subject: [PATCH 08/32] Clean up calls to LightningDataModule.setup() --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e2758254f0419..b2854ae971a32 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -926,10 +926,10 @@ def call_setup_hook(self, model): if self.datamodule is not None: called = { - None: self.datamodule.has_setup_fit, + 'fit': self.datamodule.has_setup_fit, 'validation': self.datamodule.has_setup_validation, 'test': self.datamodule.has_setup_test, - }[self.evaluating] + }[stage_name] if not called: self.datamodule.setup(stage_name) From 0b0924884ed77d6b4b60f48720e9e00a096eeeb7 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 8 Dec 2020 13:00:52 +0100 Subject: [PATCH 09/32] Update test_trainer_validate_loop.py to use BoringModel instead of EvalModelTemplate --- tests/trainer/test_trainer_validate_loop.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/trainer/test_trainer_validate_loop.py b/tests/trainer/test_trainer_validate_loop.py index a2205a4b50dc2..ec8dd82260e3f 100644 --- a/tests/trainer/test_trainer_validate_loop.py +++ b/tests/trainer/test_trainer_validate_loop.py @@ -16,14 +16,14 @@ import pytorch_lightning as pl import tests.base.develop_utils as tutils -from tests.base import EvalModelTemplate +from tests.base import BoringModel @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_single_gpu_validate(tmpdir): tutils.set_random_master_port() - model = EvalModelTemplate() + model = BoringModel() trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=2, @@ -34,12 +34,12 @@ def test_single_gpu_validate(tmpdir): trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path results = trainer.validate() - assert 'val_acc' in results[0] + assert 'x' in results[0] old_weights = model.c_d1.weight.clone().detach().cpu() results = trainer.validate(model) - assert 'val_acc' in results[0] + assert 'x' in results[0] # make sure weights didn't change new_weights = model.c_d1.weight.clone().detach().cpu() @@ -51,7 +51,7 @@ def test_single_gpu_validate(tmpdir): def test_ddp_spawn_validate(tmpdir): tutils.set_random_master_port() - model = EvalModelTemplate() + model = BoringModel() trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=2, @@ -63,12 +63,12 @@ def test_ddp_spawn_validate(tmpdir): trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path results = trainer.validate() - assert 'val_acc' in results[0] + assert 'x' in results[0] old_weights = model.c_d1.weight.clone().detach().cpu() results = trainer.validate(model) - assert 'val_acc' in results[0] + assert 'x' in results[0] # make sure weights didn't change new_weights = model.c_d1.weight.clone().detach().cpu() From 06b4419665a2b994db7adf52589e896b9b1c6a5d Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 8 Dec 2020 15:35:43 +0100 Subject: [PATCH 10/32] Fix ShardedPlugin when evaluating MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trainer.evaluating should not be compared with `is True`, since it’s either None when not in evaluation mode, ‘validation’ or ‘test’. --- pytorch_lightning/plugins/sharded_plugin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 3faf33dd49cf5..b1bc26a1b185c 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -57,7 +57,8 @@ def _optim_state_dict(self, optimizer): def _wrap_optimizers(self, model): trainer = model.trainer - if trainer.evaluating is True: + + if trainer.evaluating: return self._reinit_with_fairscale_oss(trainer) From 389940e85f9cfa833591e9588299e7a25b8eb150 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 8 Dec 2020 15:36:19 +0100 Subject: [PATCH 11/32] Add tests for Trainer.validate with ShardedPlugin --- tests/plugins/test_sharded_plugin.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 5010c39de7a80..f89c10eaf6e29 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -301,9 +301,9 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_test(tmpdir): +def test_ddp_sharded_plugin_evaluate(tmpdir): """ - Test to ensure we can use test without fit + Test to ensure we can use validate and test without fit """ model = BoringModel() trainer = Trainer( @@ -312,6 +312,7 @@ def test_ddp_sharded_plugin_test(tmpdir): fast_dev_run=True, ) + trainer.validate(model) trainer.test(model) @@ -319,9 +320,9 @@ def test_ddp_sharded_plugin_test(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_test_multigpu(tmpdir): +def test_ddp_sharded_plugin_evaluate_multigpu(tmpdir): """ - Test to ensure we can use test without fit + Test to ensure we can use validate and test without fit """ model = BoringModel() trainer = Trainer( @@ -331,4 +332,5 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir): fast_dev_run=True, ) + trainer.validate(model) trainer.test(model) From 6d0a95a97f302984713064f2ff78d657f33635b8 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Thu, 10 Dec 2020 12:31:57 +0100 Subject: [PATCH 12/32] Remove superfluous calls to LoggerConnector.set_stage in validate() and test() The stage set here is being overwritten in `Trainer.run_evaluation` anyway --- pytorch_lightning/trainer/trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6973ec2c113b0..ad4df27d7dfb6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -731,8 +731,6 @@ def validate( # -------------------- self.verbose_evaluate = verbose - self.logger_connector.set_stage("validation") - # If you supply a datamodule you can't supply val_dataloaders if val_dataloaders and datamodule: raise MisconfigurationException( @@ -782,8 +780,6 @@ def test( # -------------------- self.verbose_evaluate = verbose - self.logger_connector.set_stage("test") - # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( From 704b121e5e39749728af4da5546844a8a90d6754 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Thu, 10 Dec 2020 12:39:09 +0100 Subject: [PATCH 13/32] Update more docstrings to mention Trainer.validate --- pytorch_lightning/trainer/callback_hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 7228791aef7da..e9ec20f33d9bb 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -27,12 +27,12 @@ class TrainerCallbackHookMixin(ABC): get_model: Callable def setup(self, model, stage: str): - """Called in the beginning of fit and test""" + """Called in the beginning of fit, validate and test""" for callback in self.callbacks: callback.setup(self, model, stage) def teardown(self, stage: str): - """Called at the end of fit and test""" + """Called at the end of fit, validate and test""" for callback in self.callbacks: callback.teardown(self, self.get_model(), stage) From 12a85b3f74d300fe673ca70090fc6e629ae6bd3a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 01:54:23 +0100 Subject: [PATCH 14/32] Pass {fit,validate,test,predict} to setup() --- pytorch_lightning/callbacks/base.py | 4 +- pytorch_lightning/core/datamodule.py | 58 ++++++++----- pytorch_lightning/core/hooks.py | 8 +- pytorch_lightning/trainer/callback_hook.py | 24 +++--- pytorch_lightning/trainer/model_hooks.py | 6 -- pytorch_lightning/trainer/states.py | 16 ++-- pytorch_lightning/trainer/trainer.py | 57 ++++++------- tests/callbacks/test_callbacks.py | 35 ++++---- tests/core/test_datamodules.py | 97 ++++++++++++---------- tests/helpers/boring_model.py | 4 +- 10 files changed, 168 insertions(+), 141 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d53acf0f7030d..494d94cf446de 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,11 +34,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul pass def setup(self, trainer, pl_module: LightningModule, stage: str) -> None: - """Called when fit or test begins""" + """Called when fit, validate, test, predict, or tune begins""" pass def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None: - """Called when fit or test ends""" + """Called when fit, validate, test, predict, or tune ends""" pass def on_init_start(self, trainer) -> None: diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 29b93abe3e6a1..31c05e3bcc4c4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -55,10 +55,10 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): """A decorator that checks if prepare_data/setup have been called. - - When dm.prepare_data() is called, dm.has_prepared_data gets set to True - - When dm.setup('fit') is called, dm.has_setup_fit gets set to True - - When dm.setup('test') is called, dm.has_setup_test gets set to True - - When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}`` + it's corresponding `dm_has_setup_{stage}` gets set to True Args: fn (function): Function that will be tracked to see if it has been called. @@ -77,15 +77,15 @@ def wrapped_fn(*args, **kwargs): if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit' and 'test' to True. + # If not provided, set call status of 'fit', 'validate', and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() stage = args[1] if len(args) > 1 else kwargs.get("stage", None) - if stage == "fit" or stage is None: - obj._has_setup_fit = True - - if stage == "test" or stage is None: - obj._has_setup_test = True + if stage is None: + for s in ("fit", "validate", "test"): + setattr(obj, f"_has_setup_{s}", True) + else: + setattr(obj, f"_has_setup_{stage}", True) if fn.__name__ == "prepare_data": obj._has_prepared_data = True @@ -156,7 +156,9 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False self._has_setup_fit = False + self._has_setup_validate = False self._has_setup_test = False + self._has_setup_predict = False @property def train_transforms(self): @@ -214,32 +216,50 @@ def size(self, dim=None) -> Union[Tuple, int]: return self.dims @property - def has_prepared_data(self): - """Return bool letting you know if datamodule.prepare_data() has been called or not. + def has_prepared_data(self) -> bool: + """Return bool letting you know if ``datamodule.prepare_data()`` has been called or not. Returns: - bool: True if datamodule.prepare_data() has been called. False by default. + bool: True if ``datamodule.prepare_data()`` has been called. False by default. """ return self._has_prepared_data @property - def has_setup_fit(self): - """Return bool letting you know if datamodule.setup('fit') has been called or not. + def has_setup_fit(self) -> bool: + """Return bool letting you know if ``datamodule.setup('fit')`` has been called or not. Returns: - bool: True if datamodule.setup('fit') has been called. False by default. + bool: True ``if datamodule.setup('fit')`` has been called. False by default. """ return self._has_setup_fit @property - def has_setup_test(self): - """Return bool letting you know if datamodule.setup('test') has been called or not. + def has_setup_validate(self) -> bool: + """Return bool letting you know if ``datamodule.setup('validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup('validate')`` has been called. False by default. + """ + return self._has_setup_validate + + @property + def has_setup_test(self) -> bool: + """Return bool letting you know if ``datamodule.setup('test')`` has been called or not. Returns: - bool: True if datamodule.setup('test') has been called. False by default. + bool: True if ``datamodule.setup('test')`` has been called. False by default. """ return self._has_setup_test + @property + def has_setup_predict(self) -> bool: + """Return bool letting you know if ``datamodule.setup('predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup('predict')`` has been called. False by default. + """ + return self._has_setup_predict + @abstractmethod def prepare_data(self, *args, **kwargs): pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 604803365298c..a6567e3d52f0f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -27,12 +27,12 @@ class ModelHooks: def setup(self, stage: str) -> None: """ - Called at the beginning of fit and test. + Called at the beginning of fit (train + validate), validate, test, predict, or tune. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either 'fit' or 'test' + stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` Example:: @@ -55,10 +55,10 @@ def setup(stage): def teardown(self, stage: str) -> None: """ - Called at the end of fit and test. + Called at the end of fit (train + validate), validate, test, predict, or tune. Args: - stage: either 'fit' or 'test' + stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` """ def on_fit_start(self) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8f9fc3ad930b0..71433429f7c03 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -29,18 +29,18 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] lightning_module: LightningModule - def on_before_accelerator_backend_setup(self, model): - """Called in the beginning of fit and test""" + def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def setup(self, model, stage: str): - """Called in the beginning of fit and test""" + def setup(self, model: LightningModule, stage: str) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage) - def teardown(self, stage: str): - """Called at the end of fit and test""" + def teardown(self, stage: str) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) @@ -124,15 +124,15 @@ def on_train_end(self): for callback in self.callbacks: callback.on_train_end(self, self.lightning_module) - def on_pretrain_routine_start(self, model): - """Called when the train begins.""" + def on_pretrain_routine_start(self) -> None: + """Called when the pre-train routine begins.""" for callback in self.callbacks: - callback.on_pretrain_routine_start(self, model) + callback.on_pretrain_routine_start(self, self.lightning_module) - def on_pretrain_routine_end(self, model): - """Called when the train ends.""" + def on_pretrain_routine_end(self) -> None: + """Called when the pre-train routine ends.""" for callback in self.callbacks: - callback.on_pretrain_routine_end(self, model) + callback.on_pretrain_routine_end(self, self.lightning_module) def on_batch_start(self): """Called when the training batch begins.""" diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 7e3d6cc78320c..e98ebf088a8dc 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -22,12 +22,6 @@ class TrainerModelHooksMixin(ABC): lightning_module: LightningModule - def is_function_implemented(self, f_name, model=None): - if model is None: - model = self.lightning_module - f_op = getattr(model, f_name, None) - return callable(f_op) - def has_arg(self, f_name, arg_name): model = self.lightning_module f_op = getattr(model, f_name, None) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index d0c2ded659f67..2688fb6754977 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -27,14 +27,14 @@ class TrainerState(LightningEnum): >>> TrainerState.FINISHED == 'finished' True """ - INITIALIZING = 'INITIALIZING' # trainer creation - FITTING = 'FITTING' # trainer.fit() - VALIDATING = 'VALIDATING' # trainer.validate() - TESTING = 'TESTING' # trainer.test() - PREDICTING = 'PREDICTING' # trainer.predict() - TUNING = 'TUNING' # trainer.tune() - FINISHED = 'FINISHED' - INTERRUPTED = 'INTERRUPTED' + INITIALIZING = 'initializing' # trainer creation + FITTING = 'fit' # trainer.fit() + VALIDATING = 'validate' # trainer.validate() + TESTING = 'test' # trainer.test() + PREDICTING = 'predict' # trainer.predict() + TUNING = 'tune' # trainer.tune() + FINISHED = 'finished' + INTERRUPTED = 'interrupted' @property def stopped(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc1964f07039b..7cd666b17ca7b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -443,7 +443,7 @@ def fit( # ---------------------------- self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.setup(self, model) + self.accelerator.setup(self, model) # note: this sets up self.lightning_module self.setup_trainer(model) # ---------------------------- @@ -473,7 +473,8 @@ def fit( # TRAIN # ---------------------------- # hook - self.call_hook("on_fit_start") + if self.state == TrainerState.FITTING: + self.call_hook("on_fit_start") # plugin will setup fitting (e.g. ddp will launch child processes) self.pre_dispatch() @@ -488,12 +489,11 @@ def fit( # POST-Training CLEAN UP # ---------------------------- # hook - self.call_hook('on_fit_end') + if self.state == TrainerState.FITTING: + self.call_hook('on_fit_end') - # hook - self.teardown('fit') - if self.is_function_implemented('teardown'): - model.teardown('fit') + # teardown + self.call_teardown_hook(model) if self.state != TrainerState.INTERRUPTED: self.state = TrainerState.FINISHED @@ -541,9 +541,8 @@ def _pre_training_routine(self): # on pretrain routine start ref_model = self.lightning_module - self.on_pretrain_routine_start(ref_model) - if self.is_function_implemented("on_pretrain_routine_start"): - ref_model.on_pretrain_routine_start() + self.on_pretrain_routine_start() + ref_model.on_pretrain_routine_start() # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: @@ -556,9 +555,8 @@ def _pre_training_routine(self): self.checkpoint_connector.restore_weights() # on pretrain routine end - self.on_pretrain_routine_end(ref_model) - if self.is_function_implemented("on_pretrain_routine_end"): - ref_model.on_pretrain_routine_end() + self.on_pretrain_routine_end() + ref_model.on_pretrain_routine_end() def run_train(self) -> None: @@ -880,8 +878,6 @@ def test( self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) ) - self.teardown('test') - assert self.state.stopped self.testing = False @@ -929,10 +925,6 @@ def __evaluate_using_weights( # run test results = self.fit(model) - # teardown - if self.is_function_implemented('teardown', model=model): - model.teardown('test') - return results def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): @@ -944,10 +936,6 @@ def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, # sets up testing so we short circuit to eval results = self.fit(model) - # teardown - if self.is_function_implemented('teardown', model=model): - model.teardown('test') - return results def predict( @@ -1035,17 +1023,26 @@ def tune( assert self.state.stopped self.tuning = False - def call_setup_hook(self, model): - # call setup after the ddp process has connected - stage_name = 'test' if self.evaluating else 'fit' + def call_setup_hook(self, model: LightningModule) -> None: + assert self.state.running, f"TrainerState: {self.state}" + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value if self.datamodule is not None: - called = getattr(self.datamodule, f'has_setup_{stage_name}') + called = getattr(self.datamodule, f'has_setup_{state}') if not called: - self.datamodule.setup(stage_name) + self.datamodule.setup(state) + + self.setup(model, state) + model.setup(state) + + def call_teardown_hook(self, model: LightningModule) -> None: + assert self.state.running, f"TrainerState: {self.state}" + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value - self.setup(model, stage_name) - model.setup(stage_name) + self.teardown(state) + model.teardown(state) def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8a25ecc9f983b..2426348f770bf 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,29 +19,20 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system(_, tmpdir): - """Test the callback system.""" +def test_trainer_callback_system_fit(_, tmpdir): + """Test the callback system for fit.""" model = BoringModel() - callback_mock = MagicMock() - - trainer_options = dict( + trainer = Trainer( default_root_dir=tmpdir, callbacks=[callback_mock], max_epochs=1, limit_val_batches=1, limit_train_batches=3, - limit_test_batches=2, progress_bar_refresh_rate=0, ) - # no call yet - callback_mock.assert_not_called() - - # fit model - trainer = Trainer(**trainer_options) - # check that only the to calls exists assert trainer.callbacks[0] == callback_mock assert callback_mock.method_calls == [ @@ -49,6 +40,7 @@ def test_trainer_callback_system(_, tmpdir): call.on_init_end(trainer), ] + # fit model trainer.fit(model) assert callback_mock.method_calls == [ @@ -104,8 +96,20 @@ def test_trainer_callback_system(_, tmpdir): call.teardown(trainer, model, 'fit'), ] - callback_mock.reset_mock() - trainer = Trainer(**trainer_options) + +def test_trainer_callback_system_test(tmpdir): + """Test the callback system for test.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_test_batches=2, + progress_bar_refresh_rate=0, + ) + trainer.test(model) assert callback_mock.method_calls == [ @@ -113,7 +117,6 @@ def test_trainer_callback_system(_, tmpdir): call.on_init_end(trainer), call.setup(trainer, model, 'test'), call.on_before_accelerator_backend_setup(trainer, model), - call.on_fit_start(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), @@ -123,8 +126,6 @@ def test_trainer_callback_system(_, tmpdir): call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), - call.on_fit_end(trainer, model), - call.teardown(trainer, model, 'fit'), call.teardown(trainer, model, 'test'), ] diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 866bffcdd7441..e1b4301842ecd 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -17,6 +17,7 @@ from unittest import mock from unittest.mock import PropertyMock +import pytest import torch import torch.nn.functional as F @@ -108,13 +109,13 @@ def prepare_data(self, *args, **kwargs): dm.prepare_data() -def test_base_datamodule(tmpdir): +def test_helper_boringdatamodule(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup() -def test_base_datamodule_with_verbose_setup(tmpdir): +def test_helper_boringdatamodule_with_verbose_setup(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') @@ -123,55 +124,67 @@ def test_base_datamodule_with_verbose_setup(tmpdir): def test_data_hooks_called(tmpdir): dm = BoringDataModule() - assert dm.has_prepared_data is False - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict dm.prepare_data() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict dm.setup() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is True + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict -def test_data_hooks_called_verbose(tmpdir): +@pytest.mark.parametrize("use_kwarg", (False, True)) +def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() - assert dm.has_prepared_data is False - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test dm.prepare_data() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is False - assert dm.has_setup_test is False - - dm.setup('fit') - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is False - - dm.setup('test') - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is True - - -def test_data_hooks_called_with_stage_kwarg(tmpdir): - dm = BoringDataModule() - dm.prepare_data() - assert dm.has_prepared_data is True - - dm.setup(stage='fit') - assert dm.has_setup_fit is True - assert dm.has_setup_test is False - - dm.setup(stage='test') - assert dm.has_setup_fit is True - assert dm.has_setup_test is True + assert dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='fit') if use_kwarg else dm.setup('fit') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert not dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='validate') if use_kwarg else dm.setup('validate') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='test') if use_kwarg else dm.setup('test') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='predict') if use_kwarg else dm.setup('predict') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert dm.has_setup_predict def test_dm_add_argparse_args(tmpdir): diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index ea26310a45315..6ef2518bbef11 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -151,9 +151,11 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): if stage == "fit" or stage is None: self.random_train = Subset(self.random_full, indices=range(64)) - self.random_val = Subset(self.random_full, indices=range(64, 128)) self.dims = self.random_train[0].shape + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 128)) + if stage == "test" or stage is None: self.random_test = Subset(self.random_full, indices=range(128, 192)) self.dims = getattr(self, "dims", self.random_test[0].shape) From d49ccd1b9fb0afadaa28c4bead0e0cb7e5b1fc91 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 02:43:12 +0100 Subject: [PATCH 15/32] Fix doctest --- pytorch_lightning/trainer/states.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 2688fb6754977..33a2326c518d5 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -21,10 +21,10 @@ class TrainerState(LightningEnum): functions such as `trainer.fit()` and `trainer.test(). >>> # you can compare the type with a string - >>> TrainerState.FITTING == 'FITTING' + >>> TrainerState.FITTING == 'fit' True >>> # which is case insensitive - >>> TrainerState.FINISHED == 'finished' + >>> TrainerState.FINISHED == 'FINISHED' True """ INITIALIZING = 'initializing' # trainer creation From 23db13507878a60cb17844f6133c5b7adf9fa9ca Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:24:29 +0100 Subject: [PATCH 16/32] stage: Optional[str] = None --- pytorch_lightning/callbacks/base.py | 6 +++--- pytorch_lightning/core/hooks.py | 8 ++++---- pytorch_lightning/trainer/callback_hook.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 18 ++++++++++-------- tests/models/test_hooks.py | 16 ++++++---------- 5 files changed, 26 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 494d94cf446de..0ba1fd4ff7785 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict +from typing import Any, Dict, Optional from pytorch_lightning.core.lightning import LightningModule @@ -33,11 +33,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul """Called before accelerator is being setup""" pass - def setup(self, trainer, pl_module: LightningModule, stage: str) -> None: + def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune begins""" pass - def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None: + def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune ends""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a6567e3d52f0f..9826f9d44ac2c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,14 +25,14 @@ class ModelHooks: """Hooks to be used in LightningModule.""" - def setup(self, stage: str) -> None: + def setup(self, stage: Optional[str] = None) -> None: """ Called at the beginning of fit (train + validate), validate, test, predict, or tune. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` Example:: @@ -53,12 +53,12 @@ def setup(stage): """ - def teardown(self, stage: str) -> None: + def teardown(self, stage: Optional[str] = None) -> None: """ Called at the end of fit (train + validate), validate, test, predict, or tune. Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` """ def on_fit_start(self) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 71433429f7c03..f174cd725bd36 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Type +from typing import Any, Callable, Dict, List, Type, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -34,12 +34,12 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def setup(self, model: LightningModule, stage: str) -> None: + def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage) - def teardown(self, stage: str) -> None: + def teardown(self, stage: Optional[str] = None) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7cd666b17ca7b..d58de7d803146 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1031,18 +1031,20 @@ def call_setup_hook(self, model: LightningModule) -> None: if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') if not called: - self.datamodule.setup(state) + self.datamodule.setup(stage=state) - self.setup(model, state) - model.setup(state) + self.setup(model, stage=state) + model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - assert self.state.running, f"TrainerState: {self.state}" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value + if self.state.running: + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value + else: + state = None - self.teardown(state) - model.teardown(state) + self.teardown(stage=state) + model.teardown(stage=state) def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1a7803800b384..7c53925bd7cc4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -404,7 +404,7 @@ def on_test_end(self): self.called.append(inspect.currentframe().f_code.co_name) super().on_test_end() - def teardown(self, stage: str): + def teardown(self, stage=None): self.called.append(inspect.currentframe().f_code.co_name) super().teardown(stage) @@ -420,12 +420,12 @@ def teardown(self, stage: str): limit_train_batches=2, limit_test_batches=1, progress_bar_refresh_rate=0, + weights_summary=None, ) assert model.called == [] trainer.fit(model) - expected = [ 'on_fit_start', 'on_pretrain_routine_start', @@ -469,11 +469,10 @@ def teardown(self, stage: str): assert model.called == expected - model2 = HookedModel() - trainer.test(model2) + model = HookedModel() + trainer.test(model, verbose=False) expected = [ - 'on_fit_start', 'on_test_model_eval', 'on_test_start', 'on_test_epoch_start', @@ -483,9 +482,6 @@ def teardown(self, stage: str): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'on_fit_end', - 'teardown', # for 'fit' - 'teardown', # for 'test' + 'teardown', ] - - assert model2.called == expected + assert model.called == expected From 84f5fdb8e6b6b4254a0f635281c5356756d00ba5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:26:48 +0100 Subject: [PATCH 17/32] Trailing whitespace --- tests/core/test_datamodules.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index e1b4301842ecd..ab51a87329e2f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -131,17 +131,17 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_predict dm.prepare_data() - assert dm.has_prepared_data + assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict dm.setup() - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_test - assert dm.has_setup_validate + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate assert not dm.has_setup_predict @@ -153,21 +153,21 @@ def test_data_hooks_called_verbose(tmpdir, use_kwarg): assert not dm.has_setup_test dm.prepare_data() - assert dm.has_prepared_data + assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') assert dm.has_prepared_data - assert dm.has_setup_fit + assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') - assert dm.has_prepared_data - assert dm.has_setup_fit + assert dm.has_prepared_data + assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict @@ -180,11 +180,11 @@ def test_data_hooks_called_verbose(tmpdir, use_kwarg): assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_validate - assert dm.has_setup_test - assert dm.has_setup_predict + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert dm.has_setup_predict def test_dm_add_argparse_args(tmpdir): From 188b9feae8b386114d792113758afb51fa9c5931 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:50:52 +0100 Subject: [PATCH 18/32] Update docs and CHANGELOG --- CHANGELOG.md | 3 +++ docs/source/extensions/datamodules.rst | 12 ++++++------ docs/source/starter/introduction_guide.rst | 8 ++++---- docs/source/starter/new-project.rst | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f8f7a08b089b..f6ef0d56b3792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Changed `setup()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + ### Deprecated diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a6c083dc61fcf..85134fda06fa2 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -80,7 +80,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa self.data_dir = data_dir self.batch_size = batch_size - def setup(self, stage=None): + def setup(self, stage: Optional[str] = None): self.mnist_test = MNIST(self.data_dir, train=False) mnist_full = MNIST(self.data_dir, train=True) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) @@ -138,7 +138,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) - def setup(self, stage=None): + def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: @@ -382,12 +382,12 @@ still ensures the method runs on the correct devices) dm = MNISTDataModule() dm.prepare_data() - dm.setup('fit') + dm.setup(stage='fit') model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) trainer.fit(model, dm) - dm.setup('test') + dm.setup(stage='test') trainer.test(datamodule=dm) ---------------- @@ -403,7 +403,7 @@ You can of course use DataModules in plain PyTorch code as well. dm.prepare_data() # splits/transforms - dm.setup('fit') + dm.setup(stage='fit') # use data for batch in dm.train_dataloader(): @@ -412,7 +412,7 @@ You can of course use DataModules in plain PyTorch code as well. ... # lazy load test data - dm.setup('test') + dm.setup(stage='test') for batch in dm.test_dataloader(): ... diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 2ee31304299e0..c65894367a39e 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -240,7 +240,7 @@ In this case, it's better to group the full definition of a dataset into a `Data tokenize() build_vocab() - def setup(self): + def setup(self, stage: Optional[str] = None): # called on every GPU vocab = load_vocab() self.vocab_size = len(vocab) @@ -310,8 +310,8 @@ An alternative to using a DataModule is to defer initialization of the models mo download_data() tokenize() - def setup(self, step): - # step is either 'fit' or 'test' 90% of the time not relevant + def setup(self, stage: Optional[str] = None): + # step is either 'fit', 'validate', 'test', or 'predict'. 90% of the time not relevant data = load_data() num_classes = data.classes self.l1 = nn.Linear(..., num_classes) @@ -598,7 +598,7 @@ In this method we do all the preparation we need to do once (instead of on every MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - def setup(self, stage): + def setup(self, stage: Optional[str] = None): # transform transform=transforms.Compose([transforms.ToTensor()]) mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform) diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 0f1362616a9b1..23f91914063d9 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -651,7 +651,7 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning. MNIST(os.getcwd(), train=False, download=True) # OPTIONAL, called for every GPU/machine (assigning state is OK) - def setup(self, stage): + def setup(self, stage: Optional[str] = None): # transforms transform=transforms.Compose([ transforms.ToTensor(), From 37473f0c549590c5c342b53c564af7affb9fb05b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:52:14 +0100 Subject: [PATCH 19/32] Mention teardown --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6ef0d56b3792..327f923a79ff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) -- Changed `setup()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) ### Deprecated From 0a30abf931ec5a5f1127bdf98514df6f489cb735 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:59:49 +0100 Subject: [PATCH 20/32] Self-review --- pytorch_lightning/core/datamodule.py | 20 ++++++++++---------- pytorch_lightning/trainer/callback_hook.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 31c05e3bcc4c4..1b6852c071fe1 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -57,8 +57,8 @@ def track_data_hook_calls(fn): - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}`` - it's corresponding `dm_has_setup_{stage}` gets set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. + Its corresponding `dm_has_setup_{stage}` attribute gets set to True Args: fn (function): Function that will be tracked to see if it has been called. @@ -226,37 +226,37 @@ def has_prepared_data(self) -> bool: @property def has_setup_fit(self) -> bool: - """Return bool letting you know if ``datamodule.setup('fit')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not. Returns: - bool: True ``if datamodule.setup('fit')`` has been called. False by default. + bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. """ return self._has_setup_fit @property def has_setup_validate(self) -> bool: - """Return bool letting you know if ``datamodule.setup('validate')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not. Returns: - bool: True if ``datamodule.setup('validate')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. """ return self._has_setup_validate @property def has_setup_test(self) -> bool: - """Return bool letting you know if ``datamodule.setup('test')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not. Returns: - bool: True if ``datamodule.setup('test')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. """ return self._has_setup_test @property def has_setup_predict(self) -> bool: - """Return bool letting you know if ``datamodule.setup('predict')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not. Returns: - bool: True if ``datamodule.setup('predict')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. """ return self._has_setup_predict diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index f174cd725bd36..5aa9f1a44276b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -40,7 +40,7 @@ def setup(self, model: LightningModule, stage: Optional[str]) -> None: callback.setup(self, model, stage) def teardown(self, stage: Optional[str] = None) -> None: - """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) From 0e9d69c35356824d5cc1b8e986c850ad71de50af Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 14:39:30 +0100 Subject: [PATCH 21/32] Address Borda's comments --- docs/source/conf.py | 1 + pytorch_lightning/trainer/model_hooks.py | 10 +++++++++- pytorch_lightning/trainer/trainer.py | 3 +-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 813d5ee978821..ccf824bb37d9b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -371,6 +371,7 @@ def package_list_from_file(file): doctest_global_setup = """ import importlib import os +from typing import Optional import torch from torch import nn import pytorch_lightning as pl diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index e98ebf088a8dc..b924675d8505c 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -14,6 +14,7 @@ import inspect from abc import ABC +from typing import Optional from pytorch_lightning.core.lightning import LightningModule @@ -22,7 +23,14 @@ class TrainerModelHooksMixin(ABC): lightning_module: LightningModule - def has_arg(self, f_name, arg_name): + def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool: + # note: currently unused - kept as it is public + if model is None: + model = self.lightning_module + f_op = getattr(model, f_name, None) + return callable(f_op) + + def has_arg(self, f_name: str, arg_name: str) -> bool: model = self.lightning_module f_op = getattr(model, f_name, None) return arg_name in inspect.signature(f_op).parameters diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d58de7d803146..45fc40731b545 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1025,8 +1025,8 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1039,7 +1039,6 @@ def call_setup_hook(self, model: LightningModule) -> None: def call_teardown_hook(self, model: LightningModule) -> None: if self.state.running: state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value else: state = None From 9758c7be8653a9db529c9b89bc4cc78859ad5ec7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 17:59:44 +0100 Subject: [PATCH 22/32] Fixing conflicts --- .../accelerators/cpu_accelerator.py | 88 ---- .../accelerators/ddp2_accelerator.py | 267 ------------- .../accelerators/ddp_accelerator.py | 375 ------------------ .../accelerators/ddp_cpu_spawn_accelerator.py | 296 -------------- .../accelerators/ddp_hpc_accelerator.py | 256 ------------ .../accelerators/ddp_spawn_accelerator.py | 328 --------------- .../accelerators/dp_accelerator.py | 187 --------- .../accelerators/gpu_accelerator.py | 107 ----- .../accelerators/horovod_accelerator.py | 195 --------- .../accelerators/tpu_accelerator.py | 365 ----------------- tests/base/datamodules.py | 115 ------ tests/trainer/test_trainer_validate_loop.py | 76 ---- 12 files changed, 2655 deletions(-) delete mode 100644 pytorch_lightning/accelerators/cpu_accelerator.py delete mode 100644 pytorch_lightning/accelerators/ddp2_accelerator.py delete mode 100644 pytorch_lightning/accelerators/ddp_accelerator.py delete mode 100644 pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py delete mode 100644 pytorch_lightning/accelerators/ddp_hpc_accelerator.py delete mode 100644 pytorch_lightning/accelerators/ddp_spawn_accelerator.py delete mode 100644 pytorch_lightning/accelerators/dp_accelerator.py delete mode 100644 pytorch_lightning/accelerators/gpu_accelerator.py delete mode 100644 pytorch_lightning/accelerators/horovod_accelerator.py delete mode 100644 pytorch_lightning/accelerators/tpu_accelerator.py delete mode 100644 tests/base/datamodules.py delete mode 100644 tests/trainer/test_trainer_validate_loop.py diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py deleted file mode 100644 index 025b25f715412..0000000000000 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional, Union - -import torch - -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class CPUAccelerator(Accelerator): - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training on CPU - - Example:: - - # default - trainer = Trainer(accelerator=CPUAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.nickname = None - - def setup(self, model): - # run through amp wrapper - if self.trainer.amp_backend: - raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - return results - - def _step(self, model_step: Callable, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = model_step(*args) - else: - output = model_step(*args) - return output - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return tensor - - @property - def require_distributed_sampler(self): - return False diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py deleted file mode 100644 index d8d68d2a49b02..0000000000000 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as torch_distrib -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available - - -class DDP2Accelerator(Accelerator): - - def __init__(self, - trainer, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP2 strategy on a cluster - - Example:: - - # default - trainer = Trainer(accelerator=DDP2Accelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.task_idx = None - self.dist = LightningDistributed() - self.nickname = 'ddp2' - - def setup(self, model): - self.trainer.model = model - self.task_idx = self.cluster_environment.local_rank() - - def train(self): - model = self.trainer.model - return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def training_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - return output - - def validation_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - return output - - def test_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - return output - - def set_world_ranks(self, process_idx): - # Todo: required argument `process_idx` is not used - self.trainer.local_rank = self.trainer.node_rank - self.trainer.global_rank = self.trainer.node_rank - self.trainer.world_size = self.trainer.num_nodes - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def init_device(self, process_idx): - self.trainer.root_gpu = process_idx - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = self.trainer.data_parallel_device_ids - return device_ids - - def ddp_train(self, process_idx, mp_queue, model): - """ - Entry point for ddp - - Args: - process_idx: current process rank - mp_queue: multiprocessing queue - model: pointer to current :class:`LightningModule` - - Returns: - Dict with evaluation results - - """ - # Todo: required argument `mp_queue` is not used - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # Initialize cuda device - self.init_device(process_idx) - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - # clean up memory - torch.cuda.empty_cache() - return results - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - self.ddp_plugin.device_ids = device_ids - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py deleted file mode 100644 index 10589608c8758..0000000000000 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -import subprocess -import sys -from os.path import abspath -from time import sleep -from typing import Any, List, Optional, Union - -import numpy as np -import torch -import torch.distributed as torch_distrib -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - find_free_network_port, - rank_zero_only, - sync_ddp_if_available, -) -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.seed import seed_everything - -if _HYDRA_AVAILABLE: - from hydra.core.hydra_config import HydraConfig - from hydra.utils import get_original_cwd, to_absolute_path - - -class DDPAccelerator(Accelerator): - - def __init__(self, - trainer: Optional = None, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP strategy on a single machine (manually, not via cluster start) - - Example:: - - # default - trainer = Trainer(accelerator=DDPAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.task_idx = None - self._has_spawned_children = False - self.interactive_ddp_procs = [] - self.dist = LightningDistributed() - self.nickname = 'ddp' - - def setup(self, model): - # first track model - self.trainer.model = model - - # start the other scripts - if os.environ.get('PL_IN_DDP_SUBPROCESS', '0') != '1': - self._call_children_scripts() - - # set the task idx - self.task_idx = int(os.environ['LOCAL_RANK']) - - def _call_children_scripts(self): - assert self.trainer.global_rank == 0 - self._check_can_spawn_children() - self._has_spawned_children = True - - os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', '127.0.0.1') - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) - - # allow the user to pass the node rank - node_rank = '0' - node_rank = os.environ.get('NODE_RANK', node_rank) - node_rank = os.environ.get('GROUP_RANK', node_rank) - os.environ['NODE_RANK'] = node_rank - os.environ['LOCAL_RANK'] = '0' - - # when user is using hydra find the absolute path - path_lib = abspath if not _HYDRA_AVAILABLE else to_absolute_path - - # pull out the commands used to run the script and resolve the abs file path - command = sys.argv - try: - full_path = path_lib(command[0]) - # todo: specify the possible exception - except Exception: - full_path = abspath(command[0]) - - command[0] = full_path - # use the same python interpreter and actually running - command = [sys.executable] + command - - # the visible devices tell us how many GPUs we want to use. - # when the trainer script was called the device has already been scoped by the time - # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone - # but forward the GPUs selected via environment variables - if self.trainer.data_parallel_device_ids is None: - raise MisconfigurationException('you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)') - - os.environ['PL_TRAINER_GPUS'] = ','.join([str(i) for i in self.trainer.data_parallel_device_ids]) - os.environ['PL_IN_DDP_SUBPROCESS'] = '1' - - if self.trainer.logger is not None: - os.environ['PL_EXP_VERSION'] = str(self.trainer.logger.version) - - num_gpus = len(self.trainer.data_parallel_device_ids) - os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}' - - self.interactive_ddp_procs = [] - for local_rank in range(1, self.trainer.num_processes): - env_copy = os.environ.copy() - env_copy['LOCAL_RANK'] = f'{local_rank}' - - # remove env var if global seed not set - if os.environ.get('PL_GLOBAL_SEED') is None and 'PL_GLOBAL_SEED' in env_copy: - del env_copy['PL_GLOBAL_SEED'] - - # start process - # if hydra is available and initialized, make sure to set the cwd correctly - cwd: Optional[str] = None - if _HYDRA_AVAILABLE: - if HydraConfig.initialized(): - cwd = get_original_cwd() - proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.interactive_ddp_procs.append(proc) - - # starting all processes at once can cause issues - # with dataloaders delay between 1-10 seconds - delay = np.random.uniform(1, 5, 1)[0] - sleep(delay) - - def train(self): - model = self.trainer.model - - results = self.ddp_train(process_idx=self.task_idx, model=model) - if 'WORLD_SIZE' in os.environ: - del os.environ['WORLD_SIZE'] - return results - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if self.rpc_enabled: - # Allow RPC to handle barrier on main RPC processes - self.ddp_plugin.barrier() - elif torch_distrib.is_initialized(): - torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group) - - def _check_can_spawn_children(self): - if self._has_spawned_children: - raise RuntimeError( - "You tried to run `.fit` or `.test` multiple times in the same script." - " This is not supported in DDP mode, switch to `accelerator='ddp_spawn'` instead." - ) - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def init_device(self, process_idx): - # Todo: required argument `process_idx` is not used - self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank] - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = [self.trainer.root_gpu] - return device_ids - - def on_train_end(self): - pass - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - self.barrier('early_stopping') - should_stop = stop == self.trainer.world_size - return should_stop - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group) - - def ddp_train(self, process_idx, model): - """ - Entry point for ddp - - Args: - process_idx: - model: - - Returns: - Dict with evaluation results - - """ - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # Initialize cuda device - self.init_device(process_idx) - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.barrier('ddp_setup') - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - # clean up memory - torch.cuda.empty_cache() - - return results - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - self.ddp_plugin.device_ids = device_ids - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - """ - - """ - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py deleted file mode 100644 index a6e98fa888516..0000000000000 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as torch_distrib -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - find_free_network_port, - rank_zero_only, - rank_zero_warn, - sync_ddp_if_available, -) - - -class DDPCPUSpawnAccelerator(Accelerator): - - def __init__(self, - trainer, - nprocs: int, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn - - Example:: - - # default - trainer = Trainer(accelerator=DDPCPUSpawnAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.mp_queue = None - self.nprocs = nprocs - self.dist = LightningDistributed() - self.nickname = 'ddp_cpu' - - def setup(self, model): - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) - - # pass in a state q - smp = mp.get_context('spawn') - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # train in children process - mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(model, best_path) - return results - - def ddp_train(self, process_idx, mp_queue, model): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - """ - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model, process_idx) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - # DDP spawn already spawned off each process... no need to do anything - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - # get original model - model = self.trainer.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - torch_distrib.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def model_to_device(self, model, process_idx): - # Todo: required argument `process_idx` is not used - model.cpu() - - def get_device_ids(self): - device_ids = None - return device_ids - - def __recover_child_process_weights(self, model, best_path): - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback: - self.trainer.checkpoint_callback.best_model_path = best_path - - self.trainer.model = model - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - # Todo: required argument `model` is not used - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - self.ddp_plugin.device_ids = device_ids - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py deleted file mode 100644 index 64e51d60f1fdf..0000000000000 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as dist -import torch.distributed as torch_distrib -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available - - -class DDPHPCAccelerator(Accelerator): - - def __init__(self, - trainer, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP on an HPC cluster - - Example:: - - # default - trainer = Trainer(accelerator=DDPHPCAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.task_idx = None - self._has_spawned_children = False - self.dist = LightningDistributed() - self.nickname = 'ddp' - - def setup(self, model): - self.trainer.model = model - self.task_idx = self.cluster_environment.local_rank() - - def train(self): - model = self.trainer.model - self.ddp_train(process_idx=self.task_idx, model=model) - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def init_device(self, process_idx): - self.trainer.root_gpu = process_idx - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = [self.trainer.root_gpu] - return device_ids - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def ddp_train(self, process_idx, model): - """ - Entry point for ddp - - Args: - process_idx: - model: - - Returns: - Dict with evaluation results - - """ - # determine which process we are and world size - self.set_world_ranks(process_idx) - self.init_device(process_idx) - - # toggle prog bar - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - # clean up memory - torch.cuda.empty_cache() - - return results - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - self.ddp_plugin.device_ids = device_ids - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py deleted file mode 100644 index b283a5a1b2223..0000000000000 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -import re -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as torch_distrib -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - find_free_network_port, - rank_zero_only, - rank_zero_warn, - sync_ddp_if_available, -) -from pytorch_lightning.utilities.seed import seed_everything - - -class DDPSpawnAccelerator(Accelerator): - - def __init__(self, - trainer, - nprocs: int, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP using mp.spawn via manual launch (not cluster launch) - - Example:: - - # default - trainer = Trainer(accelerator=DDPSpawnAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.mp_queue = None - self.nprocs = nprocs - self.dist = LightningDistributed() - self.nickname = 'ddp' - - def setup(self, model): - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) - - # pass in a state q - smp = mp.get_context('spawn') - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # train in children process - mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - last_path = self.mp_queue.get() - - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(model, best_path, last_path) - return results - - def ddp_train(self, process_idx, mp_queue, model, is_master: bool = False, proc_offset: int = 0): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - """ - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # offset the process id if requested - process_idx = process_idx + proc_offset - - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # Initialize cuda device - self.init_device(process_idx, is_master) - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - # get original model - model = self.trainer.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def init_device(self, process_idx, is_master): - # Todo: required argument `process_idx` is not used - # Todo: required argument `is_master` is not used - gpu_idx = self.trainer.data_parallel_device_ids[self.trainer.local_rank] - self.trainer.root_gpu = gpu_idx - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = [self.trainer.root_gpu] - return device_ids - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - torch_distrib.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def __recover_child_process_weights(self, model, best_path, last_path): - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score - - # load last weights - if last_path is not None and not self.trainer.evaluating: - ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.evaluating and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) - mp_queue.put(last_path) - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - self.ddp_plugin.device_ids = device_ids - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py deleted file mode 100644 index b4220d32bd7c0..0000000000000 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -from torch import optim - -from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.distributed import LightningDistributed -from pytorch_lightning.overrides.data_parallel import LightningDataParallel -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class DataParallelAccelerator(Accelerator): - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using DP via manual start (not HPC cluster) - - Example:: - - # default - trainer = Trainer(accelerator=DataParallelAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.model_autocast_original_forward = None - self.dist = LightningDistributed() - self.nickname = 'dp' - - def setup(self, model): - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # put model on correct device - model.cuda(self.trainer.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # init torch data parallel - model = self.__init_torch_data_parallel(model) - - # hack forward to do autocast for the user - self.model_autocast_original_forward = model.forward - - # init half precision - if self.trainer.amp_backend: - model = self.__init_half_precision(model) - - self.trainer.model = model - - def __init_torch_data_parallel(self, model): - # create list of device ids - device_ids = self.trainer.data_parallel_device_ids - if isinstance(device_ids, int): - device_ids = list(range(device_ids)) - - # set dp device - torch.cuda.set_device(self.trainer.root_gpu) - model = LightningDataParallel(model, device_ids=device_ids) - return model - - def __init_half_precision(self, model): - if self.trainer.amp_backend == AMPType.NATIVE: - self.__init_native_amp(model) - else: - model = self.__init_nvidia_apex(model) - return model - - def __init_native_amp(self, model): - model.forward = torch.cuda.amp.autocast()(model.forward) - - def __init_nvidia_apex(self, model): - # check for this bug (amp + dp + !01 doesn't work) - # https://github.com/NVIDIA/apex/issues/227 - if self.trainer.amp_level == 'O2': - raise MisconfigurationException( - f'Amp level {self.trainer.amp_level} with DataParallel is not supported.' - f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.' - f' We recommend you switch to ddp if you want to use amp') - else: - model = self.trainer.precision_connector.connect(model) - - return model - - def train(self): - model = self.trainer.model - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - return results - - def teardown(self): - # replace the original fwd function - self.trainer.model.forward = self.model_autocast_original_forward - self.barrier() - - def _step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def training_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - elif isinstance(output, torch.Tensor): - output = output.mean() - return output - - def validation_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - elif isinstance(output, torch.Tensor): - output = output.mean() - return output - - def test_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - elif isinstance(output, torch.Tensor): - output = output.mean() - return output - - def reinit_scheduler_properties(self, optimizers: list, schedulers: list): - """ - Reinitialize optimizer.step properties added by schedulers - """ - for scheduler in schedulers: - scheduler = scheduler['scheduler'] - - for optimizer in optimizers: - # check that we dont mix users optimizers and schedulers - if scheduler.optimizer == optimizer: - # Find the mro belonging to the base lr scheduler class - for i, mro in enumerate(scheduler.__class__.__mro__): - is_regular_scheduler = optim.lr_scheduler._LRScheduler - is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau - if is_regular_scheduler or is_lr_reduce_on_plateau: - idx = i - state = scheduler.state_dict() - else: - state = None - - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) - if state is not None: - scheduler.load_state_dict(state) - - def get_reference_model(self, model) -> LightningModule: - if isinstance(model, LightningDataParallel): - return model.module - return model - - @property - def require_distributed_sampler(self): - return False diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py deleted file mode 100644 index 9ee7bd608dc84..0000000000000 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional, Union - -import torch - -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.utilities import AMPType - - -class GPUAccelerator(Accelerator): - amp_backend: AMPType - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using a single GPU - - Example:: - - # default - trainer = Trainer(accelerator=GPUAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.dist = LightningDistributed() - self.nickname = None - - def setup(self, model): - - # call setup - self.trainer.call_setup_hook(model) - - torch.cuda.set_device(self.trainer.root_gpu) - model.cuda(self.trainer.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - return results - - def _step(self, model_step: Callable, args): - args[0] = self.to_device(args[0]) - - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = model_step(*args) - else: - output = model_step(*args) - - return output - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def to_device(self, batch): - gpu_id = 0 - if isinstance(self.trainer.data_parallel_device_ids, list): - gpu_id = self.trainer.data_parallel_device_ids[0] - - # Don't copy the batch since there is a single gpu that the batch could - # be referenced from and if there are multiple optimizers the batch will - # wind up copying it to the same device repeatedly. - return self.batch_to_device(batch, gpu_id) - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return tensor - - @property - def require_distributed_sampler(self): - return False diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py deleted file mode 100644 index 8471f89cf35f0..0000000000000 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import ExitStack -from typing import Any, Callable, Optional, Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler - -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType, DeviceType -from pytorch_lightning.utilities.distributed import rank_zero_only - -if _HOROVOD_AVAILABLE: - import horovod.torch as hvd - - -class HorovodAccelerator(Accelerator): - amp_backend: AMPType - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using horovod - - Example:: - - # default - trainer = Trainer(accelerator=HorovodAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.nickname = 'horovod' - - def setup(self, model): - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - if torch.cuda.is_available() and self.trainer._device_type == DeviceType.GPU: - # Horovod: pin GPU to local rank - assert self.trainer.root_gpu == hvd.local_rank() - torch.cuda.set_device(self.trainer.root_gpu) - model.cuda(self.trainer.root_gpu) - - # avoid duplicating progress bar - if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # Horovod: scale the learning rate by the number of workers to account for - # increased total batch size - for optimizer in self.trainer.optimizers: - for param_group in optimizer.param_groups: - param_group['lr'] *= hvd.size() - - # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR - for scheduler in self.trainer.lr_schedulers: - scheduler = scheduler['scheduler'] - if isinstance(scheduler, _LRScheduler): - scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] - - # Horovod: broadcast parameters & optimizer state to ensure consistent initialization - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - for optimizer in self.trainer.optimizers: - hvd.broadcast_optimizer_state(optimizer, root_rank=0) - - def _filter_named_parameters(model, optimizer): - opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])]) - return [(name, p) for name, p in model.named_parameters() if p in opt_params] - - # Horovod: wrap optimizers to perform gradient aggregation via allreduce - self.trainer.optimizers = [ - hvd.DistributedOptimizer(optimizer, named_parameters=_filter_named_parameters(model, optimizer)) - for optimizer in self.trainer.optimizers - ] - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - # Update logger rank info from Horovod to avoid race conditions from different ranks - # creating directories / writing files in the same locations. - self.trainer.global_rank = hvd.rank() - rank_zero_only.rank = self.trainer.global_rank - - self.trainer.model = model - - def train(self): - with ExitStack() as stack: - for optimizer in self.trainer.optimizers: - # Synchronization will be performed explicitly following backward() - stack.enter_context(optimizer.skip_synchronize()) - - # set up training routine - self.trainer.train_loop.setup_training(self.trainer.model) - - # train or evaluate - results = self.train_or_evaluate() - - # Make sure all workers have finished training before returning to the user - hvd.join() - return results - - def _step(self, model_step: Callable, args): - if self.trainer._device_type == DeviceType.GPU: - args[0] = self.batch_to_device(args[0], hvd.local_rank()) - - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = model_step(*args) - else: - output = model_step(*args) - - return output - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) - optimizer.synchronize() - - def on_train_epoch_end(self, outputs): - hvd.join(hvd.local_rank() if self.trainer._device_type == DeviceType.GPU else -1) - - def barrier(self, name: Optional[str] = None): - hvd.join() - - def broadcast(self, obj, src=0): - obj = hvd.broadcast_object(obj, src) - return obj - - def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): - if group is not None: - raise ValueError( - "Horovod does not support allgather using a subcommunicator at this time. " - "Unset `group`." - ) - - if len(result.shape) == 0: - # Convert scalars to single dimension tensors - result = result.reshape(1) - - # sync and gather all - hvd.join() - gathered = hvd.allgather(result) - gathered_result = list(gathered.split(1, dim=0)) - return gathered_result - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - if group is not None: - raise ValueError( - "Horovod does not support allreduce using a subcommunicator at this time. " - "Unset `group`." - ) - - if reduce_op is None or reduce_op == "sum": - reduce_op = hvd.Sum - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): - reduce_op = hvd.Average - else: - raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") - - # sync all processes before reduction - hvd.join() - return hvd.allreduce(tensor, op=reduce_op) - - @property - def distributed_sampler_kwargs(self): - return dict(num_replicas=hvd.size(), rank=hvd.rank()) - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py deleted file mode 100644 index 14aa0c4e66706..0000000000000 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import io -import os -import re -from typing import Any, Callable, Optional, Union - -import torch -import torch.multiprocessing as mp -from torch.optim import Optimizer - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import ( - _TPU_AVAILABLE, - move_data_to_device, - rank_zero_info, - rank_zero_only, - rank_zero_warn, -) -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _TPU_AVAILABLE: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as xla_pl - import torch_xla.distributed.xla_multiprocessing as xmp - - -class TPUAccelerator(Accelerator): - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using TPUs (colab, single machine or pod) - - Example:: - - # default - trainer = Trainer(accelerator=TPUAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.start_method = None - self.mp_queue = None - self.nickname = None - - def setup(self, model): - rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') - - # TODO: Move this check to Trainer __init__ or device parser - if not _TPU_AVAILABLE: - raise MisconfigurationException('PyTorch XLA not installed.') - - # see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2 - self.start_method = 'fork' - - # pass in a state q - smp = mp.get_context(self.start_method) - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def teardown(self): - model = self.trainer.model - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - last_path = self.mp_queue.get() - - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback is not None: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also bets score - - # load last weights - if last_path and not self.trainer.evaluating: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - - # when training completes, load the weights back in main process - self.__load_weights_on_main_process() - return results - - def train(self): - model = self.trainer.model - - # train - if self.trainer.tpu_id is not None: - self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue) - else: - xmp.spawn( - self.tpu_train_in_process, - args=(model, self.trainer, self.mp_queue), - nprocs=self.trainer.tpu_cores, - start_method=self.start_method - ) - - def __load_weights_on_main_process(self): - model = self.trainer.model - - # load weights if not interrupted - if self.trainer.on_colab_kaggle and not self.trainer.evaluating: - self.load_spawn_weights(model) - - self.trainer.model = model - - def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None): - """ - Here we are inside each individual process - """ - # Todo: required argument `tpu_core_idx` is not used - if not trainer: - trainer = self.trainer - - trainer.call_setup_hook(model) - - # setup TPU training - self.__setup_tpu_training(model, trainer) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or evaluate - results = self.train_or_evaluate() - - # save weights at the end of training - self.__save_end_of_training_weights(model, trainer) - - # persist info in spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - def _step(self, model_step: Callable, args): - args[0] = self.to_device(args[0]) - return model_step(*args) - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def process_dataloader(self, dataloader): - device = xm.xla_device(self.trainer.tpu_id) - dataloader = xla_pl.ParallelLoader(dataloader, [device]) - dataloader = dataloader.per_device_loader(device) - return dataloader - - def to_device(self, batch): - """ - Transfers the data to the TPU. - - Args: - batch: A tensor or collection of tensors. - - Return: - the tensor on the TPU device. - - See Also: - - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - """ - if not _TPU_AVAILABLE: - raise MisconfigurationException( - 'Requested to transfer batch to TPU but XLA is not available.' - ' Are you sure this machine has TPUs?' - ) - device = xm.xla_device(self.trainer.tpu_id) - - return self.batch_to_device(batch, device) - - def __save_end_of_training_weights(self, model: LightningModule, trainer): - # when training ends on these platforms dump weights to get out of the main process - if trainer.on_colab_kaggle: - rank_zero_warn('cleaning up... please do not interrupt') - self.save_spawn_weights(model) - - def __setup_tpu_training(self, model: LightningModule, trainer): - # use the default device from the process - # tpu_device = xm.xla_device() - - # if given an ordinal device, use this as the device - if trainer.tpu_id is not None: - tpu_device = xm.xla_device(trainer.tpu_id) - else: - tpu_device = xm.xla_device() - # track the device and move model to it - trainer._device = tpu_device - model.to(trainer._device) - - # get the appropriate tpu ranks - trainer.tpu_local_core_rank = xm.get_local_ordinal() - trainer.tpu_global_core_rank = xm.get_ordinal() - - # avoid duplicating progress bar - if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: - trainer.progress_bar_callback.disable() - - trainer.global_rank = trainer.tpu_local_core_rank - rank_zero_only.rank = trainer.global_rank - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # init 16 bit for TPU - if trainer.precision == 16: - os.environ['XLA_USE_BF16'] = str(1) - - log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},' - f' global rank: {trainer.tpu_global_core_rank}' - f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}') - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - # do backward pass - if self.trainer.train_loop.automatic_optimization: - model = self.trainer.get_model() - model.backward(closure_loss, optimizer, opt_idx) - else: - closure_loss.backward(*args, **kwargs) - - # detach after backward - closure_loss = closure_loss.detach() - - return closure_loss - - def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md - model = self.trainer.get_model() - parameters = model.parameters() - max_norm = grad_clip_val - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) - - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) - clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) - for p in parameters: - p.grad.data.mul_(clip_coef.to(p.grad.data.device)) - - def barrier(self, name: Optional[str] = None): - torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32) - stop = xm.mesh_reduce("stop_signal", stop, sum) - torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - should_stop = int(stop.item()) == self.trainer.world_size - return should_stop - - def save_spawn_weights(self, model): - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - """ - # Todo: required argument `model` is not used - if self.trainer.is_global_zero: - path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') - self.trainer.save_checkpoint(path) - return path - - def load_spawn_weights(self, original_model): - """ - Load the temp weights saved in the process - To recover the trained model from the ddp process we load the saved weights - """ - - loaded_model = original_model - - if self.trainer.is_global_zero: - # load weights saved in ddp - path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') - loaded_model = original_model.__class__.load_from_checkpoint(path) - - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) - - # remove ddp weights - os.remove(path) - - return loaded_model - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"): - return - - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.evaluating and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - state_dict = move_data_to_device(model.state_dict(), torch.device("cpu")) - atomic_save(state_dict, last_path) - mp_queue.put(last_path) - - def broadcast(self, obj, src=0): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float) - data = xm.all_gather(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return tensor - - @property - def norm_clipping_epsilon(self): - return 1e-6 - - def on_save(self, checkpoint): - """ - Move XLA tensors to CPU before saving - Recommended on XLA Guide: - https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors - """ - return move_data_to_device(checkpoint, torch.device("cpu")) - - @property - def distributed_sampler_kwargs(self): - return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - - @property - def require_distributed_sampler(self): - return True diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py deleted file mode 100644 index 1fa2ca24e6307..0000000000000 --- a/tests/base/datamodules.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional - -from torch.utils.data import DataLoader, random_split -from torch.utils.data.distributed import DistributedSampler - -from pytorch_lightning.core.datamodule import LightningDataModule -from tests.base.datasets import MNIST, TrialMNIST - - -class TrialMNISTDataModule(LightningDataModule): - def __init__(self, data_dir: str = "./"): - super().__init__() - self.data_dir = data_dir - self.non_picklable = None - self.checkpoint_state: Optional[str] = None - - def prepare_data(self): - TrialMNIST(self.data_dir, train=True, download=True) - TrialMNIST(self.data_dir, train=False, download=True) - - def setup(self, stage: Optional[str] = None): - - if stage != 'test': - mnist_full = TrialMNIST( - root=self.data_dir, train=True, num_samples=64, download=True - ) - self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) - self.dims = self.mnist_train[0][0].shape - - if stage == "test" or stage is None: - self.mnist_test = TrialMNIST( - root=self.data_dir, train=False, num_samples=64, download=True - ) - self.dims = getattr(self, "dims", self.mnist_test[0][0].shape) - - self.non_picklable = lambda x: x ** 2 - - def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=32) - - def val_dataloader(self): - return DataLoader(self.mnist_val, batch_size=32) - - def test_dataloader(self): - return DataLoader(self.mnist_test, batch_size=32) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - checkpoint[self.__class__.__name__] = self.__class__.__name__ - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.checkpoint_state = checkpoint.get(self.__class__.__name__) - - -class MNISTDataModule(LightningDataModule): - def __init__( - self, data_dir: str = "./", batch_size: int = 32, dist_sampler: bool = False - ) -> None: - super().__init__() - - self.dist_sampler = dist_sampler - self.data_dir = data_dir - self.batch_size = batch_size - - # self.dims is returned when you call dm.size() - # Setting default dims here because we know them. - # Could optionally be assigned dynamically in dm.setup() - self.dims = (1, 28, 28) - - def prepare_data(self): - # download only - MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081)) - MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081)) - - def setup(self, stage: Optional[str] = None): - - # Assign train/val datasets for use in dataloaders - # TODO: need to split using random_split once updated to torch >= 1.6 - if stage != 'test': - self.mnist_train = MNIST( - self.data_dir, train=True, normalize=(0.1307, 0.3081) - ) - - # Assign test dataset for use in dataloader(s) - if stage == "test" or stage is None: - self.mnist_test = MNIST( - self.data_dir, train=False, normalize=(0.1307, 0.3081) - ) - - def train_dataloader(self): - dist_sampler = None - if self.dist_sampler: - dist_sampler = DistributedSampler(self.mnist_train, shuffle=False) - - return DataLoader( - self.mnist_train, - batch_size=self.batch_size, - sampler=dist_sampler, - shuffle=False, - ) - - def test_dataloader(self): - return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False) diff --git a/tests/trainer/test_trainer_validate_loop.py b/tests/trainer/test_trainer_validate_loop.py deleted file mode 100644 index ec8dd82260e3f..0000000000000 --- a/tests/trainer/test_trainer_validate_loop.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch - -import pytorch_lightning as pl -import tests.base.develop_utils as tutils -from tests.base import BoringModel - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_single_gpu_validate(tmpdir): - tutils.set_random_master_port() - - model = BoringModel() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0], - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.validate() - assert 'x' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.validate(model) - assert 'x' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_ddp_spawn_validate(tmpdir): - tutils.set_random_master_port() - - model = BoringModel() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - distributed_backend='ddp_spawn', - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.validate() - assert 'x' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.validate(model) - assert 'x' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) From 18280dfa096e9ea38f3970ad8f6544621de40d11 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 18:01:13 +0100 Subject: [PATCH 23/32] Implement Trainer.validate --- CHANGELOG.md | 3 + docs/source/common/trainer.rst | 15 +- pytorch_lightning/callbacks/progress.py | 11 +- .../trainer/configuration_validator.py | 21 ++- pytorch_lightning/trainer/trainer.py | 70 +++++++- tests/accelerators/test_common.py | 43 +++++ tests/accelerators/test_dp.py | 33 ---- tests/callbacks/test_callbacks.py | 82 +++++++--- tests/callbacks/test_progress_bar.py | 22 ++- tests/checkpointing/test_model_checkpoint.py | 10 ++ tests/core/test_datamodules.py | 114 +------------- tests/models/test_hooks.py | 16 ++ tests/plugins/test_sharded_plugin.py | 26 +-- tests/trainer/optimization/test_optimizers.py | 21 ++- tests/trainer/test_config_validator.py | 62 +++++--- tests/trainer/test_dataloaders.py | 149 ++++++------------ tests/trainer/test_trainer.py | 70 ++++---- tests/trainer/test_trainer_test_loop.py | 76 --------- 18 files changed, 413 insertions(+), 431 deletions(-) create mode 100644 tests/accelerators/test_common.py delete mode 100644 tests/trainer/test_trainer_test_loop.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 327f923a79ff1..f3bf6b4021731 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948)) + + - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 17cfc7eccbc20..6edf896ada01c 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -151,6 +151,19 @@ So you can run it like so: ------------ +Validation +---------- +You can perform an evaluation epoch over the validation set, outside of the training loop, +using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model right at its initialization +or after it has already been trained. + +.. code-block:: python + + trainer.validate(val_dataloaders=val_dataloaders) + +------------ + Testing ------- Once you're done training, feel free to run the test set! @@ -158,7 +171,7 @@ Once you're done training, feel free to run the test set! .. code-block:: python - trainer.test(test_dataloaders=test_dataloader) + trainer.test(test_dataloaders=test_dataloaders) ------------ diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index c382e67b21a64..74e57e2b5642e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -355,9 +355,11 @@ def init_predict_tqdm(self) -> tqdm: def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ + # The main progress bar doesn't exist in `trainer.validate()` + has_main_bar = self.main_progress_bar is not None bar = tqdm( desc='Validating', - position=(2 * self.process_position + 1), + position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, @@ -426,7 +428,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): @@ -479,8 +482,10 @@ def print( def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) - def _update_bar(self, bar): + def _update_bar(self, bar: Optional[tqdm]) -> None: """ Updates the bar by the refresh rate without overshooting. """ + if bar is None: + return if bar.total is not None: delta = min(self.refresh_rate, bar.total - bar.n) else: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 1bf38048ee159..8c539b5ff478d 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -22,7 +23,7 @@ class ConfigValidator(object): def __init__(self, trainer): self.trainer = trainer - def verify_loop_configurations(self, model: LightningModule): + def verify_loop_configurations(self, model: LightningModule) -> None: r""" Checks that the model is configured correctly before the run is started. @@ -30,10 +31,16 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if self.trainer.training: + if self.trainer.state == TrainerState.FITTING: self.__verify_train_loop_configuration(model) - elif self.trainer.evaluating: - self.__verify_eval_loop_configuration(model) + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TUNING: + self.__verify_train_loop_configuration(model) + elif self.trainer.state == TrainerState.VALIDATING: + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TESTING: + self.__verify_eval_loop_configuration(model, 'test') + # TODO: add predict def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -81,11 +88,9 @@ def __verify_train_loop_configuration(self, model): ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model): - stage = "val" if self.trainer.validating else "test" - + def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None: loader_name = f'{stage}_dataloader' - step_name = f'{stage}_step' + step_name = 'validation_step' if stage == 'val' else 'test_step' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 45fc40731b545..fea327896126e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -826,6 +826,65 @@ def run_sanity_check(self, ref_model): self._running_stage = stage + def validate( + self, + model: Optional[LightningModule] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Perform one evaluation epoch over the validation set. + + Args: + model: The model to validate. + + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. + + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. + + verbose: If True, prints the validation results. + + datamodule: A instance of :class:`LightningDataModule`. + + Returns: + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. + """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_evaluate = verbose + + self.state = TrainerState.VALIDATING + self.validating = True + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass both `trainer.val(val_dataloaders=..., datamodule=...)`' + ) + + model_provided = model is not None + model = model or self.lightning_module + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + results = ( + self.__evaluate_given_model(model, dataloaders=val_dataloaders) + if model_provided else + self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=val_dataloaders) + ) + + assert self.state.stopped + self.validating = False + + return results + def test( self, model: Optional[LightningModule] = None, @@ -839,17 +898,18 @@ def test( fit to make sure you never run on your test set until you want to. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model. Default to ``best``. - datamodule: A instance of :class:`LightningDataModule`. - model: The model to test. test_dataloaders: Either a single PyTorch DataLoader or a list of them, specifying test samples. + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the current weights of the model. + verbose: If True, prints the test results. + datamodule: A instance of :class:`LightningDataModule`. + Returns: Returns a list of dictionaries, one for each test dataloader containing their respective metrics. """ @@ -864,7 +924,7 @@ def test( # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( - 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' + 'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`' ) model_provided = model is not None diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py new file mode 100644 index 0000000000000..8944730cb3c3b --- /dev/null +++ b/tests/accelerators/test_common.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from pytorch_lightning import Trainer +from tests.accelerators.test_dp import CustomClassificationModelDP +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("trainer_kwargs", ( + pytest.param({"gpus": 1}, marks=RunIf(min_gpus=1)), + pytest.param({"accelerator": "dp", "gpus": 2}, marks=RunIf(min_gpus=2)), + pytest.param({"accelerator": "ddp_spawn", "gpus": 2}, marks=RunIf(min_gpus=2)), +)) +def test_evaluate(tmpdir, trainer_kwargs, tutils=None): + tutils.set_random_master_port() + + dm = ClassifDataModule() + model = CustomClassificationModelDP() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + deterministic=True, + **trainer_kwargs + ) + + result = trainer.fit(model, datamodule=dm) + assert result + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + + old_weights = model.layer_0.weight.clone().detach().cpu() + + result = trainer.validate(datamodule=dm) + assert result[0]['val_acc'] > 0.7 + + result = trainer.test(datamodule=dm) + assert result[0]['test_acc'] > 0.6 + + # make sure weights didn't change + new_weights = model.layer_0.weight.clone().detach().cpu() + assert torch.testing.assert_allclose(old_weights, new_weights) \ No newline at end of file diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 4736c6788c208..fad3d5ad2daa7 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch import torch.nn.functional as F import pytorch_lightning as pl @@ -24,8 +23,6 @@ from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel -PRETEND_N_OF_GPUS = 16 - class CustomClassificationModelDP(ClassificationModel): @@ -95,36 +92,6 @@ def test_multi_gpu_model_dp(tmpdir): memory.get_memory_profile('min_max') -@RunIf(min_gpus=2) -def test_dp_test(tmpdir): - tutils.set_random_master_port() - - dm = ClassifDataModule() - model = CustomClassificationModelDP() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - accelerator='dp', - ) - trainer.fit(model, datamodule=dm) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test(datamodule=dm) - assert 'test_acc' in results[0] - - old_weights = model.layer_0.weight.clone().detach().cpu() - - results = trainer.test(model, datamodule=dm) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.layer_0.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) - - @RunIf(min_gpus=2) def test_dp_training_step_dict(tmpdir): """ diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 2426348f770bf..626eb59dffb9c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,8 +19,8 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system_fit(_, tmpdir): - """Test the callback system for fit.""" +def test_trainer_callback_hook_system_fit(_, tmpdir): + """Test the callback hook system for fit.""" model = BoringModel() callback_mock = MagicMock() @@ -97,8 +97,8 @@ def test_trainer_callback_system_fit(_, tmpdir): ] -def test_trainer_callback_system_test(tmpdir): - """Test the callback system for test.""" +def test_trainer_callback_hook_system_test(tmpdir): + """Test the callback hook system for test.""" model = BoringModel() callback_mock = MagicMock() @@ -130,6 +130,42 @@ def test_trainer_callback_system_test(tmpdir): ] +def test_trainer_callback_hook_system_validate(tmpdir): + """Test the callback hook system for validate.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_val_batches=2, + progress_bar_refresh_rate=0, + ) + + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.setup(trainer, model, 'validate'), + call.on_before_accelerator_backend_setup(trainer, model), + call.on_validation_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_batch_start(trainer, model, ANY, 1, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), + call.on_validation_epoch_end(trainer, model), + call.on_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.teardown(trainer, model, 'validate'), + ] + + +# TODO: add callback tests for predict and tune + + def test_callbacks_configured_in_model(tmpdir): """ Test the callback system with callbacks added through the model hook. """ @@ -166,22 +202,29 @@ def assert_expected_calls(_trainer, model_callback, trainer_callback): # .fit() trainer_options.update(callbacks=[trainer_callback_mock]) trainer = Trainer(**trainer_options) + assert trainer_callback_mock in trainer.callbacks assert model_callback_mock not in trainer.callbacks trainer.fit(model) + assert model_callback_mock in trainer.callbacks assert trainer.callbacks[-1] == model_callback_mock assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) # .test() - model_callback_mock.reset_mock() - trainer_callback_mock.reset_mock() - trainer_options.update(callbacks=[trainer_callback_mock]) - trainer = Trainer(**trainer_options) - trainer.test(model) - assert model_callback_mock in trainer.callbacks - assert trainer.callbacks[-1] == model_callback_mock - assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) + for fn in ("test", "validate"): + model_callback_mock.reset_mock() + trainer_callback_mock.reset_mock() + + trainer_options.update(callbacks=[trainer_callback_mock]) + trainer = Trainer(**trainer_options) + + trainer_fn = getattr(trainer, fn) + trainer_fn(model) + + assert model_callback_mock in trainer.callbacks + assert trainer.callbacks[-1] == model_callback_mock + assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) def test_configure_callbacks_hook_multiple_calls(tmpdir): @@ -208,10 +251,13 @@ def configure_callbacks(self): callbacks_after_fit = trainer.callbacks.copy() assert callbacks_after_fit == callbacks_before_fit + [model_callback_mock] - trainer.test(model) - callbacks_after_test = trainer.callbacks.copy() - assert callbacks_after_test == callbacks_after_fit + for fn in ("test", "validate"): + trainer_fn = getattr(trainer, fn) + trainer_fn(model) + + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit - trainer.test(ckpt_path=None) - callbacks_after_test = trainer.callbacks.copy() - assert callbacks_after_test == callbacks_after_fit + trainer_fn(ckpt_path=None) + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index e4171a8520353..c0c69f7d03406 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -89,7 +89,6 @@ def test_progress_bar_totals(tmpdir): trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=1, - limit_val_batches=1.0, max_epochs=1, ) bar = trainer.progress_bar_callback @@ -121,6 +120,12 @@ def test_progress_bar_totals(tmpdir): assert 0 == bar.total_test_batches assert bar.test_progress_bar is None + trainer.validate(model) + + assert bar.val_progress_bar.total == m + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + trainer.test(model) # check test progress bar total @@ -156,6 +161,13 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == progress_bar.val_batch_idx + assert 1 == progress_bar.val_progress_bar.total + assert 1 == progress_bar.val_progress_bar.n + trainer.test(model) # the test progress bar should display 1 batch @@ -213,8 +225,16 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 + + trainer.validate(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 trainer.test(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps assert progress_bar.test_batches_seen == progress_bar.total_test_batches diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 48e4a22e1ec05..79a6ed6f86fac 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -847,6 +847,9 @@ def assert_checkpoint_log_dir(idx): assert_checkpoint_log_dir(0) assert_checkpoint_content(ckpt_dir) + trainer.validate(model) + assert trainer.current_epoch == epochs - 1 + trainer.test(model) assert trainer.current_epoch == epochs - 1 @@ -860,17 +863,24 @@ def assert_checkpoint_log_dir(idx): assert_trainer_init(trainer) model = ExtendedBoringModel() + trainer.test(model) assert not trainer.checkpoint_connector.has_trained # resume_from_checkpoint is resumed when calling `.fit` assert trainer.global_step == 0 assert trainer.current_epoch == 0 + trainer.fit(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) + trainer.validate(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + def test_configure_model_checkpoint(tmpdir): """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ab51a87329e2f..1f671380f869b 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -297,20 +297,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ -def test_test_loop_only(tmpdir): - reset_seed() - - dm = BoringDataModule() - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - ) - trainer.test(model, datamodule=dm) - - def test_full_loop(tmpdir): reset_seed() @@ -327,109 +313,17 @@ def test_full_loop(tmpdir): # fit model result = trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert dm.trainer is not None assert result - # test - result = trainer.test(datamodule=dm) - assert result[0]['test_acc'] > 0.6 - - -def test_trainer_attached_to_dm(tmpdir): - reset_seed() - - dm = BoringDataModule() - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - weights_summary=None, - deterministic=True, - ) - - # fit model - trainer.fit(model, dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + # validate + result = trainer.validate(datamodule=dm) assert dm.trainer is not None + assert result[0]['val_acc'] > 0.7 # test result = trainer.test(datamodule=dm) - result = result[0] assert dm.trainer is not None - - -@RunIf(min_gpus=1) -def test_full_loop_single_gpu(tmpdir): - reset_seed() - - dm = ClassifDataModule() - model = ClassificationModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - gpus=1, - deterministic=True, - ) - - # fit model - result = trainer.fit(model, dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert result - - # test - result = trainer.test(datamodule=dm) - assert result[0]['test_acc'] > 0.6 - - -@RunIf(min_gpus=2) -def test_full_loop_dp(tmpdir): - set_random_master_port() - - class CustomClassificationModelDP(ClassificationModel): - - def _step(self, batch, batch_idx): - x, y = batch - logits = self(x) - return {'logits': logits, 'y': y} - - def training_step(self, batch, batch_idx): - out = self._step(batch, batch_idx) - loss = F.cross_entropy(out['logits'], out['y']) - return loss - - def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx) - - def test_step(self, batch, batch_idx): - return self._step(batch, batch_idx) - - def test_step_end(self, outputs): - self.log('test_acc', self.test_acc(outputs['logits'], outputs['y'])) - - dm = ClassifDataModule() - model = CustomClassificationModelDP() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - accelerator='dp', - gpus=2, - deterministic=True, - ) - - # fit model - result = trainer.fit(model, datamodule=dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert result - - # test - result = trainer.test(datamodule=dm) assert result[0]['test_acc'] > 0.6 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7c53925bd7cc4..0d1c7cf40a2bf 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -466,7 +466,23 @@ def teardown(self, stage=None): 'on_fit_end', 'teardown', ] + assert model.called == expected + + model = HookedModel() + trainer.validate(model, verbose=False) + expected = [ + 'on_validation_model_eval', + 'on_validation_start', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_epoch_end', + 'on_validation_end', + 'on_validation_model_train', + 'teardown', + ] assert model.called == expected model = HookedModel() diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index b59563f70e4aa..a48f048160ee5 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -259,30 +259,20 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @RunIf(skip_windows=True, special=True, fairscale=True) -def test_ddp_sharded_plugin_test(tmpdir): +@pytest.mark.parametrize("trainer_kwargs", ( + {'num_processes': 2}, + pytest.param({'gpus': 2}, marks=RunIf(min_gpus=2)) +)) +def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): """ - Test to ensure we can use test without fit + Test to ensure we can use validate and test without fit """ model = BoringModel() trainer = Trainer( accelerator='ddp_sharded_spawn', - num_processes=2, - fast_dev_run=True, - ) - - trainer.test(model) - - -@RunIf(min_gpus=2, skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_test_multigpu(tmpdir): - """ - Test to ensure we can use test without fit - """ - model = BoringModel() - trainer = Trainer( - accelerator='ddp_sharded_spawn', - gpus=2, fast_dev_run=True, + **trainer_kwargs, ) + trainer.validate(model) trainer.test(model) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 01c23ed18fe65..34845c46b45eb 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -287,15 +287,22 @@ def test_configure_optimizers_with_frequency(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -def test_init_optimizers_during_testing(tmpdir): +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_init_optimizers_during_evaluation(tmpdir, fn): """ - Test that optimizers is an empty list during testing. + Test that optimizers is an empty list during evaluation """ - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__multiple_schedulers - - trainer = Trainer(default_root_dir=tmpdir, limit_test_batches=10) - trainer.test(model, ckpt_path=None) + class TestModel(BoringModel): + def configure_optimizers(self): + optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) + optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=1) + lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=10, limit_test_batches=10) + validate_or_test = getattr(trainer, fn) + validate_or_test(TestModel(), ckpt_path=None) assert len(trainer.lr_schedulers) == 0 assert len(trainer.optimizers) == 0 diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 00ad020aa1b57..59e10480a485e 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -13,12 +13,9 @@ # limitations under the License. import pytest -import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate - -# TODO: add matching messages +from tests.helpers import BoringModel def test_wrong_train_setting(tmpdir): @@ -26,49 +23,44 @@ def test_wrong_train_setting(tmpdir): * Test that an error is thrown when no `train_dataloader()` is defined * Test that an error is thrown when no `training_step()` is defined """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(**hparams) + with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'): + model = BoringModel() model.train_dataloader = None trainer.fit(model) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(**hparams) + with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'): + model = BoringModel() model.training_step = None trainer.fit(model) def test_wrong_configure_optimizers(tmpdir): """ Test that an error is thrown when no `configure_optimizers()` is defined """ - tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate() + with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'): + model = BoringModel() model.configure_optimizers = None trainer.fit(model) -def test_val_loop_config(tmpdir): +def test_fit_val_loop_config(tmpdir): """" When either val loop or val data are missing raise warning """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() model.validation_step = None trainer.fit(model) # has val loop but no val data - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() model.val_dataloader = None trainer.fit(model) @@ -77,17 +69,35 @@ def test_test_loop_config(tmpdir): """" When either test loop or test data are missing """ - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'): + model = BoringModel() model.test_dataloader = None trainer.test(model) # has test data but no test loop - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'): + model = BoringModel() model.test_step = None - trainer.test(model, test_dataloaders=model.dataloader(train=False)) + trainer.test(model) + + +def test_val_loop_config(tmpdir): + """" + When either validation loop or validation data are missing + """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has val loop but no val data + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() + model.val_dataloader = None + trainer.validate(model) + + # has val data but no val loop + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() + model.validation_step = None + trainer.validate(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 5530779b4f77d..f3cde9f2f6eab 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -125,8 +125,7 @@ def test_multiple_val_dataloader(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # verify there are 2 val loaders - assert len(trainer.val_dataloaders) == 2, \ - 'Multiple val_dataloaders not initiated properly' + assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: @@ -134,18 +133,22 @@ def test_multiple_val_dataloader(tmpdir): @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_multiple_test_dataloader(tmpdir, ckpt_path): - """Verify multiple test_dataloader.""" - - model_template = EvalModelTemplate() +def test_multiple_eval_dataloader(tmpdir, ckpt_path): + """Verify multiple evaluation dataloaders.""" class MultipleTestDataloaderModel(EvalModelTemplate): - def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] - def test_step(self, batch, batch_idx, *args, **kwargs): - return model_template.test_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs) + def test_step(self, *args, **kwargs): + return super().test_step__multiple_dataloaders(*args, **kwargs) + + def val_dataloader(self): + return self.test_dataloader() + + def validation_step(self, *args, **kwargs): + output = self.test_step(*args, **kwargs) + return {k.replace("test_", "val_"): v for k, v in output.items()} model = MultipleTestDataloaderModel() @@ -159,17 +162,18 @@ def test_step(self, batch, batch_idx, *args, **kwargs): trainer.fit(model) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path - trainer.test(ckpt_path=ckpt_path) - - # verify there are 2 test loaders - assert len(trainer.test_dataloaders) == 2, 'Multiple test_dataloaders not initiated properly' - # make sure predictions are good for each test set - for dataloader in trainer.test_dataloaders: + trainer.validate(ckpt_path=ckpt_path) + # verify there are 2 loaders + assert len(trainer.val_dataloaders) == 2 + # make sure predictions are good for each dl + for dataloader in trainer.val_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader) - # run the test method trainer.test(ckpt_path=ckpt_path) + assert len(trainer.test_dataloaders) == 2 + for dataloader in trainer.test_dataloaders: + tpipes.run_prediction_eval_model_template(trainer.model, dataloader) def test_train_dataloader_passed_to_fit(tmpdir): @@ -189,90 +193,45 @@ def test_train_dataloader_passed_to_fit(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -def test_train_val_dataloaders_passed_to_fit(tmpdir): - """ Verify that train & val dataloader can be passed to fit """ - - # train, val passed to fit - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) - fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) - - trainer.fit(model, **fit_options) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - - @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path): - """Verify train, val & test dataloader(s) can be passed to fit and test method""" +@pytest.mark.parametrize("n", (1, 2)) +def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): + """Verify that dataloaders can be passed.""" model = EvalModelTemplate() + if n == 1: + dataloaders = model.dataloader(train=False) + else: + dataloaders = [model.dataloader(train=False)] * 2 + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders - # train, val and test passed to fit + # train, multiple val and multiple test passed to fit trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) - fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) - trainer.fit(model, **fit_options) - - if ckpt_path == 'specific': - ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) - trainer.test(**test_options) + trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert len(trainer.val_dataloaders) == 1, \ - f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - - -@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): - """Verify that multiple val & test dataloaders can be passed to fit.""" + assert len(trainer.val_dataloaders) == n - model = EvalModelTemplate() - model.validation_step = model.validation_step__multiple_dataloaders - model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders - model.test_step = model.test_step__multiple_dataloaders - - # train, multiple val and multiple test passed to fit - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) - fit_options = dict( - train_dataloader=model.dataloader(train=True), - val_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)] - ) - trainer.fit(model, **fit_options) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict( - test_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)], ckpt_path=ckpt_path - ) - trainer.test(**test_options) - assert len(trainer.val_dataloaders) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' + trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path) + trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) + + assert len(trainer.val_dataloaders) == n + assert len(trainer.test_dataloaders) == n @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0.0, 0.0, 0.0), - pytest.param(1.0, 1.0, 1.0), + (0.0, 0.0, 0.0), + (1.0, 1.0, 1.0), ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" @@ -299,8 +258,8 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0, 0, 0), - pytest.param(10, 10, 10), + (0, 0, 0), + (10, 10, 10), ]) def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" @@ -327,10 +286,10 @@ def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, lim @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0.0, 0.0, 0.0), - pytest.param(0, 0, 0.5), - pytest.param(1.0, 1.0, 1.0), - pytest.param(0.2, 0.4, 0.4), + (0.0, 0.0, 0.0), + (0, 0, 0.5), + (1.0, 1.0, 1.0), + (0.2, 0.4, 0.4), ]) def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" @@ -362,9 +321,9 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0, 0, 0), - pytest.param(1, 2, 3), - pytest.param(1, 2, 1e50), + (0, 0, 0), + (1, 2, 3), + (1, 2, 1e50), ]) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): @@ -445,10 +404,10 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): if fast_dev_run == 'temp': with pytest.raises(MisconfigurationException, match='either a bool or an int'): - trainer = Trainer(**trainer_options) + Trainer(**trainer_options) elif fast_dev_run == -1: with pytest.raises(MisconfigurationException, match='should be >= 0'): - trainer = Trainer(**trainer_options) + Trainer(**trainer_options) else: trainer = Trainer(**trainer_options) @@ -1191,12 +1150,6 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir): train = RandomDataset(32, 64) context = 'spawn' train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) - - class ExtendedBoringModel(BoringModel): - - def train_dataloader(self): - return train - trainer = Trainer( max_epochs=1, progress_bar_refresh_rate=20, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3e090fb44943e..2592f138c5b9f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -599,44 +599,57 @@ def test_benchmark_option(tmpdir): assert torch.backends.cudnn.benchmark -@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) -@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) -def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k): - hparams = EvalModelTemplate.get_default_hparams() +@pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) +@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + self.log("foo", -batch_idx) + return super().validation_step(batch, batch_idx) - model = EvalModelTemplate(**hparams) + model = TestModel() trainer = Trainer( max_epochs=2, progress_bar_refresh_rate=0, default_root_dir=tmpdir, - callbacks=[ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k)], + callbacks=[ModelCheckpoint(monitor="foo", save_top_k=save_top_k)], ) trainer.fit(model) + + test_or_validate = getattr(trainer, fn) if ckpt_path == "best": # ckpt_path is 'best', meaning we load the best weights if save_top_k == 0: with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): - trainer.test(ckpt_path=ckpt_path) + test_or_validate(ckpt_path=ckpt_path) else: - trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + else: + assert trainer.validated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training - trainer.test(ckpt_path=ckpt_path) + test_or_validate(ckpt_path=ckpt_path) assert trainer.tested_ckpt_path is None + assert trainer.validated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: with pytest.raises(FileNotFoundError): - trainer.test(ckpt_path="random.ckpt") + test_or_validate(ckpt_path="random.ckpt") else: ckpt_path = str( list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir() )[0].absolute() ) - trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == ckpt_path + else: + assert trainer.validated_ckpt_path == ckpt_path def test_disabled_training(tmpdir): @@ -1292,10 +1305,11 @@ def test_trainer_pickle(tmpdir): cloudpickle.dumps(trainer) -def test_trainer_setup_call(tmpdir): - """Test setup call with fit and test call.""" +@pytest.mark.parametrize("stage", ("fit", "validate", "test")) +def test_trainer_setup_call(tmpdir, stage): + """Test setup call gets the correct stage""" - class CurrentModel(EvalModelTemplate): + class CurrentModel(BoringModel): def setup(self, stage): self.stage = stage @@ -1311,21 +1325,23 @@ def setup(self, model, stage): # fit model trainer = TrainerSubclass(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False) - trainer.fit(model) - assert trainer.stage == "fit" - assert trainer.lightning_module.stage == "fit" + if stage == "fit": + trainer.fit(model) + elif stage == "validate": + trainer.validate(model, ckpt_path=None) + else: + trainer.test(model, ckpt_path=None) - trainer.test(ckpt_path=None) - assert trainer.stage == "test" - assert trainer.lightning_module.stage == "test" + assert trainer.stage == stage + assert trainer.lightning_module.stage == stage @pytest.mark.parametrize( "train_batches, max_steps, log_interval", [ - pytest.param(10, 10, 1), - pytest.param(3, 10, 1), - pytest.param(3, 10, 5), + (10, 10, 1), + (3, 10, 1), + (3, 10, 5), ], ) @patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") @@ -1398,7 +1414,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] model = BoringModel() - datamodule = TestLightningDataModule(dataloaders) + dm = TestLightningDataModule(dataloaders) trainer = Trainer( default_root_dir=tmpdir, @@ -1411,7 +1427,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T plugins=plugins, ) if datamodule: - results = trainer.predict(model, datamodule=datamodule) + results = trainer.predict(model, datamodule=dm) else: results = trainer.predict(model, dataloaders=dataloaders) diff --git a/tests/trainer/test_trainer_test_loop.py b/tests/trainer/test_trainer_test_loop.py deleted file mode 100644 index 7e2a9299fc8a0..0000000000000 --- a/tests/trainer/test_trainer_test_loop.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -import pytorch_lightning as pl -import tests.helpers.utils as tutils -from tests.base import EvalModelTemplate -from tests.helpers.runif import RunIf - - -@RunIf(min_gpus=2) -def test_single_gpu_test(tmpdir): - tutils.set_random_master_port() - - model = EvalModelTemplate() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0], - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test() - assert 'test_acc' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.test(model) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) - - -@RunIf(min_gpus=2) -def test_ddp_spawn_test(tmpdir): - tutils.set_random_master_port() - - model = EvalModelTemplate() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - accelerator='ddp_spawn', - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test() - assert 'test_acc' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.test(model) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) From e582d58aefda46f2b22bd46e86ef798ee7133009 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 18:44:00 +0100 Subject: [PATCH 24/32] Refactor --- .../trainer/connectors/data_connector.py | 10 +-- pytorch_lightning/trainer/trainer.py | 66 +++++++------------ tests/trainer/test_dataloaders.py | 4 +- 3 files changed, 30 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9e08cf031175f..d787f796f3d88 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -93,10 +93,10 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa def attach_dataloaders( self, model, - train_dataloader=None, - val_dataloaders=None, - test_dataloaders=None, - predict_dataloaders=None, + train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations @@ -112,7 +112,7 @@ def attach_dataloaders( if predict_dataloaders is not None: model.predict_dataloader = _PatchDataLoader(predict_dataloaders) - def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None: + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None: # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fea327896126e..ba38c6e8c4b1e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -874,11 +874,14 @@ def validate( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model, datamodule) - results = ( - self.__evaluate_given_model(model, dataloaders=val_dataloaders) - if model_provided else - self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=val_dataloaders) - ) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) + + if not model_provided: + self.__evaluate_using_weights(model, ckpt_path=ckpt_path) + + # run validate + results = self.fit(model) assert self.state.stopped self.validating = False @@ -932,11 +935,14 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model, datamodule) - results = ( - self.__evaluate_given_model(model, dataloaders=test_dataloaders) - if model_provided else - self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) - ) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + + if not model_provided: + self.__evaluate_using_weights(model, ckpt_path=ckpt_path) + + # run test + results = self.fit(model) assert self.state.stopped self.testing = False @@ -947,7 +953,6 @@ def __evaluate_using_weights( self, model, ckpt_path: Optional[str] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None ): # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: @@ -961,43 +966,22 @@ def __evaluate_using_weights( if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path - if len(ckpt_path) == 0: - rank_zero_warn( - f'`.test()` found no path for the best weights, {ckpt_path}. Please' + if not ckpt_path: + raise MisconfigurationException( + f'`.test()` found no path for the best weights: "{ckpt_path}". Please' ' specify a path for a checkpoint `.test(ckpt_path=PATH)`' ) - return {} self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) - # attach dataloaders - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - if self.validating: self.validated_ckpt_path = ckpt_path else: self.tested_ckpt_path = ckpt_path - # run test - results = self.fit(model) - - return results - - def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): - # attach data - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - - # run test - # sets up testing so we short circuit to eval - results = self.fit(model) - - return results - def predict( self, model: Optional[LightningModule] = None, @@ -1037,15 +1021,11 @@ def predict( 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' ) - if datamodule is not None: - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - - # attach data - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - self.model = model results = self.fit(model) assert self.state.stopped diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f3cde9f2f6eab..e4aea38fb7f37 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -163,14 +163,14 @@ def validation_step(self, *args, **kwargs): if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path - trainer.validate(ckpt_path=ckpt_path) + trainer.validate(ckpt_path=ckpt_path, verbose=False) # verify there are 2 loaders assert len(trainer.val_dataloaders) == 2 # make sure predictions are good for each dl for dataloader in trainer.val_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader) - trainer.test(ckpt_path=ckpt_path) + trainer.test(ckpt_path=ckpt_path, verbose=False) assert len(trainer.test_dataloaders) == 2 for dataloader in trainer.test_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader) From 5b99ec06eedef1444c274bab9c536c2a25e6b499 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 15:39:00 +0100 Subject: [PATCH 25/32] flake8 --- tests/accelerators/test_common.py | 2 +- tests/core/test_datamodules.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 8944730cb3c3b..9fcb54caa4603 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -40,4 +40,4 @@ def test_evaluate(tmpdir, trainer_kwargs, tutils=None): # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() - assert torch.testing.assert_allclose(old_weights, new_weights) \ No newline at end of file + assert torch.testing.assert_allclose(old_weights, new_weights) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 1f671380f869b..2118fec6c207b 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -19,7 +19,6 @@ import pytest import torch -import torch.nn.functional as F from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -29,7 +28,7 @@ from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel -from tests.helpers.utils import reset_seed, set_random_master_port +from tests.helpers.utils import reset_seed @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) From 9f4dce2c42c0a32046d9fc9adecb3625a7eac465 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 16:02:09 +0100 Subject: [PATCH 26/32] Refactor --- pytorch_lightning/trainer/trainer.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c185a03ca7859..83753c7506b1f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -869,7 +869,7 @@ def validate( # If you supply a datamodule you can't supply val_dataloaders if val_dataloaders and datamodule: raise MisconfigurationException( - 'You cannot pass both `trainer.val(val_dataloaders=..., datamodule=...)`' + 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' ) model_provided = model is not None @@ -881,7 +881,7 @@ def validate( self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) if not model_provided: - self.__evaluate_using_weights(model, ckpt_path=ckpt_path) + self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) # run validate results = self.fit(model) @@ -942,7 +942,7 @@ def test( self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) if not model_provided: - self.__evaluate_using_weights(model, ckpt_path=ckpt_path) + self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) # run test results = self.fit(model) @@ -952,11 +952,11 @@ def test( return results - def __evaluate_using_weights( + def __load_ckpt_weights( self, model, ckpt_path: Optional[str] = None, - ): + ) -> Optional[str]: # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( @@ -979,11 +979,7 @@ def __evaluate_using_weights( ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) - - if self.validating: - self.validated_ckpt_path = ckpt_path - else: - self.tested_ckpt_path = ckpt_path + return ckpt_path def predict( self, From 088d4bc0ca4e3efb4acde01b449b955d7cd5e914 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 16:33:31 +0100 Subject: [PATCH 27/32] Missing import --- tests/accelerators/test_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 9fcb54caa4603..6301f19b456e3 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -1,6 +1,7 @@ import pytest import torch +import tests.helpers.utils as tutils from pytorch_lightning import Trainer from tests.accelerators.test_dp import CustomClassificationModelDP from tests.helpers.datamodules import ClassifDataModule @@ -12,7 +13,7 @@ pytest.param({"accelerator": "dp", "gpus": 2}, marks=RunIf(min_gpus=2)), pytest.param({"accelerator": "ddp_spawn", "gpus": 2}, marks=RunIf(min_gpus=2)), )) -def test_evaluate(tmpdir, trainer_kwargs, tutils=None): +def test_evaluate(tmpdir, trainer_kwargs): tutils.set_random_master_port() dm = ClassifDataModule() From 58fcca41c255652bcbabba9b85fbd7554b796114 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 17:40:32 +0100 Subject: [PATCH 28/32] Fix test --- tests/accelerators/test_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 6301f19b456e3..3265f621aa872 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -34,11 +34,11 @@ def test_evaluate(tmpdir, trainer_kwargs): old_weights = model.layer_0.weight.clone().detach().cpu() result = trainer.validate(datamodule=dm) - assert result[0]['val_acc'] > 0.7 + assert result[0]['val_acc'] > 0.55 result = trainer.test(datamodule=dm) assert result[0]['test_acc'] > 0.6 # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() - assert torch.testing.assert_allclose(old_weights, new_weights) + torch.testing.assert_allclose(old_weights, new_weights) From babb73da7d58dd87e11f52ff329fe2104bb7cd35 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 8 Mar 2021 18:07:55 +0100 Subject: [PATCH 29/32] Same threshold --- tests/accelerators/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 3265f621aa872..6962af7249d1b 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -37,7 +37,7 @@ def test_evaluate(tmpdir, trainer_kwargs): assert result[0]['val_acc'] > 0.55 result = trainer.test(datamodule=dm) - assert result[0]['test_acc'] > 0.6 + assert result[0]['test_acc'] > 0.55 # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() From 235dc278a5b29e7a5cff0484f4f15bb6543303d5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Mar 2021 00:29:20 +0100 Subject: [PATCH 30/32] Address tchaton's comments --- pytorch_lightning/trainer/states.py | 2 +- pytorch_lightning/trainer/trainer.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 33a2326c518d5..b1f188ab047fe 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -58,7 +58,7 @@ class RunningStage(LightningEnum): """ TRAINING = 'train' SANITY_CHECKING = 'sanity_check' - VALIDATING = 'validation' + VALIDATING = 'validate' TESTING = 'test' PREDICTING = 'predict' TUNING = 'tune' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 83753c7506b1f..c21f65803f3e6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -970,9 +970,10 @@ def __load_ckpt_weights( ckpt_path = self.checkpoint_callback.best_model_path if not ckpt_path: + fn = self.state.value raise MisconfigurationException( - f'`.test()` found no path for the best weights: "{ckpt_path}". Please' - ' specify a path for a checkpoint `.test(ckpt_path=PATH)`' + f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' + ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) self.training_type_plugin.barrier() From e423b98c119b4da095a5b6bc2e923728a087bee8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Mar 2021 02:57:58 +0100 Subject: [PATCH 31/32] Missing import --- tests/accelerators/test_dp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 23e3ce67cfc8b..6b84e1a70ae58 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch import torch.nn.functional as F from torch.utils.data import DataLoader From 8fab50f035c328ead9e7f73b7e2f35ab3cfd5451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 10 Mar 2021 19:20:43 +0100 Subject: [PATCH 32/32] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e27c516c257f..c3039d24aadc0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -839,6 +839,7 @@ def validate( ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. verbose: If True, prints the validation results. @@ -902,6 +903,7 @@ def test( ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. verbose: If True, prints the test results.