Skip to content

Commit

Permalink
Use default_root_dir as the log_dir with LoggerCollections (#8187)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people authored Jul 19, 2021
1 parent a6fd32a commit 6604fc1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
Expand Down Expand Up @@ -226,8 +227,12 @@ def model(self, model: torch.nn.Module) -> None:
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
elif isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
elif isinstance(self.logger, LoggerCollection):
dirpath = self.default_root_dir
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
dirpath = self.logger.save_dir

dirpath = self.accelerator.broadcast(dirpath)
return dirpath
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def __init__(
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)
logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
the default ``TensorBoardLogger``. ``False`` will disable logging.
the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers are
provided and the `save_dir` property of that logger is not set, local files (checkpoints,
profiler traces, etc.) are saved in ``default_root_dir`` rather than in the ``log_dir`` of any
of the individual loggers.
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
Expand Down
32 changes: 32 additions & 0 deletions tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from packaging.version import Version

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -430,6 +432,36 @@ def test_pytorch_profiler_nested(tmpdir):
assert events_name == expected, (events_name, torch.__version__, platform.system())


def test_pytorch_profiler_logger_collection(tmpdir):
"""
Tests whether the PyTorch profiler is able to write its trace locally when
the Trainer's logger is an instance of LoggerCollection. See issue #8157.
"""

def look_for_trace(trace_dir):
""" Determines if a directory contains a PyTorch trace """
return any("trace.json" in filename for filename in os.listdir(trace_dir))

# Sanity check
assert not look_for_trace(tmpdir)

model = BoringModel()

# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir)]
trainer = Trainer(
default_root_dir=tmpdir,
profiler="pytorch",
logger=logger,
limit_train_batches=5,
max_epochs=1,
)

assert isinstance(trainer.logger, LoggerCollection)
trainer.fit(model)
assert look_for_trace(tmpdir)


@RunIf(min_gpus=1, special=True)
def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
"""
Expand Down
19 changes: 18 additions & 1 deletion tests/trainer/properties/log_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -140,3 +140,20 @@ def test_logdir_custom_logger(tmpdir):
assert trainer.log_dir == expected
trainer.fit(model)
assert trainer.log_dir == expected


def test_logdir_logger_collection(tmpdir):
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection"""
default_root_dir = tmpdir / "default_root_dir"
save_dir = tmpdir / "save_dir"
model = TestModel(default_root_dir)
trainer = Trainer(
default_root_dir=default_root_dir,
max_steps=2,
logger=[TensorBoardLogger(save_dir=save_dir, name='custom_logs')]
)
assert isinstance(trainer.logger, LoggerCollection)
assert trainer.log_dir == default_root_dir

trainer.fit(model)
assert trainer.log_dir == default_root_dir

0 comments on commit 6604fc1

Please sign in to comment.