Skip to content

Commit

Permalink
ref: add .log to lightning module (1/n) (#3686)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Sep 28, 2020
1 parent f37e9e8 commit a41704e
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
Expand Down Expand Up @@ -102,6 +103,8 @@ def __init__(self, *args, **kwargs):
# optionally can be set by user
self._example_input_array = None
self._datamodule = None
self._results = None
self._current_fx_name = ''

@property
def example_input_array(self) -> Any:
Expand Down Expand Up @@ -146,6 +149,78 @@ def forward(self, x):
if self.trainer.is_global_zero:
print(*args, **kwargs)

def log(
self,
name: str,
value: Any,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = True,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
):
"""
Log a key, value
Example::
result.log('train_loss', loss)
# defaults used
result.log(
name,
value,
on_step=False,
on_epoch=False,
logger=True,
prog_bar=False,
reduce_fx=torch.mean,
enable_graph=False
)
Args:
name: key name
value: value name
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
tbptt_reduce_fx: function to reduce on truncated back prop
tbptt_pad_token: token to use for padding
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across
sync_dist_group: the ddp group
"""
if self._results is not None:
# in any epoch end can't log step metrics (only epoch metric)
if 'epoch_end' in self._current_fx_name and on_step:
on_step = False

self._results.log(
name,
value,
prog_bar,
logger,
on_step,
on_epoch,
reduce_fx,
tbptt_reduce_fx,
tbptt_pad_token,
enable_graph,
sync_dist,
sync_dist_op,
sync_dist_group
)

def forward(self, *args, **kwargs):
r"""
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
Expand Down

0 comments on commit a41704e

Please sign in to comment.