Skip to content

Commit

Permalink
ref: refactored inner eval loop (#3141)
Browse files Browse the repository at this point in the history
* refactored dataloader process hook

* refactored dataloader process hook

* refactored dataloader process hook
  • Loading branch information
williamFalcon committed Aug 25, 2020
1 parent f064d74 commit ccc923c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 74 deletions.
65 changes: 61 additions & 4 deletions pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.core.step_result import EvalResult
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class EvaluationLoop(object):
Expand Down Expand Up @@ -43,11 +44,38 @@ def on_evaluation_epoch_start(self, *args, **kwargs):
else:
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)

def evaluation_step(self, *args, **kwargs):
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]

multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1)
multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1)

if multiple_test_loaders or multiple_val_loaders:
args.append(dataloader_idx)

return args

def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
# configure args
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)

# run actual test step
if self.testing:
output = self.trainer.accelerator_backend.test_step(*args, **kwargs)
output = self.trainer.accelerator_backend.test_step(args)
else:
output = self.trainer.accelerator_backend.validation_step(*args, **kwargs)
output = self.trainer.accelerator_backend.validation_step(args)

# track batch size for weighted average
is_result_obj = isinstance(output, Result)
if is_result_obj:
output.track_batch_size(len(batch))

# allow only EvalResult when using structured results (from val_step)
if is_result_obj and not isinstance(output, EvalResult):
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)

return output

def evaluation_step_end(self, *args, **kwargs):
Expand All @@ -69,8 +97,37 @@ def on_evaluation_batch_end(self, *args, **kwargs):
else:
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)

def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
# Add step predictions to prediction collection to write later
if output is not None:
do_write_predictions = isinstance(output, Result) and self.testing
if do_write_predictions:
self.predictions.add(output.pop('predictions', None))

# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)

def on_evaluation_epoch_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)

def log_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
return

if isinstance(output, EvalResult):
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics

if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
for k, v in step_log_metrics.items():
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v

self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx)

if len(step_pbar_metrics) > 0:
self.trainer.add_progress_bar_metrics(step_pbar_metrics)
77 changes: 7 additions & 70 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop

try:
Expand Down Expand Up @@ -273,55 +272,19 @@ def _evaluate(
if batch_idx >= dl_max_batches:
break

# -----------------
# eval_batch_start
# -----------------
# val loop hooks
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

# -----------------
# RUN EVALUATION STEP
# -----------------
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step(args)

# track batch size for weighted average
is_result_obj = isinstance(output, Result)
if is_result_obj:
output.track_batch_size(len(batch))

# allow only EvalResult when using structured results (from val_step)
if is_result_obj and not isinstance(output, EvalResult):
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)

# ------------------
# EVAL STEP END
# ------------------
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)

# ------------------
# Hook: on_eval_batch_end
# ------------------
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)

# ----------------------
# Post processing
# ----------------------
# track outputs for collation
if output is not None:

# Add step predictions to prediction collection to write later
do_write_predictions = is_result_obj and test_mode
if do_write_predictions:
self.evaluation_loop.predictions.add(output.pop('predictions', None))
# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_metrics(output, batch_idx)

if output is not None:
dl_outputs.append(output)

self.__eval_add_step_metrics(output, batch_idx)

# track debug metrics
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)

self.evaluation_loop.outputs.append(dl_outputs)

# ---------------------
Expand Down Expand Up @@ -454,23 +417,6 @@ def __gather_epoch_end_eval_results(self, outputs):
eval_results = eval_results[0]
return eval_results

def __eval_add_step_metrics(self, output, batch_idx):
# track step level metrics
if isinstance(output, EvalResult) and not self.running_sanity_check:
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics

if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
for k, v in step_log_metrics.items():
metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v

self.log_metrics(metrics_by_epoch, {}, step=batch_idx)

if len(step_pbar_metrics) > 0:
self.add_progress_bar_metrics(step_pbar_metrics)

def __auto_reduce_result_objs(self, outputs):
# outputs has a list of results per dataloader
eval_results = []
Expand Down Expand Up @@ -588,12 +534,3 @@ def __log_evaluation_epoch_metrics(self, eval_results, test_mode):
print('-' * 80)

return eval_loop_results

def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]

if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1):
args.append(dataloader_idx)

return args

0 comments on commit ccc923c

Please sign in to comment.