Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Nov 9, 2021
1 parent 51a7d36 commit c2a3626
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
8 changes: 4 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,10 @@ def log_dict(
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
_FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

if isinstance(reduce_fx, str):
Expand Down Expand Up @@ -481,10 +485,6 @@ def log_dict(
elif reduce_fx == "default":
reduce_fx = "mean"

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running it's first batch) the hook name has changed
# reset any tensors for the new hook name
Expand Down
6 changes: 1 addition & 5 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def training_step(self, batch, batch_idx):

trainer = Trainer(default_root_dir=tmpdir)
model = TestModel()
with pytest.raises(MisconfigurationException, match="`self.log` with the key `foo/dataloader_idx_0`"):
with pytest.raises(MisconfigurationException, match="`self.log` with the key `'foo/dataloader_idx_0'`"):
trainer.fit(model)

class TestModel(BoringModel):
Expand Down Expand Up @@ -717,19 +717,15 @@ def on_validation_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_train_batch_start(self, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_start", 1.0, reduce_fx="sum")

def on_train_batch_end(self, outputs, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_end", 1.0, reduce_fx="sum")

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

def training_epoch_end(self, *_) -> None:
Expand Down

0 comments on commit c2a3626

Please sign in to comment.