From a41704ee93032279150d4cf3e9f6fb02450a0020 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 27 Sep 2020 20:26:16 -0400 Subject: [PATCH] ref: add .log to lightning module (1/n) (#3686) --- pytorch_lightning/core/lightning.py | 75 +++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8a3c16dccbb49..6fc20dcac63d8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -32,6 +32,7 @@ from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.parsing import ( AttributeDict, collect_init_args, @@ -102,6 +103,8 @@ def __init__(self, *args, **kwargs): # optionally can be set by user self._example_input_array = None self._datamodule = None + self._results = None + self._current_fx_name = '' @property def example_input_array(self) -> Any: @@ -146,6 +149,78 @@ def forward(self, x): if self.trainer.is_global_zero: print(*args, **kwargs) + def log( + self, + name: str, + value: Any, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = True, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + ): + """ + Log a key, value + + Example:: + + result.log('train_loss', loss) + + # defaults used + result.log( + name, + value, + on_step=False, + on_epoch=False, + logger=True, + prog_bar=False, + reduce_fx=torch.mean, + enable_graph=False + ) + + + Args: + name: key name + value: value name + prog_bar: if True logs to the progress base + logger: if True logs to the logger + on_step: if True logs the output of validation_step or test_step + on_epoch: if True, logs the output of the training loop aggregated + reduce_fx: Torch.mean by default + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding + enable_graph: if True, will not auto detach the graph + sync_dist: if True, reduces the metric across GPUs/TPUs + sync_dist_op: the op to sync across + sync_dist_group: the ddp group + """ + if self._results is not None: + # in any epoch end can't log step metrics (only epoch metric) + if 'epoch_end' in self._current_fx_name and on_step: + on_step = False + + self._results.log( + name, + value, + prog_bar, + logger, + on_step, + on_epoch, + reduce_fx, + tbptt_reduce_fx, + tbptt_pad_token, + enable_graph, + sync_dist, + sync_dist_op, + sync_dist_group + ) + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define