Skip to content

Commit

Permalink
Do not add return dict items to callback_metrics (#6682)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and lexierule committed Mar 30, 2021
1 parent f6d5782 commit 61b7fd5
Show file tree
Hide file tree
Showing 18 changed files with 101 additions and 342 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498))

### Removed

- Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682))

### Fixed

- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
Expand Down
10 changes: 5 additions & 5 deletions docs/source/ecosystem/asr_nlp_tts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ with PyTorch Lightning since every NeMo model is a Lightning Module.
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
)
wer_num, wer_denom = self._wer(predictions, transcript, transcript_len)
tensorboard_logs = {
self.log_dict({
'train_loss': loss_value,
'training_batch_wer': wer_num / wer_denom,
'learning_rate': self._optimizer.param_groups[0]['lr'],
}
return {'loss': loss_value, 'log': tensorboard_logs}
})
return loss_value
Neural Types in NeMo ASR
------------------------
Expand Down Expand Up @@ -539,8 +539,8 @@ since every NeMo model is a Lightning Module.
logits = self(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)
loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)
tensorboard_logs = {'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']}
return {'loss': loss, 'log': tensorboard_logs}
self.log_dict({'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']})
return loss
...
Neural Types in NeMo NLP
Expand Down
4 changes: 2 additions & 2 deletions docs/source/ecosystem/bolts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ you can trust the implementations and use them to bootstrap your research much f
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long())
logs = {"loss": loss}
return {"loss": loss, "log": logs}
self.log("loss", loss)
return loss
----------

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _validate_monitor_key(self, trainer):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
)
raise MisconfigurationException(m)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:

# update callback_metrics
logger_connector._callback_metrics.update(callback_metrics)
logger_connector._callback_metrics.pop("epoch", None)

batch_pbar_metrics.pop("debug_epoch", None)
return batch_pbar_metrics, batch_log_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:

@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self.trainer._running_stage) # type: ignore
return self._cached_results.get(self.trainer._running_stage)

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
Expand Down Expand Up @@ -125,8 +125,6 @@ def cache_logged_metrics(self):
def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
# logging
self.configure_logger(logger)
# todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
# and assign here the desired value
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
Expand Down Expand Up @@ -189,9 +187,6 @@ def cache_training_step_metrics(self, opt_closure_result):
batch_log_metrics = opt_closure_result.training_step_output.log_metrics
logged_metrics_tmp.update(batch_log_metrics)

callback_metrics = opt_closure_result.training_step_output.callback_metrics
callback_metrics_tmp.update(callback_metrics)

batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
pbar_metrics_tmp.update(batch_pbar_metrics)

Expand All @@ -214,9 +209,6 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
metrics (dict): Metric values
grad_norm_dic (dict): Gradient norms
step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training
steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients)
and global_step for the rest.
"""
# add gpu memory
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
Expand Down Expand Up @@ -350,27 +342,6 @@ def _track_callback_metrics(self, eval_results):
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}

# add metrics to prog bar
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
if len(log_metrics) > 0:
self.trainer.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks (all prog bar, logged and callback metrics)
callback_metrics.update(log_metrics)
callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
if self.trainer.running_sanity_check:
return
Expand All @@ -381,21 +352,21 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
if not isinstance(eval_results, list):
eval_results = [eval_results]

num_loaders: int = self.trainer.evaluation_loop.num_dataloaders
prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {}

for result_idx, result in enumerate(eval_results):
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
_, prog_bar_metrics, log_metrics, _ = self.trainer.process_dict_result(result)

# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics}

# add metrics to prog bar
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

if num_loaders > 1:
self.__process_eval_epoch_end_results_and_log_legacy_update(
prog_bar_metrics, log_metrics, callback_metrics
)
# log metrics
if len(log_metrics) > 0:
self.trainer.logger_connector.log_metrics(log_metrics, {})

if num_loaders == 1:
self.__process_eval_epoch_end_results_and_log_legacy_update(
prog_bar_metrics, log_metrics, callback_metrics
)
if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

def on_train_epoch_end(self):
# inform cached logger connector epoch finished
Expand Down Expand Up @@ -448,10 +419,9 @@ def log_train_epoch_end_metrics(

# TODO: deprecate 1.0
else:
out = self.__run_legacy_training_epoch_end(
num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
epoch_log_metrics, epoch_progress_bar_metrics = self.__run_legacy_training_epoch_end(
num_optimizers, epoch_output, model, is_result_obj
)
epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out

# it will perform reduction over epoch and return log metrics
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
Expand Down Expand Up @@ -503,9 +473,7 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
# capture logging
self.trainer.logger_connector.cache_logged_metrics()

def __run_legacy_training_epoch_end(
self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
):
def __run_legacy_training_epoch_end(self, num_optimizers, epoch_output, model, is_result_obj):

epoch_log_metrics = {}
epoch_progress_bar_metrics = {}
Expand Down Expand Up @@ -536,15 +504,14 @@ def __run_legacy_training_epoch_end(
_processed_outputs = self.trainer.process_dict_result(epoch_output)
epoch_progress_bar_metrics = _processed_outputs[1]
epoch_log_metrics = _processed_outputs[2]
epoch_callback_metrics = _processed_outputs[3]

# --------------------------
# Structured Result (auto epoch end)
# --------------------------
elif is_result_obj:
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics
return epoch_log_metrics, epoch_progress_bar_metrics

def __auto_reduce_results_on_epoch_end(self, epoch_output):
epoch_log_metrics = {}
Expand Down
31 changes: 9 additions & 22 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import DeviceType, DistributedType
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach


Expand All @@ -42,8 +43,14 @@ class TrainerLoggingMixin(ABC):

def metrics_to_scalars(self, metrics):
new_metrics = {}
# TODO: this is duplicated in MetricsHolder. should be unified
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
if v.numel() != 1:
raise MisconfigurationException(
f"The metric `{k}` does not contain a single element"
f" thus it cannot be converted to float. Found `{v}`"
)
v = v.item()

if isinstance(v, dict):
Expand Down Expand Up @@ -81,23 +88,8 @@ def process_dict_result(self, output, train=False):
if isinstance(output, torch.Tensor):
progress_bar_metrics = {}
log_metrics = {}
callback_metrics = {}
hiddens = None
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens

# ---------------
# EXTRACT CALLBACK KEYS
# ---------------
# all keys not progress_bar or log are candidates for callbacks
callback_metrics = {}
if isinstance(output, Mapping):
for k, v in output.items():
if k not in ['progress_bar', 'log', 'hiddens']:
callback_metrics[k] = v

if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
num_gpus = self.num_gpus
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
return output, progress_bar_metrics, log_metrics, hiddens

# ---------------
# EXTRACT PROGRESS BAR KEYS
Expand Down Expand Up @@ -159,17 +151,12 @@ def process_dict_result(self, output, train=False):
# ---------------
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None

# use every metric passed in as a candidate for callback
callback_metrics.update(progress_bar_metrics)
callback_metrics.update(log_metrics)

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
callback_metrics = recursive_detach(callback_metrics)
progress_bar_metrics = recursive_detach(progress_bar_metrics)
log_metrics = recursive_detach(log_metrics)

return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
return loss, progress_bar_metrics, log_metrics, hiddens

def reduce_distributed_output(self, output, num_gpus):
if num_gpus <= 1:
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,15 +859,6 @@ def run_sanity_check(self, ref_model):
# run eval step
_, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches)

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
# when we get a list back, used only the last item
if isinstance(eval_results, list):
eval_results = eval_results[-1]

_, _, _, callback_metrics, _ = self.process_dict_result(eval_results)
self.logger_connector.callback_metrics = callback_metrics

self.on_sanity_check_end()
self.running_sanity_check = False

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ def _process_training_step_output(self, training_step_output, split_batch):
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3],
hiddens=training_step_output[4],
hiddens=training_step_output[3],
)
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
Expand Down
5 changes: 2 additions & 3 deletions tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def _mean(res, key):
val_loss_mean = val_loss_mean.item()
val_acc_mean = val_acc_mean.item()

metrics_dict = {'early_stop_on': val_loss_mean, 'val_acc': val_acc_mean}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results
self.log('early_stop_on', val_loss_mean, prog_bar=True)
self.log('val_acc', val_acc_mean, prog_bar=True)

def validation_epoch_end__multiple_dataloaders(self, outputs):
"""
Expand Down
30 changes: 4 additions & 26 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ModelOverrideValidationReturn(BoringModel):

def validation_epoch_end(self, outputs):
loss = self.validation_return_values[self.current_epoch]
return {"test_val_loss": loss}
self.log("test_val_loss", loss)

model = ModelOverrideValidationReturn()
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
Expand Down Expand Up @@ -217,7 +217,7 @@ class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', torch.tensor(val_loss))
self.log('abc', val_loss)

model = CurrentModel()

Expand All @@ -231,28 +231,6 @@ def validation_epoch_end(self, outputs):
assert trainer.current_epoch == 5, 'early_stopping failed'


def test_early_stopping_functionality_arbitrary_key(tmpdir):
"""Tests whether early stopping works with a custom key and dictionary results on val step."""

class CurrentModel(BoringModel):

def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
return {'jiraffe': torch.tensor(val_loss)}

model = CurrentModel()

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[EarlyStopping(monitor='jiraffe')],
overfit_batches=0.20,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch >= 5, 'early_stopping failed'


@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, min_steps, min_epochs):
"""Excepted Behaviour:
Expand All @@ -269,7 +247,7 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
when `early_stopping` is being triggered,
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
This test validate those expected behaviours
"""
Expand Down Expand Up @@ -306,7 +284,7 @@ def validation_epoch_end(self, outputs):
self._count_decrease += 1
self._loss_value -= self._eps
self._values.append(_mean)
return {"test_val_loss": _mean}
self.log('test_val_loss', _mean)

model = Model(step_freeze)
model.training_step_end = None
Expand Down
Loading

0 comments on commit 61b7fd5

Please sign in to comment.