From e891ceb83651160ee1497455ffd63a75621bd013 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 13 Apr 2021 04:37:54 -0700 Subject: [PATCH] Remove evaluation loop legacy dict returns for `*_epoch_end` hooks (#6973) Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 3 ++ .../logger_connector/logger_connector.py | 29 +---------- pytorch_lightning/trainer/evaluation_loop.py | 48 +++---------------- pytorch_lightning/trainer/trainer.py | 10 ++-- .../trainer/data_flow/test_eval_loop_flow.py | 13 ++--- 5 files changed, 18 insertions(+), 85 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0d229887b283..a28a18be47184 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed evaluation loop legacy returns for `*_epoch_end` hooks ([#6973](https://github.com/PyTorchLightning/pytorch-lightning/pull/6973)) + + - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8fd716ee64ad7..904fb08a2cdc9 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -25,7 +25,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.trainer.states import RunningStage, TrainerState -from pytorch_lightning.utilities import DeviceType, flatten_dict +from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -297,33 +297,6 @@ def get_evaluate_epoch_results(self): self.eval_loop_results = [] return results - def _track_callback_metrics(self, eval_results): - if len(eval_results) > 0 and (eval_results[0] is None or not isinstance(eval_results[0], Result)): - return - - flat = {} - if isinstance(eval_results, list): - for eval_result in eval_results: - # with a scalar return, auto set it to "val_loss" for callbacks - if isinstance(eval_result, torch.Tensor): - flat = {'val_loss': eval_result} - elif isinstance(eval_result, dict): - flat = flatten_dict(eval_result) - - self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): - self.trainer.logger_connector.evaluation_callback_metrics.update(flat) - else: - # with a scalar return, auto set it to "val_loss" for callbacks - if isinstance(eval_results, torch.Tensor): - flat = {'val_loss': eval_results} - else: - flat = flatten_dict(eval_results) - - self.trainer.logger_connector.callback_metrics.update(flat) - if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): - self.trainer.logger_connector.evaluation_callback_metrics.update(flat) - def on_train_epoch_end(self): # inform cached logger connector epoch finished self.cached_results.has_batch_loop_finished = True diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8b7543e6bf50f..8ea2c79460805 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -190,61 +190,25 @@ def evaluation_epoch_end(self): # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() - # call the model epoch end - deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders) - - # enable returning anything - for i, r in enumerate(deprecated_results): - if not isinstance(r, (dict, Result, torch.Tensor)): - deprecated_results[i] = [] - - return deprecated_results - - def log_epoch_metrics_on_evaluation_end(self): - # get the final loop results - eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() - return eval_loop_results - - def __run_eval_epoch_end(self, num_dataloaders): - model = self.trainer.lightning_module - - # with a single dataloader don't pass an array outputs = self.outputs + # with a single dataloader don't pass an array + eval_results = outputs[0] if self.num_dataloaders == 1 else outputs - eval_results = outputs - if num_dataloaders == 1: - eval_results = outputs[0] - - user_reduced = False + # call the model epoch end + model = self.trainer.lightning_module if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' - eval_results = model.test_epoch_end(eval_results) - user_reduced = True + model.test_epoch_end(eval_results) else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' - eval_results = model.validation_epoch_end(eval_results) - user_reduced = True + model.validation_epoch_end(eval_results) # capture logging self.trainer.logger_connector.cache_logged_metrics() - # depre warning - if eval_results is not None and user_reduced: - step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end' - self.warning_cache.warn( - f'The {step} should not return anything as of 9.1.' - ' To log, use self.log(...) or self.write(...) directly in the LightningModule' - ) - - if not isinstance(eval_results, list): - eval_results = [eval_results] - - self.trainer.logger_connector._track_callback_metrics(eval_results) - - return eval_results def __gather_epoch_end_eval_results(self, outputs): eval_results = [] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e6bf36df92a01..eefb8cbee408f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -712,7 +712,7 @@ def run_evaluation(self, on_epoch=False): self.evaluation_loop.outputs.append(dl_outputs) # lightning module method - deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() + self.evaluation_loop.evaluation_epoch_end() # hook self.evaluation_loop.on_evaluation_epoch_end() @@ -725,7 +725,7 @@ def run_evaluation(self, on_epoch=False): self.evaluation_loop.on_evaluation_end() # log epoch metrics - eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end() + eval_loop_results = self.logger_connector.get_evaluate_epoch_results() # save predictions to disk self.evaluation_loop.predictions.to_disk() @@ -735,7 +735,7 @@ def run_evaluation(self, on_epoch=False): torch.set_grad_enabled(True) - return eval_loop_results, deprecated_eval_results + return eval_loop_results def track_output_for_epoch_end(self, outputs, output): if output is not None: @@ -757,7 +757,7 @@ def run_evaluate(self): assert self.evaluating with self.profiler.profile(f"run_{self._running_stage}_evaluation"): - eval_loop_results, _ = self.run_evaluation() + eval_loop_results = self.run_evaluation() if len(eval_loop_results) == 0: return 1 @@ -831,7 +831,7 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_start() # run eval step - _, eval_results = self.run_evaluation() + self.run_evaluation() self.on_sanity_check_end() diff --git a/tests/trainer/data_flow/test_eval_loop_flow.py b/tests/trainer/data_flow/test_eval_loop_flow.py index 575de5727a21e..8fdb321b6f230 100644 --- a/tests/trainer/data_flow/test_eval_loop_flow.py +++ b/tests/trainer/data_flow/test_eval_loop_flow.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests to ensure that the training loop works with a dict (1.0) +Tests the evaluation loop """ -import pytest import torch from pytorch_lightning import Trainer @@ -189,8 +188,6 @@ def validation_epoch_end(self, outputs): assert out_a == self.out_a assert out_b == self.out_b - return {'no returns needed'} - def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) @@ -206,8 +203,7 @@ def backward(self, loss, optimizer, optimizer_idx): weights_summary=None, ) - with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"): - trainer.fit(model) + trainer.fit(model) # make sure correct steps were called assert model.validation_step_called @@ -254,8 +250,6 @@ def validation_epoch_end(self, outputs): assert out_a == self.out_a assert out_b == self.out_b - return {'no returns needed'} - def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) @@ -270,8 +264,7 @@ def backward(self, loss, optimizer, optimizer_idx): weights_summary=None, ) - with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"): - trainer.fit(model) + trainer.fit(model) # make sure correct steps were called assert model.validation_step_called