Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: clean up hooks in run_evaluation #3156

Merged
merged 7 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@ def __init__(self, trainer):
self.predictions = None
self.max_batches = None

def get_evaluation_dataloaders(self):
# select dataloaders
model = self.trainer.get_model()

if self.testing:
self.trainer.reset_test_dataloader(model)
dataloaders = self.trainer.test_dataloaders
max_batches = self.trainer.num_test_batches
else:
if self.trainer.val_dataloaders is None:
self.trainer.reset_val_dataloader(model)

dataloaders = self.trainer.val_dataloaders
max_batches = self.trainer.num_val_batches

return dataloaders, max_batches

def should_skip_evaluation(self, dataloaders, max_batches):
# skip when dataloaders aren't defined
if dataloaders is None:
return True

# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return True

return False

def on_evaluation_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_start', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_start', *args, **kwargs)

def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)

def reload_evaluation_dataloaders(self):
model = self.trainer.get_model()
if self.testing:
self.trainer.reset_test_dataloader(model)
else:
self.trainer.reset_val_dataloader(model)

def is_using_eval_results(self):
outputs = self.outputs
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
Expand Down
40 changes: 13 additions & 27 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ def _evaluate(
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""
# set up the loop for val/test
self.evaluation_loop.testing = test_mode

# enable eval mode + no grads
model.zero_grad()
Expand Down Expand Up @@ -310,7 +308,10 @@ def _evaluate(
return eval_results

def run_evaluation(self, test_mode: bool = False):
# hook
# set up the loop for val/test
self.evaluation_loop.testing = test_mode

# TODO: deprecate
model = self.get_model()
model.on_pre_performance_check()

Expand All @@ -328,20 +329,13 @@ def run_evaluation(self, test_mode: bool = False):
dataloaders = self.val_dataloaders
max_batches = self.num_val_batches

if dataloaders is None:
return [], []

# Validation/Test begin callbacks
if test_mode:
self.on_test_start()
else:
self.on_validation_start()

# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
return [], []

# TODO: deprecate
self.evaluation_loop.on_evaluation_start()

# run evaluation (val_step + val_step_end + val_epoch_end)
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)

Expand All @@ -351,20 +345,12 @@ def run_evaluation(self, test_mode: bool = False):
# hook
model.on_post_performance_check()

# eventual dataset reloading
if test_mode:
if self.reload_dataloaders_every_epoch:
self.reset_test_dataloader(model)
else:
# val
if self.reload_dataloaders_every_epoch:
self.reset_val_dataloader(model)
# user may want to reload every epoch
if self.reload_dataloaders_every_epoch:
self.evaluation_loop.reload_evaluation_dataloaders()

# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
self.on_validation_end()
# TODO: deprecate
self.evaluation_loop.on_evaluation_end()

return eval_loop_results, eval_results

Expand Down