Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove on epoch guard from the should stop validation check #7701

Merged
merged 2 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,4 @@ def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]:

@staticmethod
def _should_log(trainer) -> bool:
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)

return should_log
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
4 changes: 1 addition & 3 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,4 @@ def _find_names(self, lr_schedulers) -> List[str]:

@staticmethod
def _should_log(trainer) -> bool:
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)

return should_log
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
18 changes: 4 additions & 14 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,21 +529,11 @@ def run_training_epoch(self):

self.total_batch_idx += 1

# max steps reached, end training
if (
max_steps_reached = (
self.max_steps is not None and self.max_steps <= self.global_step + 1
and self._accumulated_batches_reached()
):
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if self.trainer.should_stop:
break

# stop epoch if we limited the number of training batches
if self._num_training_batches_reached(is_last_batch):
)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break

# progress global step according to grads progress
Expand Down Expand Up @@ -906,7 +896,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo
if on_epoch and is_last_batch and is_infinite_dataset:
return True

if on_epoch and self.trainer.should_stop:
if self.trainer.should_stop:
return True

# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
Expand Down
32 changes: 32 additions & 0 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,35 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
else:
assert trainer.train_loop.batch_idx == batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs


def test_should_stop_mid_epoch(tmpdir):
"""Test that training correctly stops mid epoch and that validation is still called at the right time"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.validation_called_at = None

def training_step(self, batch, batch_idx):
if batch_idx == 4:
self.trainer.should_stop = True
return super().training_step(batch, batch_idx)

def validation_step(self, *args):
self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step)
return super().validation_step(*args)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=10,
limit_val_batches=1,
)
trainer.fit(model)

assert trainer.current_epoch == 0
assert trainer.global_step == 5
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 5, in the current way global step is used it is expected.
what are you planning to fix here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, but the current way is "wrong" (although expected given the design in master). That's why I used the verb "fix"

The fix will be incrementing the global_step count appropriately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we said we would use the progress tracking to count optimizer step?
The only difference between global_step and the real optimizer step count is that the global_step update is delayed for logging.
Sorry if I mix something up again here.