Skip to content

Commit

Permalink
Remove evaluation loop legacy dict returns for *_epoch_end hooks (#…
Browse files Browse the repository at this point in the history
…6973)



Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
ananthsub and carmocca committed Apr 13, 2021
1 parent 23e8dff commit e891ceb
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 85 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed evaluation loop legacy returns for `*_epoch_end` hooks ([#6973](https://github.com/PyTorchLightning/pytorch-lightning/pull/6973))


- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.utilities import DeviceType, flatten_dict
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden

Expand Down Expand Up @@ -297,33 +297,6 @@ def get_evaluate_epoch_results(self):
self.eval_loop_results = []
return results

def _track_callback_metrics(self, eval_results):
if len(eval_results) > 0 and (eval_results[0] is None or not isinstance(eval_results[0], Result)):
return

flat = {}
if isinstance(eval_results, list):
for eval_result in eval_results:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_result, torch.Tensor):
flat = {'val_loss': eval_result}
elif isinstance(eval_result, dict):
flat = flatten_dict(eval_result)

self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
flat = {'val_loss': eval_results}
else:
flat = flatten_dict(eval_results)

self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True
Expand Down
48 changes: 6 additions & 42 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,61 +190,25 @@ def evaluation_epoch_end(self):
# unset dataloder_idx in model
self.trainer.logger_connector.evaluation_epoch_end()

# call the model epoch end
deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders)

# enable returning anything
for i, r in enumerate(deprecated_results):
if not isinstance(r, (dict, Result, torch.Tensor)):
deprecated_results[i] = []

return deprecated_results

def log_epoch_metrics_on_evaluation_end(self):
# get the final loop results
eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results()
return eval_loop_results

def __run_eval_epoch_end(self, num_dataloaders):
model = self.trainer.lightning_module

# with a single dataloader don't pass an array
outputs = self.outputs
# with a single dataloader don't pass an array
eval_results = outputs[0] if self.num_dataloaders == 1 else outputs

eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]

user_reduced = False
# call the model epoch end
model = self.trainer.lightning_module

if self.trainer.testing:
if is_overridden('test_epoch_end', model=model):
model._current_fx_name = 'test_epoch_end'
eval_results = model.test_epoch_end(eval_results)
user_reduced = True
model.test_epoch_end(eval_results)

else:
if is_overridden('validation_epoch_end', model=model):
model._current_fx_name = 'validation_epoch_end'
eval_results = model.validation_epoch_end(eval_results)
user_reduced = True
model.validation_epoch_end(eval_results)

# capture logging
self.trainer.logger_connector.cache_logged_metrics()
# depre warning
if eval_results is not None and user_reduced:
step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end'
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
)

if not isinstance(eval_results, list):
eval_results = [eval_results]

self.trainer.logger_connector._track_callback_metrics(eval_results)

return eval_results

def __gather_epoch_end_eval_results(self, outputs):
eval_results = []
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def run_evaluation(self, on_epoch=False):
self.evaluation_loop.outputs.append(dl_outputs)

# lightning module method
deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end()
self.evaluation_loop.evaluation_epoch_end()

# hook
self.evaluation_loop.on_evaluation_epoch_end()
Expand All @@ -725,7 +725,7 @@ def run_evaluation(self, on_epoch=False):
self.evaluation_loop.on_evaluation_end()

# log epoch metrics
eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end()
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()

# save predictions to disk
self.evaluation_loop.predictions.to_disk()
Expand All @@ -735,7 +735,7 @@ def run_evaluation(self, on_epoch=False):

torch.set_grad_enabled(True)

return eval_loop_results, deprecated_eval_results
return eval_loop_results

def track_output_for_epoch_end(self, outputs, output):
if output is not None:
Expand All @@ -757,7 +757,7 @@ def run_evaluate(self):
assert self.evaluating

with self.profiler.profile(f"run_{self._running_stage}_evaluation"):
eval_loop_results, _ = self.run_evaluation()
eval_loop_results = self.run_evaluation()

if len(eval_loop_results) == 0:
return 1
Expand Down Expand Up @@ -831,7 +831,7 @@ def run_sanity_check(self, ref_model):
self.on_sanity_check_start()

# run eval step
_, eval_results = self.run_evaluation()
self.run_evaluation()

self.on_sanity_check_end()

Expand Down
13 changes: 3 additions & 10 deletions tests/trainer/data_flow/test_eval_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests to ensure that the training loop works with a dict (1.0)
Tests the evaluation loop
"""

import pytest
import torch

from pytorch_lightning import Trainer
Expand Down Expand Up @@ -189,8 +188,6 @@ def validation_epoch_end(self, outputs):
assert out_a == self.out_a
assert out_b == self.out_b

return {'no returns needed'}

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)

Expand All @@ -206,8 +203,7 @@ def backward(self, loss, optimizer, optimizer_idx):
weights_summary=None,
)

with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"):
trainer.fit(model)
trainer.fit(model)

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -254,8 +250,6 @@ def validation_epoch_end(self, outputs):
assert out_a == self.out_a
assert out_b == self.out_b

return {'no returns needed'}

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)

Expand All @@ -270,8 +264,7 @@ def backward(self, loss, optimizer, optimizer_idx):
weights_summary=None,
)

with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"):
trainer.fit(model)
trainer.fit(model)

# make sure correct steps were called
assert model.validation_step_called
Expand Down

0 comments on commit e891ceb

Please sign in to comment.