Skip to content

Commit

Permalink
Make training_epoch_end behave like validation_epoch_end (Lightning-A…
Browse files Browse the repository at this point in the history
…I#1357)

* Make training_epoch_end behave like validation_epoch_end + minor fixes in docstrings.

* Minor fixes (Borda's comments).

* Detach tensors in batch_output (to avoid possible memory leak) + doc fix.

Co-authored-by: Jean-Baptiste SCHIRATTI <jean-baptisteschiratti@MacBook-Pro-de-Jean-Baptiste.local>
  • Loading branch information
2 people authored and akarnachev committed Apr 4, 2020
1 parent 761658e commit 3f74a3c
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 23 deletions.
78 changes: 75 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,78 @@ def training_end(self, *args, **kwargs):
Deprecated in v0.7.0. use training_step_end instead
"""

def training_epoch_end(
self,
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
) -> Dict[str, Dict[str, Tensor]]:
"""Called at the end of training epoch with the outputs of all training_steps
.. code-block:: python
# the pseudocode for these calls
train_outs = []
for train_batch in train_data:
out = training_step(train_batch)
train_outs.append(out)
training_epoch_end(val_outs)
Args:
outputs: List of outputs you defined in training_step, or if there are multiple
dataloaders, a list containing a list of outputs for each dataloader
Return:
Dict or OrderedDict (dict): Dict has the following optional keys:
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
.. note:: If this method is not overridden, this won't be called.
- The outputs here are strictly for logging or progress bar.
- If you don't need to display anything, don't return anything.
- If you want to manually set current step, you can specify the 'step' key in the 'log' Dict
Examples:
With a single dataloader
.. code-block:: python
def training_epoch_end(self, outputs):
train_acc_mean = 0
for output in outputs:
train_acc_mean += output['train_acc']
train_acc_mean /= len(outputs)
# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item()}
}
return results
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
.. code-block:: python
def training_epoch_end(self, outputs):
train_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
train_acc_mean += output['train_acc']
i += 1
train_acc_mean /= i
# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch}
}
return results
"""

def training_step_end(self, *args, **kwargs) -> Dict[
str, Union[Tensor, Dict[str, Tensor]]
]:
Expand Down Expand Up @@ -453,7 +525,7 @@ def validation_epoch_end(
outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
) -> Dict[str, Dict[str, Tensor]]:
"""
Called at end of validation epoch with the output of all validation_steps
Called at end of validation epoch with the outputs of all validation_steps
.. code-block:: python
Expand All @@ -462,7 +534,7 @@ def validation_epoch_end(
val_outs = []
for val_batch in val_data:
out = validation_step(train_batch)
train_outs.append(out)
val_outs.append(out)
validation_epoch_end(val_outs)
Args:
Expand Down Expand Up @@ -493,7 +565,7 @@ def validation_epoch_end(self, outputs):
val_acc_mean /= len(outputs)
tqdm_dict = {'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
# show val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_acc': val_acc_mean.item()}
Expand Down
79 changes: 61 additions & 18 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean

Expand Down Expand Up @@ -390,14 +391,17 @@ def train(self):

def run_training_epoch(self):

# get model
model = self.get_model()

# Epoch start events
with self.profiler.profile('on_epoch_start'):
# callbacks
self.on_epoch_start()

# model hooks
if self.is_function_implemented('on_epoch_start'):
self.get_model().on_epoch_start()
model.on_epoch_start()

# track local dataloader so TPU can wrap each epoch
train_dataloader = self.train_dataloader
Expand All @@ -408,6 +412,9 @@ def run_training_epoch(self):
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)

# bookkeeping
outputs = []

# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
Expand All @@ -418,14 +425,15 @@ def run_training_epoch(self):

self.batch_idx = batch_idx

model = self.get_model()
model.global_step = self.global_step

# ---------------
# RUN TRAIN STEP
# ---------------
output = self.run_training_batch(batch, batch_idx)
batch_result, grad_norm_dic, batch_step_metrics = output
_outputs = self.run_training_batch(batch, batch_idx)
batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
# detach tensors in batch_output before appending to outputs
outputs.append(_recursive_detach(batch_output))

# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1
Expand Down Expand Up @@ -484,6 +492,18 @@ def run_training_epoch(self):
if early_stop_epoch or self.fast_dev_run:
break

# process epoch outputs
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
model = model.module

if self.is_overriden('training_epoch_end', model=model):
epoch_output = model.training_epoch_end(outputs)
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]
self.log_metrics(log_epoch_metrics, {})
self.callback_metrics.update(callback_epoch_metrics)

# in case validation step is missing and you are not running fast-dev to duplicate last batch
if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val):
self.call_checkpoint_callback()
Expand All @@ -497,7 +517,7 @@ def run_training_epoch(self):
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
self.get_model().on_epoch_end()
model.on_epoch_end()

def run_training_batch(self, batch, batch_idx):
# track grad norms
Expand Down Expand Up @@ -546,14 +566,13 @@ def run_training_batch(self, batch, batch_idx):
def optimizer_closure():
# forward pass
with self.profiler.profile('model_forward'):
output = self.training_forward(
output_dict = self.training_forward(
split_batch, batch_idx, opt_idx, self.hiddens)

closure_loss = output[0]
progress_bar_metrics = output[1]
log_metrics = output[2]
callback_metrics = output[3]
self.hiddens = output[4]
# format and reduce outputs accordingly
processed_output = self.process_output(output_dict, train=True)

closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
Expand All @@ -577,10 +596,10 @@ def optimizer_closure():
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()

return closure_loss
return closure_loss, output_dict

# calculate loss
loss = optimizer_closure()
loss, batch_output = optimizer_closure()

# check if loss or model weights are nan
self.detect_nan_tensors(loss)
Expand All @@ -606,7 +625,8 @@ def optimizer_closure():
model = self.get_model()
with self.profiler.profile('optimizer_step'):
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx, optimizer_closure)
optimizer, opt_idx,
lambda: optimizer_closure()[0])

# calculate running loss for display
self.running_loss.append(self.batch_loss_value.mean())
Expand All @@ -633,7 +653,7 @@ def optimizer_closure():
# track all metrics for callbacks
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})

return 0, grad_norm_dic, all_log_metrics
return 0, grad_norm_dic, all_log_metrics, batch_output

def _get_optimizers_iterable(self):
if not self.optimizer_frequencies:
Expand Down Expand Up @@ -732,9 +752,6 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
warnings.warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use training_epoch_end instead', DeprecationWarning)

# format and reduce outputs accordingly
output = self.process_output(output, train=True)

return output

def update_learning_rates(self, interval: str):
Expand Down Expand Up @@ -784,3 +801,29 @@ def _with_is_last(iterable):
last = val
# yield last, no longer has next
yield last, True


def _recursive_detach(in_dict):
"""Detach all tensors in `in_dict`.
May operate recursively if some of the values in `in_dict` are dictionaries
which contain instances of `torch.Tensor`. Other types in `in_dict` are
not affected by this utility function.
Parameters
----------
in_dict : dict
Returns
-------
out_dict : dict
"""
out_dict = {}
for k, v in in_dict.items():
if isinstance(v, dict):
out_dict.update({k: _recursive_detach(v)})
elif callable(getattr(v, 'detach', None)):
out_dict.update({k: v.detach()})
else:
out_dict.update({k: v})
return out_dict
10 changes: 8 additions & 2 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,16 @@ def test_multiple_loggers_pickle(tmpdir):
def test_adding_step_key(tmpdir):
logged_step = 0

def _validation_end(outputs):
def _validation_epoch_end(outputs):
nonlocal logged_step
logged_step += 1
return {"log": {"step": logged_step, "val_acc": logged_step / 10}}

def _training_epoch_end(outputs):
nonlocal logged_step
logged_step += 1
return {"log": {"step": logged_step, "train_acc": logged_step / 10}}

def _log_metrics_decorator(log_metrics_fn):
def decorated(metrics, step):
if "val_acc" in metrics:
Expand All @@ -148,7 +153,8 @@ def decorated(metrics, step):
return decorated

model, hparams = tutils.get_default_model()
model.validation_epoch_end = _validation_end
model.validation_epoch_end = _validation_epoch_end
model.training_epoch_end = _training_epoch_end
trainer_options = dict(
max_epochs=4,
default_save_path=tmpdir,
Expand Down

0 comments on commit 3f74a3c

Please sign in to comment.