Skip to content

Commit

Permalink
ref: clean up hooks in run_evaluation (#3156)
Browse files Browse the repository at this point in the history
* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation
  • Loading branch information
williamFalcon committed Aug 25, 2020
1 parent 22b9642 commit 50aed42
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 27 deletions.
48 changes: 48 additions & 0 deletions pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@ def __init__(self, trainer):
self.predictions = None
self.max_batches = None

def get_evaluation_dataloaders(self):
# select dataloaders
model = self.trainer.get_model()

if self.testing:
self.trainer.reset_test_dataloader(model)
dataloaders = self.trainer.test_dataloaders
max_batches = self.trainer.num_test_batches
else:
if self.trainer.val_dataloaders is None:
self.trainer.reset_val_dataloader(model)

dataloaders = self.trainer.val_dataloaders
max_batches = self.trainer.num_val_batches

return dataloaders, max_batches

def should_skip_evaluation(self, dataloaders, max_batches):
# skip when dataloaders aren't defined
if dataloaders is None:
return True

# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return True

return False

def on_evaluation_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_start', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_start', *args, **kwargs)

def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)

def reload_evaluation_dataloaders(self):
model = self.trainer.get_model()
if self.testing:
self.trainer.reset_test_dataloader(model)
else:
self.trainer.reset_val_dataloader(model)

def is_using_eval_results(self):
outputs = self.outputs
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
Expand Down
40 changes: 13 additions & 27 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ def _evaluate(
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""
# set up the loop for val/test
self.evaluation_loop.testing = test_mode

# enable eval mode + no grads
model.zero_grad()
Expand Down Expand Up @@ -310,7 +308,10 @@ def _evaluate(
return eval_results

def run_evaluation(self, test_mode: bool = False):
# hook
# set up the loop for val/test
self.evaluation_loop.testing = test_mode

# TODO: deprecate
model = self.get_model()
model.on_pre_performance_check()

Expand All @@ -328,20 +329,13 @@ def run_evaluation(self, test_mode: bool = False):
dataloaders = self.val_dataloaders
max_batches = self.num_val_batches

if dataloaders is None:
return [], []

# Validation/Test begin callbacks
if test_mode:
self.on_test_start()
else:
self.on_validation_start()

# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
return [], []

# TODO: deprecate
self.evaluation_loop.on_evaluation_start()

# run evaluation (val_step + val_step_end + val_epoch_end)
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)

Expand All @@ -351,20 +345,12 @@ def run_evaluation(self, test_mode: bool = False):
# hook
model.on_post_performance_check()

# eventual dataset reloading
if test_mode:
if self.reload_dataloaders_every_epoch:
self.reset_test_dataloader(model)
else:
# val
if self.reload_dataloaders_every_epoch:
self.reset_val_dataloader(model)
# user may want to reload every epoch
if self.reload_dataloaders_every_epoch:
self.evaluation_loop.reload_evaluation_dataloaders()

# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
self.on_validation_end()
# TODO: deprecate
self.evaluation_loop.on_evaluation_end()

return eval_loop_results, eval_results

Expand Down

0 comments on commit 50aed42

Please sign in to comment.