Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix global step increment on training_epoch_end #3673

Merged
merged 7 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def log_train_epoch_end_metrics(self,
# EPOCH END STEP IF DEFINED
# --------------------------
if is_overridden('training_epoch_end', model=model):
self.trainer.global_step += 1

if is_result_obj:
# with result object gather across time and training steps so each opt idx has a single result obj
Expand Down
22 changes: 13 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,6 @@ def run_training_epoch(self):
dataloader_idx = 0
should_check_val = False
for batch_idx, (batch, is_last_batch) in train_dataloader:
# stop epoch if we limited the number of training batches
if batch_idx >= self.trainer.num_training_batches:
break

self.trainer.batch_idx = batch_idx
model.global_step = self.trainer.global_step

Expand Down Expand Up @@ -477,11 +473,8 @@ def run_training_epoch(self):
monitor_metrics.update(batch_output.batch_log_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step:
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
break

# end epoch early
Expand All @@ -490,6 +483,15 @@ def run_training_epoch(self):
if self.trainer.should_stop:
break

self.trainer.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if batch_idx + 1 >= self.trainer.num_training_batches:
break

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# process epoch outputs
self.trainer.logger_connector.on_train_epoch_end(
epoch_output,
Expand All @@ -504,6 +506,9 @@ def run_training_epoch(self):
# epoch end hook
self.run_on_epoch_end_hook()

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track grad norms
grad_norm_dic = {}
Expand Down Expand Up @@ -662,7 +667,6 @@ def increment_accumulated_grad_global_step(self):
# progress global step according to grads progress
if num_accumulated_batches_reached or num_training_batches_reached:
self.trainer.global_step += 1
self.trainer.total_batch_idx += 1

def should_check_val_fx(self, batch_idx, is_last_batch):
# decide if we should run validation
Expand Down
63 changes: 31 additions & 32 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,35 +195,34 @@ def get_optimizer_params(optimizer):
assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0])
assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1])


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

num_workers = 8
init_lr = hparams.get('learning_rate') * num_workers

with patch('pytorch_lightning.accelerators.horovod_backend.hvd.size') as mock_hvd_size:
mock_hvd_size.return_value = 8

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0.5,
limit_train_batches=0.2,
distributed_backend='horovod'
)
results = trainer.fit(model)
assert results == 1

adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0]
adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0]

# Called ones after end of epoch with gamma=0.1
assert pytest.approx(init_lr * 0.1) == adjusted_lr1

# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1
assert pytest.approx(init_lr * 0.1) == adjusted_lr2
# @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tgaddair weird error... happens only on some machines some times haha

# def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
# hparams = EvalModelTemplate.get_default_hparams()
# model = EvalModelTemplate(**hparams)
# model.configure_optimizers = model.configure_optimizers__multiple_schedulers
#
# num_workers = 8
# init_lr = hparams.get('learning_rate') * num_workers
#
# with patch('pytorch_lightning.accelerators.horovod_backend.hvd.size') as mock_hvd_size:
# mock_hvd_size.return_value = 8
#
# # fit model
# trainer = Trainer(
# default_root_dir=tmpdir,
# max_epochs=1,
# limit_val_batches=0.5,
# limit_train_batches=0.2,
# distributed_backend='horovod'
# )
# results = trainer.fit(model)
# assert results == 1
#
# adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0]
# adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0]
#
# # Called ones after end of epoch with gamma=0.1
# assert pytest.approx(init_lr * 0.1) == adjusted_lr1
#
# # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1
# assert pytest.approx(init_lr * 0.1) == adjusted_lr2
44 changes: 44 additions & 0 deletions tests/trainer/test_correct_freq_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Tests to ensure that the training loop works with a dict
"""
from pytorch_lightning import Trainer
from tests.base.model_template import EvalModelTemplate
import os


def test_training_step_scalar(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
model.validation_step = None
model.test_step = None
model.training_step = model.training_step_result_obj_dp
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj_dp
model.validation_step_end = None
model.validation_epoch_end = None
model.test_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)

# epoch 0
assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 0
assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 1
assert trainer.dev_debugger.logged_metrics[2]['global_step'] == 1

# epoch 1
assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2
assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3
assert trainer.dev_debugger.logged_metrics[5]['global_step'] == 3
4 changes: 2 additions & 2 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def test_accumulation_and_early_stopping(tmpdir):

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
assert len(lrfinder.results['lr']) == 100, \
assert len(lrfinder.results['lr']) == 99, \
'Early stopping for learning rate finder did not work'
assert lrfinder._total_batch_idx == 100 * 2, \
assert lrfinder._total_batch_idx == 99 * 2, \
'Accumulation parameter did not work'


Expand Down