Skip to content

Commit

Permalink
quick fix (#4697)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Nov 17, 2020
1 parent a3576ea commit 1ac42ab
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed docs typo ([#4659](https://github.com/PyTorchLightning/pytorch-lightning/pull/4659), [#4670](https://github.com/PyTorchLightning/pytorch-lightning/pull/4670))
- Fixed notebooks typo ([#4657](https://github.com/PyTorchLightning/pytorch-lightning/pull/4657))
- Allowing decorate model init with saving `hparams` inside ([#4662](https://github.com/PyTorchLightning/pytorch-lightning/pull/4662))
- Fixed `split_idx` set by `LoggerConnector` in `on_trainer_init` to `Trainer` ([#4697](https://github.com/PyTorchLightning/pytorch-lightning/pull/4697))



Expand Down
28 changes: 21 additions & 7 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import ChainMap
from copy import deepcopy
from pprint import pprint
from typing import Iterable

import torch

from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pprint import pprint
from typing import Iterable
from copy import deepcopy
from collections import ChainMap
from pytorch_lightning.utilities.model_utils import is_overridden


class LoggerConnector:
Expand All @@ -42,6 +44,18 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps

self.trainer.split_idx = None

@property
def should_flush_logs(self):
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
return should_flush or self.trainer.should_stop

@property
def should_update_logs(self):
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop

def configure_logger(self, logger):
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
Expand Down

0 comments on commit 1ac42ab

Please sign in to comment.