Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable logger connector re-design #7891

Merged
merged 22 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Allow passing `self.log(batch_size=...)` ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
* Allow passing `self.log(metric_attribute='your_metric')` to properly serialize the state of any `torchmetrics.Metric`s in your model ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
* Each of the training loops now keeps its own metrics ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

Expand Down Expand Up @@ -161,6 +165,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))


### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
Expand Down
4 changes: 4 additions & 0 deletions docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ except functions with `batch_start` in their names.
def training_step(self, batch, batch_idx):
self.log('my_metric', x)

# or a dict
def training_step(self, batch, batch_idx):
self.log('performance', {'acc': acc, 'recall': recall})

Depending on where log is called from, Lightning auto-determines the correct logging mode for you. \
But of course you can override the default behavior by manually setting the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` parameters.

Expand Down
113 changes: 67 additions & 46 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
import uuid
from abc import ABC
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -43,16 +42,13 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
from pytorch_lightning.trainer.connectors.logger_connector.result import Result

warning_cache = WarningCache()
log = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,13 +105,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# optionally can be set by user
self._example_input_array = None
self._datamodule = None
self._results: Optional['Result'] = None
self._current_fx_name: Optional[str] = None
self._running_manual_backward: bool = False
self._current_dataloader_idx: Optional[int] = None
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = dict()
self._metric_attributes: Optional[Dict[int, str]] = None

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
if use_pl_optimizer:
Expand Down Expand Up @@ -267,14 +263,16 @@ def log(
logger: bool = True,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Callable = torch.mean,
reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_op: Optional = None, # noqa: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
) -> None:
"""
Log a key, value
Expand All @@ -298,19 +296,22 @@ def log(

Args:
name: key to log
value: value to log
value: value to log. Can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former.
prog_bar: if True logs to the progress bar
logger: if True logs to the logger
on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
sync_dist_group: the ddp group to sync across
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some some data structures might need to explicitly provide it.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
metric_attribute: The attribute name for the metric in the LightningModule.
Necessary to save/restore its state.
"""
if tbptt_reduce_fx is not None:
rank_zero_deprecation(
Expand All @@ -324,6 +325,15 @@ def log(
' Please, open a discussion explaining your use-case in'
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
)
if sync_dist_op is not None:
rank_zero_deprecation(
f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6."
f" Use `self.log(reduce_fx={sync_dist_op})` instead."
)
if reduce_fx == 'default':
reduce_fx = sync_dist_op
elif reduce_fx == 'default':
reduce_fx = 'mean'

# check for invalid values
apply_to_collection(value, dict, self.__check_not_nested, name)
Expand All @@ -335,8 +345,10 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

results = self.trainer.results
assert results is not None
assert self._current_fx_name is not None
self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
Expand All @@ -345,18 +357,35 @@ def log(
" but it should not contain information about `dataloader_idx`"
)

sync_fn = partial(
self.__sync,
sync_fn=self.trainer.training_type_plugin.reduce,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)
value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn)
if metric_attribute is None and isinstance(value, Metric):
if self._metric_attributes is None:
# compute once
self._metric_attributes = {
id(module): name
for name, module in self.named_children() if isinstance(module, Metric)
}
if not self._metric_attributes:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
carmocca marked this conversation as resolved.
Show resolved Hide resolved
" You can fix this by setting an attribute for the metric in your `LightningModule`."
)
# try to find the passed metric in the LightningModule
metric_attribute = self._metric_attributes.get(id(value))
if metric_attribute is None:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
f" of {list(self._metric_attributes.values())}"
)

value = apply_to_collection(value, numbers.Number, self.__to_float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we conserve the logged type ?

Copy link
Contributor Author

@carmocca carmocca Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So not converting to float tensor but just wrapping it in tensor?

We can, but I don't think this matters as ResultMetric.update will convert it to float anyways

edit: changed to __to_tensor


assert self._results is not None
self._results.log(
if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
# when restarting an new epoch, reset the tensors
carmocca marked this conversation as resolved.
Show resolved Hide resolved
results.reset(metrics=False, fx=self._current_fx_name)

results.log(
self._current_fx_name,
name,
value,
prog_bar=prog_bar,
Expand All @@ -366,21 +395,28 @@ def log(
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
metric_attribute=metric_attribute,
sync_dist=sync_dist,
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available,
sync_dist_group=sync_dist_group,
)

self.trainer.logger_connector._current_fx = self._current_fx_name

def log_dict(
self,
dictionary: Mapping[str, _METRIC_COLLECTION],
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Callable = torch.mean,
reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_op: Optional = None, # noqa: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
) -> None:
Expand All @@ -393,15 +429,15 @@ def log_dict(
self.log_dict(values)

Args:
dictionary: key value pairs (str, tensors)
dictionary: key value pairs.
The values can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former.
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
sync_dist_group: the ddp group sync across
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
Expand All @@ -426,25 +462,7 @@ def log_dict(
)

@staticmethod
def __sync(
value: Union[torch.Tensor, numbers.Number],
sync_fn: Optional[Callable] = None,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
device: torch.device = None,
) -> torch.Tensor:
"""Sync across workers when using distributed training"""
if isinstance(value, numbers.Number):
value = torch.tensor(value, device=device, dtype=torch.float)
sync_fn = sync_fn or sync_ddp_if_available
dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
if not sync_dist or not dist_available:
return value
return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

@staticmethod
def __check_not_nested(value: dict, name: str) -> None:
def __check_not_nested(value: dict, name: str) -> dict:
# self-imposed restriction. for simplicity
if any(isinstance(v, dict) for v in value.values()):
raise ValueError(f'`self.log({name}, {value})` was called, but nested dictionaries cannot be logged')
Expand All @@ -454,6 +472,9 @@ def __check_not_nested(value: dict, name: str) -> None:
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f'`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged')

def __to_float(self, value: numbers.Number) -> torch.Tensor:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return torch.tensor(value, device=self.device, dtype=torch.float)

def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None:
"""Override this method to change the default behaviour of ``log_grad_norm``.

Expand Down
Loading