Skip to content

Commit

Permalink
Do not add return dict items to callback_metrics (#6682)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 26, 2021
1 parent 6b990f3 commit bc61361
Show file tree
Hide file tree
Showing 16 changed files with 84 additions and 341 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))


- Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682))


- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


Expand Down
10 changes: 5 additions & 5 deletions docs/source/ecosystem/asr_nlp_tts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ with PyTorch Lightning since every NeMo model is a Lightning Module.
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
)
wer_num, wer_denom = self._wer(predictions, transcript, transcript_len)
tensorboard_logs = {
self.log_dict({
'train_loss': loss_value,
'training_batch_wer': wer_num / wer_denom,
'learning_rate': self._optimizer.param_groups[0]['lr'],
}
return {'loss': loss_value, 'log': tensorboard_logs}
})
return loss_value
Neural Types in NeMo ASR
------------------------
Expand Down Expand Up @@ -539,8 +539,8 @@ since every NeMo model is a Lightning Module.
logits = self(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)
loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)
tensorboard_logs = {'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']}
return {'loss': loss, 'log': tensorboard_logs}
self.log_dict({'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']})
return loss
...
Neural Types in NeMo NLP
Expand Down
4 changes: 2 additions & 2 deletions docs/source/ecosystem/bolts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ you can trust the implementations and use them to bootstrap your research much f
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long())
logs = {"loss": loss}
return {"loss": loss, "log": logs}
self.log("loss", loss)
return loss
----------

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def _validate_monitor_key(self, trainer):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
)
raise MisconfigurationException(m)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:

# update callback_metrics
logger_connector._callback_metrics.update(callback_metrics)
logger_connector._callback_metrics.pop("epoch", None)

batch_pbar_metrics.pop("debug_epoch", None)
return batch_pbar_metrics, batch_log_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:

@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self.trainer._running_stage) # type: ignore
return self._cached_results.get(self.trainer._running_stage)

def get_metrics(self, key: str) -> Dict:
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
Expand Down Expand Up @@ -121,8 +121,6 @@ def cache_logged_metrics(self):
def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
# logging
self.configure_logger(logger)
# todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
# and assign here the desired value
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
Expand Down Expand Up @@ -185,9 +183,6 @@ def cache_training_step_metrics(self, opt_closure_result):
batch_log_metrics = opt_closure_result.training_step_output.log_metrics
logged_metrics_tmp.update(batch_log_metrics)

callback_metrics = opt_closure_result.training_step_output.callback_metrics
callback_metrics_tmp.update(callback_metrics)

batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
pbar_metrics_tmp.update(batch_pbar_metrics)

Expand All @@ -210,9 +205,6 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
metrics (dict): Metric values
grad_norm_dic (dict): Gradient norms
step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training
steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients)
and global_step for the rest.
"""
# add gpu memory
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
Expand Down Expand Up @@ -348,27 +340,6 @@ def _track_callback_metrics(self, eval_results):
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}

# add metrics to prog bar
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
if len(log_metrics) > 0:
self.trainer.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks (all prog bar, logged and callback metrics)
callback_metrics.update(log_metrics)
callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
if self.trainer.sanity_checking:
return
Expand All @@ -379,21 +350,21 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
if not isinstance(eval_results, list):
eval_results = [eval_results]

num_loaders: int = self.trainer.evaluation_loop.num_dataloaders
prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {}

for result_idx, result in enumerate(eval_results):
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
_, prog_bar_metrics, log_metrics, _ = self.trainer.process_dict_result(result)

# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics}

# add metrics to prog bar
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

if num_loaders > 1:
self.__process_eval_epoch_end_results_and_log_legacy_update(
prog_bar_metrics, log_metrics, callback_metrics
)
# log metrics
if len(log_metrics) > 0:
self.trainer.logger_connector.log_metrics(log_metrics, {})

if num_loaders == 1:
self.__process_eval_epoch_end_results_and_log_legacy_update(
prog_bar_metrics, log_metrics, callback_metrics
)
if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

def on_train_epoch_end(self):
# inform cached logger connector epoch finished
Expand Down Expand Up @@ -446,10 +417,9 @@ def log_train_epoch_end_metrics(

# TODO: deprecate 1.0
else:
out = self.__run_legacy_training_epoch_end(
num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
epoch_log_metrics, epoch_progress_bar_metrics = self.__run_legacy_training_epoch_end(
num_optimizers, epoch_output, model, is_result_obj
)
epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out

# it will perform reduction over epoch and return log metrics
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
Expand Down Expand Up @@ -501,9 +471,7 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
# capture logging
self.trainer.logger_connector.cache_logged_metrics()

def __run_legacy_training_epoch_end(
self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
):
def __run_legacy_training_epoch_end(self, num_optimizers, epoch_output, model, is_result_obj):

epoch_log_metrics = {}
epoch_progress_bar_metrics = {}
Expand Down Expand Up @@ -534,15 +502,14 @@ def __run_legacy_training_epoch_end(
_processed_outputs = self.trainer.process_dict_result(epoch_output)
epoch_progress_bar_metrics = _processed_outputs[1]
epoch_log_metrics = _processed_outputs[2]
epoch_callback_metrics = _processed_outputs[3]

# --------------------------
# Structured Result (auto epoch end)
# --------------------------
elif is_result_obj:
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics
return epoch_log_metrics, epoch_progress_bar_metrics

def __auto_reduce_results_on_epoch_end(self, epoch_output):
epoch_log_metrics = {}
Expand Down
31 changes: 9 additions & 22 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach


Expand All @@ -32,8 +33,14 @@ class TrainerLoggingMixin(ABC):

def metrics_to_scalars(self, metrics):
new_metrics = {}
# TODO: this is duplicated in MetricsHolder. should be unified
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
if v.numel() != 1:
raise MisconfigurationException(
f"The metric `{k}` does not contain a single element"
f" thus it cannot be converted to float. Found `{v}`"
)
v = v.item()

if isinstance(v, dict):
Expand Down Expand Up @@ -71,23 +78,8 @@ def process_dict_result(self, output, train=False):
if isinstance(output, torch.Tensor):
progress_bar_metrics = {}
log_metrics = {}
callback_metrics = {}
hiddens = None
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens

# ---------------
# EXTRACT CALLBACK KEYS
# ---------------
# all keys not progress_bar or log are candidates for callbacks
callback_metrics = {}
if isinstance(output, Mapping):
for k, v in output.items():
if k not in ['progress_bar', 'log', 'hiddens']:
callback_metrics[k] = v

if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
num_gpus = self.num_gpus
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
return output, progress_bar_metrics, log_metrics, hiddens

# ---------------
# EXTRACT PROGRESS BAR KEYS
Expand Down Expand Up @@ -149,17 +141,12 @@ def process_dict_result(self, output, train=False):
# ---------------
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None

# use every metric passed in as a candidate for callback
callback_metrics.update(progress_bar_metrics)
callback_metrics.update(log_metrics)

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
callback_metrics = recursive_detach(callback_metrics)
progress_bar_metrics = recursive_detach(progress_bar_metrics)
log_metrics = recursive_detach(log_metrics)

return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
return loss, progress_bar_metrics, log_metrics, hiddens

def reduce_distributed_output(self, output, num_gpus):
if num_gpus <= 1:
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,15 +823,6 @@ def run_sanity_check(self, ref_model):
# run eval step
_, eval_results = self.run_evaluation()

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
# when we get a list back, used only the last item
if isinstance(eval_results, list):
eval_results = eval_results[-1]

_, _, _, callback_metrics, _ = self.process_dict_result(eval_results)
self.logger_connector.callback_metrics = callback_metrics

self.on_sanity_check_end()

self._running_stage = stage
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ def _process_training_step_output(self, training_step_output, split_batch):
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3],
hiddens=training_step_output[4],
hiddens=training_step_output[3],
)
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
Expand Down
5 changes: 2 additions & 3 deletions tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def _mean(res, key):
val_loss_mean = val_loss_mean.item()
val_acc_mean = val_acc_mean.item()

metrics_dict = {'early_stop_on': val_loss_mean, 'val_acc': val_acc_mean}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results
self.log('early_stop_on', val_loss_mean, prog_bar=True)
self.log('val_acc', val_acc_mean, prog_bar=True)

def validation_epoch_end__multiple_dataloaders(self, outputs):
"""
Expand Down
30 changes: 4 additions & 26 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ModelOverrideValidationReturn(BoringModel):

def validation_epoch_end(self, outputs):
loss = self.validation_return_values[self.current_epoch]
return {"test_val_loss": loss}
self.log("test_val_loss", loss)

model = ModelOverrideValidationReturn()
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
Expand Down Expand Up @@ -220,7 +220,7 @@ class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', torch.tensor(val_loss))
self.log('abc', val_loss)

model = CurrentModel()

Expand All @@ -234,28 +234,6 @@ def validation_epoch_end(self, outputs):
assert trainer.current_epoch == 5, 'early_stopping failed'


def test_early_stopping_functionality_arbitrary_key(tmpdir):
"""Tests whether early stopping works with a custom key and dictionary results on val step."""

class CurrentModel(BoringModel):

def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
return {'jiraffe': torch.tensor(val_loss)}

model = CurrentModel()

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[EarlyStopping(monitor='jiraffe')],
overfit_batches=0.20,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch >= 5, 'early_stopping failed'


@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int):
"""Excepted Behaviour:
Expand All @@ -272,7 +250,7 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: in
when `early_stopping` is being triggered,
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
This test validate those expected behaviours
"""
Expand Down Expand Up @@ -309,7 +287,7 @@ def validation_epoch_end(self, outputs):
self._count_decrease += 1
self._loss_value -= self._eps
self._values.append(_mean)
return {"test_val_loss": _mean}
self.log('test_val_loss', _mean)

model = Model(step_freeze)
model.training_step_end = None
Expand Down
Loading

0 comments on commit bc61361

Please sign in to comment.