Skip to content

Commit

Permalink
ref: moved eval loop logging to loggers 1/n (#3408)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Sep 9, 2020
1 parent 8f6b115 commit 0c2e315
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 75 deletions.
9 changes: 7 additions & 2 deletions pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,14 @@ def evaluation_epoch_end(self, num_dataloaders):
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
return eval_results

def log_epoch_metrics(self, eval_results):
def log_epoch_metrics(self, eval_results, test_mode):
using_eval_result = self.is_using_eval_results()
self.trainer.logger_connector.on_evaluation_epoch_end(eval_results, using_eval_result)
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
eval_results,
using_eval_result,
test_mode
)
return eval_loop_results

def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
model = self.trainer.get_model()
Expand Down
74 changes: 3 additions & 71 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,33 +124,16 @@
"""

from abc import ABC, abstractmethod
from pprint import pprint
from typing import Callable, List, Union
from typing import Callable, List

import torch
from torch.utils.data import DataLoader

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.logger_connector import LoggerConnector

try:
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainerEvaluationLoopMixin(ABC):

Expand Down Expand Up @@ -265,15 +248,12 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))

# bookkeeping
self.evaluation_loop.log_epoch_metrics(eval_results)
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
self.evaluation_loop.predictions.to_disk()

# hook
self.evaluation_loop.on_evaluation_epoch_end()

# log the final eval loop metrics
eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode)

# enable train mode again
model.train()
torch.set_grad_enabled(True)
Expand All @@ -282,51 +262,3 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
self.evaluation_loop.on_evaluation_end()

return eval_loop_results, eval_results

def __log_evaluation_epoch_metrics(self, eval_results, test_mode):
if self.running_sanity_check:
return

eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:

# in eval, the user may return something at every validation step without final reduction
if not isinstance(eval_results, list):
eval_results = [eval_results]

for result_idx, result in enumerate(eval_results):
if isinstance(result, EvalResult):
prog_bar_metrics = result.epoch_pbar_metrics
log_metrics = result.epoch_log_metrics
callback_metrics = result.callback_metrics

# in testing we don't need the callback metrics
if test_mode:
callback_metrics = {}
else:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result)

# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}

# add metrics to prog bar
self.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
self.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks
self.logger_connector.callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
eval_loop_results.append(dataloader_result_metrics)

# log results of test
if test_mode and self.is_global_zero and self.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
pprint(results)
print('-' * 80)

return eval_loop_results
58 changes: 56 additions & 2 deletions pytorch_lightning/trainer/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from pytorch_lightning.core import memory
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.core.step_result import EvalResult, Result
from pprint import pprint


class LoggerConnector:
Expand Down Expand Up @@ -73,7 +74,12 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def on_evaluation_epoch_end(self, eval_results, using_eval_result):
def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode):
# TODO: merge both functions?
self._log_on_evaluation_epoch_end_metrics(eval_results, using_eval_result)
return self.__log_evaluation_epoch_metrics_2(eval_results, test_mode)

def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result):
if using_eval_result:
if isinstance(eval_results, list):
for eval_result in eval_results:
Expand All @@ -97,6 +103,54 @@ def on_evaluation_epoch_end(self, eval_results, using_eval_result):
flat = flatten_dict(eval_results)
self.trainer.logger_connector.callback_metrics.update(flat)

def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):
if self.trainer.running_sanity_check:
return

eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:

# in eval, the user may return something at every validation step without final reduction
if not isinstance(eval_results, list):
eval_results = [eval_results]

for result_idx, result in enumerate(eval_results):
if isinstance(result, EvalResult):
prog_bar_metrics = result.epoch_pbar_metrics
log_metrics = result.epoch_log_metrics
callback_metrics = result.callback_metrics

# in testing we don't need the callback metrics
if test_mode:
callback_metrics = {}
else:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_output(result)

# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}

# add metrics to prog bar
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
self.trainer.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks
self.trainer.logger_connector.callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
eval_loop_results.append(dataloader_result_metrics)

# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
pprint(results)
print('-' * 80)

return eval_loop_results

def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator,
early_stopping_accumulator, num_optimizers)
Expand Down

0 comments on commit 0c2e315

Please sign in to comment.