Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable logger connector re-design #7891

Merged
merged 22 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Allow passing `self.log(batch_size=...)` ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
* Each of the training loops now keeps its own metrics ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

Expand Down Expand Up @@ -161,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))


### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
Expand Down
4 changes: 4 additions & 0 deletions docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ except functions with `batch_start` in their names.
def training_step(self, batch, batch_idx):
self.log('my_metric', x)

# or a dict
def training_step(self, batch, batch_idx):
self.log('performance', {'acc': acc, 'recall': recall})

Depending on where log is called from, Lightning auto-determines the correct logging mode for you. \
But of course you can override the default behavior by manually setting the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` parameters.

Expand Down
88 changes: 42 additions & 46 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
import uuid
from abc import ABC
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -43,16 +42,13 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
from pytorch_lightning.trainer.connectors.logger_connector.result import Result

warning_cache = WarningCache()
log = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,7 +105,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# optionally can be set by user
self._example_input_array = None
self._datamodule = None
self._results: Optional['Result'] = None
self._current_fx_name: Optional[str] = None
self._running_manual_backward: bool = False
self._current_dataloader_idx: Optional[int] = None
Expand Down Expand Up @@ -267,14 +262,15 @@ def log(
logger: bool = True,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Callable = torch.mean,
reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_op: Optional = None, # noqa: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
) -> None:
"""
Log a key, value
Expand All @@ -298,19 +294,20 @@ def log(

Args:
name: key to log
value: value to log
value: value to log. Can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former.
prog_bar: if True logs to the progress bar
logger: if True logs to the logger
on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
sync_dist_group: the ddp group to sync across
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
Comment on lines +309 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the docs for batch size, is this what you mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but what I'm missing is why e.g. what is that parameter used for (inferred or otherwise)?

Copy link
Contributor

@awaelchli awaelchli Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To compute the correct average when we ask self.log to average the metric on epoch end.

It has to be weighted by the batch size because often the last batch does not have the same size as the others.
The dataset is not guaranteed to be divisible by the batch size and the drop_last in the PyTorch DataLoader is False by default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the ResultMetric in result.py you will find the line:

self.cumulated_batch_size += batch_size
and the cumulated_batch_size is then used in the compute() method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok that's what I thought it was for but I couldn't find in the code where it's actually doing that.

So does this mean that my _step should log a scalar which is the mean of the current batch and PL will correctly average (including across DDP processes) by multiplying with the batch size, summing, then dividing by the dataset size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
if tbptt_reduce_fx is not None:
rank_zero_deprecation(
Expand All @@ -324,6 +321,15 @@ def log(
' Please, open a discussion explaining your use-case in'
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
)
if sync_dist_op is not None:
rank_zero_deprecation(
f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6."
f" Use `self.log(reduce_fx={sync_dist_op})` instead."
)
if reduce_fx == 'default':
reduce_fx = sync_dist_op
elif reduce_fx == 'default':
reduce_fx = 'mean'

# check for invalid values
apply_to_collection(value, dict, self.__check_not_nested, name)
Expand All @@ -335,8 +341,10 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

results = self.trainer._results
assert results is not None
assert self._current_fx_name is not None
self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
Expand All @@ -345,18 +353,15 @@ def log(
" but it should not contain information about `dataloader_idx`"
)

sync_fn = partial(
self.__sync,
sync_fn=self.trainer.training_type_plugin.reduce,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)
value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn)
value = apply_to_collection(value, numbers.Number, self.__to_tensor)

assert self._results is not None
self._results.log(
if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running it's first batch) the hook name has changed
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)

results.log(
self._current_fx_name,
name,
value,
prog_bar=prog_bar,
Expand All @@ -366,21 +371,27 @@ def log(
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
sync_dist=sync_dist,
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available,
sync_dist_group=sync_dist_group,
)

self.trainer.logger_connector._current_fx = self._current_fx_name

def log_dict(
self,
dictionary: Mapping[str, _METRIC_COLLECTION],
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Callable = torch.mean,
reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_op: Optional = None, # noqa: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
) -> None:
Expand All @@ -393,15 +404,15 @@ def log_dict(
self.log_dict(values)

Args:
dictionary: key value pairs (str, tensors)
dictionary: key value pairs.
The values can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former.
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
sync_dist_group: the ddp group sync across
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
Expand All @@ -426,25 +437,7 @@ def log_dict(
)

@staticmethod
def __sync(
value: Union[torch.Tensor, numbers.Number],
sync_fn: Optional[Callable] = None,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
device: torch.device = None,
) -> torch.Tensor:
"""Sync across workers when using distributed training"""
if isinstance(value, numbers.Number):
value = torch.tensor(value, device=device, dtype=torch.float)
sync_fn = sync_fn or sync_ddp_if_available
dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
if not sync_dist or not dist_available:
return value
return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

@staticmethod
def __check_not_nested(value: dict, name: str) -> None:
def __check_not_nested(value: dict, name: str) -> dict:
# self-imposed restriction. for simplicity
if any(isinstance(v, dict) for v in value.values()):
raise ValueError(f'`self.log({name}, {value})` was called, but nested dictionaries cannot be logged')
Expand All @@ -454,6 +447,9 @@ def __check_not_nested(value: dict, name: str) -> None:
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged')

def __to_tensor(self, value: numbers.Number) -> torch.Tensor:
return torch.tensor(value, device=self.device)

def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None:
"""Override this method to change the default behaviour of ``log_grad_norm``.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
self.trainer.results.extract_batch_size(batch)
self.trainer._results.extract_batch_size(batch)
self._batch_idx = batch_idx

def update_eval_step_metrics(self) -> None:
Expand Down Expand Up @@ -210,7 +210,7 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
"""

def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
self.trainer.results.extract_batch_size(split_batch)
self.trainer._results.extract_batch_size(split_batch)
self._batch_idx = batch_idx
self._split_idx = split_idx

Expand All @@ -232,7 +232,7 @@ def update_train_epoch_metrics(self) -> None:
self.log_metrics(metrics)

# reset result collection for next epoch
self.trainer.results.reset(metrics=True)
self.trainer._results.reset(metrics=True)

"""
Utilities and properties
Expand Down Expand Up @@ -273,7 +273,7 @@ def should_reset_tensors(self, fx: str) -> bool:
return is_different_fx and is_first_batch

def reset(self, metrics: Optional[bool] = None) -> None:
self.trainer.results.reset(metrics=metrics)
self.trainer._results.reset(metrics=metrics)
self._batch_idx = None
self._split_idx = None
self._current_fx = None
Expand All @@ -282,25 +282,25 @@ def reset(self, metrics: Optional[bool] = None) -> None:
def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]:
"""This function returns either batch or epoch metrics depending on ``_epoch_end_reached``."""
on_step = not self._epoch_end_reached
return self.trainer.results.metrics(on_step)
return self.trainer._results.metrics(on_step)

@property
def callback_metrics(self) -> Dict[str, _METRIC]:
if self.trainer.results:
if self.trainer._results:
metrics = self.metrics[MetricSource.CALLBACK]
self._callback_metrics.update(metrics)
return self._callback_metrics

@property
def logged_metrics(self) -> Dict[str, _METRIC]:
if self.trainer.results:
if self.trainer._results:
metrics = self.metrics[MetricSource.LOG]
self._logged_metrics.update(metrics)
return self._logged_metrics

@property
def progress_bar_metrics(self) -> Dict[str, float]:
if self.trainer.results:
if self.trainer._results:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class _Metadata:
reduce_fx: Union[str, Callable] = torch.mean
enable_graph: bool = False
dataloader_idx: Optional[int] = None
metric_attribute: Optional[str] = None
sync: _Sync = field(default_factory=_Sync)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -225,7 +224,7 @@ class ResultCollection(dict):
Example:

# `device` needs to be provided before logging
result = ResultCollection(True, torch.device("cpu"))
result = ResultCollection(training=True, torch.device("cpu"))

# you can log to a specific collection.
# arguments: fx, key, value, metadata
Expand Down Expand Up @@ -303,7 +302,6 @@ def log(
sync_dist_group: Optional[Any] = None,
dataloader_idx: Optional[int] = None,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
) -> None:
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
# no metrics should be logged with graphs
Expand Down Expand Up @@ -331,7 +329,6 @@ def log(
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
sync=_Sync(
should=sync_dist,
fn=sync_dist_fn,
Expand Down
Loading