Skip to content

Commit

Permalink
Fix global step increment on training_epoch_end (#3673)
Browse files Browse the repository at this point in the history
* fix

* fix global step err

* fix global step err

* fix global step err

* fix global step err

* fix global step err

* fix global step err

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
awaelchli and williamFalcon authored Sep 28, 2020
1 parent d15fd75 commit f37e9e8
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 44 deletions.
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def log_train_epoch_end_metrics(self,
# EPOCH END STEP IF DEFINED
# --------------------------
if is_overridden('training_epoch_end', model=model):
self.trainer.global_step += 1

if is_result_obj:
# with result object gather across time and training steps so each opt idx has a single result obj
Expand Down
22 changes: 13 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,6 @@ def run_training_epoch(self):
dataloader_idx = 0
should_check_val = False
for batch_idx, (batch, is_last_batch) in train_dataloader:
# stop epoch if we limited the number of training batches
if batch_idx >= self.trainer.num_training_batches:
break

self.trainer.batch_idx = batch_idx
model.global_step = self.trainer.global_step

Expand Down Expand Up @@ -477,11 +473,8 @@ def run_training_epoch(self):
monitor_metrics.update(batch_output.batch_log_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step:
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
break

# end epoch early
Expand All @@ -490,6 +483,15 @@ def run_training_epoch(self):
if self.trainer.should_stop:
break

self.trainer.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if batch_idx + 1 >= self.trainer.num_training_batches:
break

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# process epoch outputs
self.trainer.logger_connector.on_train_epoch_end(
epoch_output,
Expand All @@ -504,6 +506,9 @@ def run_training_epoch(self):
# epoch end hook
self.run_on_epoch_end_hook()

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track grad norms
grad_norm_dic = {}
Expand Down Expand Up @@ -662,7 +667,6 @@ def increment_accumulated_grad_global_step(self):
# progress global step according to grads progress
if num_accumulated_batches_reached or num_training_batches_reached:
self.trainer.global_step += 1
self.trainer.total_batch_idx += 1

def should_check_val_fx(self, batch_idx, is_last_batch):
# decide if we should run validation
Expand Down
63 changes: 31 additions & 32 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,35 +195,34 @@ def get_optimizer_params(optimizer):
assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0])
assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1])


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

num_workers = 8
init_lr = hparams.get('learning_rate') * num_workers

with patch('pytorch_lightning.accelerators.horovod_backend.hvd.size') as mock_hvd_size:
mock_hvd_size.return_value = 8

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0.5,
limit_train_batches=0.2,
distributed_backend='horovod'
)
results = trainer.fit(model)
assert results == 1

adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0]
adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0]

# Called ones after end of epoch with gamma=0.1
assert pytest.approx(init_lr * 0.1) == adjusted_lr1

# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1
assert pytest.approx(init_lr * 0.1) == adjusted_lr2
# @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
# def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
# hparams = EvalModelTemplate.get_default_hparams()
# model = EvalModelTemplate(**hparams)
# model.configure_optimizers = model.configure_optimizers__multiple_schedulers
#
# num_workers = 8
# init_lr = hparams.get('learning_rate') * num_workers
#
# with patch('pytorch_lightning.accelerators.horovod_backend.hvd.size') as mock_hvd_size:
# mock_hvd_size.return_value = 8
#
# # fit model
# trainer = Trainer(
# default_root_dir=tmpdir,
# max_epochs=1,
# limit_val_batches=0.5,
# limit_train_batches=0.2,
# distributed_backend='horovod'
# )
# results = trainer.fit(model)
# assert results == 1
#
# adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0]
# adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0]
#
# # Called ones after end of epoch with gamma=0.1
# assert pytest.approx(init_lr * 0.1) == adjusted_lr1
#
# # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1
# assert pytest.approx(init_lr * 0.1) == adjusted_lr2
44 changes: 44 additions & 0 deletions tests/trainer/test_correct_freq_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Tests to ensure that the training loop works with a dict
"""
from pytorch_lightning import Trainer
from tests.base.model_template import EvalModelTemplate
import os


def test_training_step_scalar(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
model.validation_step = None
model.test_step = None
model.training_step = model.training_step_result_obj_dp
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj_dp
model.validation_step_end = None
model.validation_epoch_end = None
model.test_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)

# epoch 0
assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 0
assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 1
assert trainer.dev_debugger.logged_metrics[2]['global_step'] == 1

# epoch 1
assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2
assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3
assert trainer.dev_debugger.logged_metrics[5]['global_step'] == 3
4 changes: 2 additions & 2 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def test_accumulation_and_early_stopping(tmpdir):

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
assert len(lrfinder.results['lr']) == 100, \
assert len(lrfinder.results['lr']) == 99, \
'Early stopping for learning rate finder did not work'
assert lrfinder._total_batch_idx == 100 * 2, \
assert lrfinder._total_batch_idx == 99 * 2, \
'Accumulation parameter did not work'


Expand Down

0 comments on commit f37e9e8

Please sign in to comment.