Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove LoggerConnector.on_trainer_init #11121

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Dict, Optional

import torch

import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _AcceleratorType, memory
Expand Down Expand Up @@ -46,25 +44,6 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None)
self._batch_idx: Optional[int] = None
self._split_idx: Optional[int] = None

def on_trainer_init(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #10417 we propose to make the logger owned by the connector. The two methods on_trainer_init and configure_logger would then essentially become part of the LoggerConnector init without reference to the trainer. What are your thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would definitely prefer this approach. Right now, the Trainer is starting to become a bigger class and it is getting harder for users to parse it.

self,
logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]],
flush_logs_every_n_steps: Optional[int],
log_every_n_steps: int,
move_metrics_to_cpu: bool,
) -> None:
self.configure_logger(logger)
if flush_logs_every_n_steps is not None:
rank_zero_deprecation(
f"Setting `Trainer(flush_logs_every_n_steps={flush_logs_every_n_steps})` is deprecated in v1.5 "
"and will be removed in v1.7. Please configure flushing in the logger instead."
)
else:
flush_logs_every_n_steps = 100 # original default parameter
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu

@property
def should_flush_logs(self) -> bool:
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
Expand All @@ -75,21 +54,6 @@ def should_update_logs(self) -> bool:
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop

def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
if isinstance(logger, bool):
# default logger
self.trainer.logger = (
TensorBoardLogger(
save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id(), name="lightning_logs"
)
if logger
else None
)
elif isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
else:
self.trainer.logger = logger

def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
"""Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses
metrics["step"] as a step.
Expand Down
35 changes: 33 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,9 @@ def __init__(
# configure profiler
self.__init_profiler(profiler)

# init logger flags
# configure logger flags
self.logger: Optional[LightningLoggerBase]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played a bit with this and this definition needs to be in the __init__, otherwise mypy will treat it as Any

self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)
self.__init_logger_flags(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)

# init debugging flags
self._init_debugging_flags(
Expand Down Expand Up @@ -1632,6 +1632,37 @@ def __setup_profiler(self) -> None:
self.profiler._lightning_module = proxy(self.lightning_module)
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)

def __init_logger_flags(
self,
logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]],
flush_logs_every_n_steps: Optional[int],
log_every_n_steps: int,
move_metrics_to_cpu: bool,
) -> None:
if flush_logs_every_n_steps is not None:
rank_zero_deprecation(
f"Setting `Trainer(flush_logs_every_n_steps={flush_logs_every_n_steps})` is deprecated in v1.5 "
"and will be removed in v1.7. Please configure flushing in the logger instead."
)
else:
flush_logs_every_n_steps = 100 # original default parameter

self.flush_logs_every_n_steps = flush_logs_every_n_steps
self.log_every_n_steps = log_every_n_steps
self.move_metrics_to_cpu = move_metrics_to_cpu

if logger is True:
# default logger
self.logger = TensorBoardLogger(
save_dir=self.default_root_dir, version=SLURMEnvironment.job_id(), name="lightning_logs"
)
elif logger is False:
self.logger = None
elif isinstance(logger, Iterable):
self.logger = LoggerCollection(logger)
else:
self.logger = logger

def _log_device_info(self) -> None:
rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == _AcceleratorType.GPU}")

Expand Down