From 5723c6489448cb01888a3c6ecab721d3694d435a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:06:47 +0200 Subject: [PATCH 01/17] init --- pytorch_lightning/trainer/trainer.py | 11 +++-------- pytorch_lightning/trainer/training_loop.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2a6a53a7c192c..10f15a196e37b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -329,7 +329,9 @@ def __init__( self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.train_loop = TrainLoop(self, multiple_trainloader_mode) + self.train_loop = TrainLoop( + self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps + ) self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) @@ -375,13 +377,6 @@ def __init__( truncated_bptt_steps, terminate_on_nan, ) - self.train_loop.on_trainer_init( - max_epochs, - min_epochs, - max_steps, - min_steps, - num_sanity_val_steps, - ) self.evaluation_loop.on_trainer_init() # configure tuner diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 790dc4c70bdeb..584097331f7a2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -37,7 +37,15 @@ class TrainLoop: - def __init__(self, trainer, multiple_trainloader_mode: str): + def __init__( + self, + trainer, multiple_trainloader_mode: str, + max_epochs: Optional[int], + min_epochs: Optional[int], + max_steps: Optional[int], + min_steps: Optional[int], + num_sanity_val_steps: int, + ): self.trainer = trainer self.accumulated_loss = None self.warning_cache = WarningCache() @@ -50,14 +58,6 @@ def __init__(self, trainer, multiple_trainloader_mode: str): self.trainer._multiple_trainloader_mode = multiple_trainloader_mode self._optimizer_freq_cumsum = None - def on_trainer_init( - self, - max_epochs: Optional[int], - min_epochs: Optional[int], - max_steps: Optional[int], - min_steps: Optional[int], - num_sanity_val_steps: int, - ) -> None: self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.should_stop = False From 437a68fe001059f1a47fdbbd044213e2552e2e3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:12:12 +0200 Subject: [PATCH 02/17] global step --- .../trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/properties.py | 6 +++++- pytorch_lightning/trainer/training_loop.py | 2 +- tests/trainer/test_trainer.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 7a1d198615f08..3ad476a4ceefe 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -150,7 +150,7 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True) # restore callback states self.trainer.on_load_checkpoint(checkpoint) - self.trainer.global_step = checkpoint['global_step'] + self.trainer.train_loop.global_step = checkpoint['global_step'] self.trainer.current_epoch = checkpoint['epoch'] # crash if max_epochs is lower then the current epoch from the checkpoint diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index effc7af117cdf..4bd55af749918 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -49,7 +49,6 @@ class TrainerProperties(ABC): _default_root_dir: str _lightning_optimizers = None _progress_bar_callback: ProgressBarBase - state: TrainerState _weights_save_path: str accelerator_connector: AcceleratorConnector @@ -58,6 +57,8 @@ class TrainerProperties(ABC): limit_val_batches: int logger: LightningLoggerBase logger_connector: LoggerConnector + state: TrainerState + train_loop: TrainLoop @property def accelerator(self) -> Accelerator: @@ -485,6 +486,9 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self.state.stage = None + @property + def global_step(self): + return self.train_loop.global_step # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 584097331f7a2..f65aed810d1e8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -58,7 +58,7 @@ def __init__( self.trainer._multiple_trainloader_mode = multiple_trainloader_mode self._optimizer_freq_cumsum = None - self.trainer.global_step = 0 + self.global_step = 0 self.trainer.current_epoch = 0 self.trainer.should_stop = False self.trainer.state = TrainerState() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index abf5abcaae2bf..0ae94ae667131 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -340,7 +340,7 @@ def mock_save_function(filepath, *args): # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i - trainer.global_step = i + trainer.train_loop.global_step = i trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) From 17e9b6a8087bc45d39e04549a6285dd31bfcdd7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:14:16 +0200 Subject: [PATCH 03/17] global step --- pytorch_lightning/trainer/properties.py | 1 + pytorch_lightning/trainer/training_loop.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 4bd55af749918..e88a7ff8c0825 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -33,6 +33,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.states import RunningStage, TrainerState, TrainerStatus +from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn from pytorch_lightning.utilities.argparse import ( add_argparse_args, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f65aed810d1e8..2a41ebaceff44 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -92,7 +92,7 @@ def optimizer_freq_cumsum(self): return self._optimizer_freq_cumsum def should_skip_training(self): - should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps + should_by_max_steps = self.trainer.max_steps is not None and self.global_step >= self.trainer.max_steps should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 @@ -107,9 +107,9 @@ def on_train_end(self): # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 + self.global_step -= 1 self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 + self.global_step += 1 # hook self.trainer.call_hook("on_train_end") @@ -450,7 +450,7 @@ def track_and_norm_grad(self, optimizer): def _track_gradient_norm(self): grad_norm_dict = {} - if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: + if (self.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.lightning_module grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) @@ -530,7 +530,7 @@ def run_training_epoch(self): # max steps reached, end training if ( - self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 + self.trainer.max_steps is not None and self.trainer.max_steps <= self.global_step + 1 and self._accumulated_batches_reached() ): break @@ -887,8 +887,8 @@ def increment_accumulated_grad_global_step(self): # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: - self.trainer.global_step = self.trainer.accelerator.update_global_step( - self.trainer.total_batch_idx, self.trainer.global_step + self.global_step = self.trainer.accelerator.update_global_step( + self.trainer.total_batch_idx, self.global_step ) def _accumulated_batches_reached(self): From 6b6d77a73cc5f5f4c5ff732da294d70aa5b30de6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:18:09 +0200 Subject: [PATCH 04/17] current epoch --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/properties.py | 3 +++ pytorch_lightning/trainer/training_loop.py | 4 ++-- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++-- tests/trainer/test_trainer.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3ad476a4ceefe..1f61b33a74b9c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -151,7 +151,7 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True) self.trainer.on_load_checkpoint(checkpoint) self.trainer.train_loop.global_step = checkpoint['global_step'] - self.trainer.current_epoch = checkpoint['epoch'] + self.trainer.train_loop.current_epoch = checkpoint['epoch'] # crash if max_epochs is lower then the current epoch from the checkpoint if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index e88a7ff8c0825..0e23606241ec2 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -491,5 +491,8 @@ def sanity_checking(self, val: bool) -> None: def global_step(self): return self.train_loop.global_step + @property + def current_epoch(self): + return self.train_loop.current_epoch # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2a41ebaceff44..a833c0eec5b59 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -59,7 +59,7 @@ def __init__( self._optimizer_freq_cumsum = None self.global_step = 0 - self.trainer.current_epoch = 0 + self.current_epoch = 0 self.trainer.should_stop = False self.trainer.state = TrainerState() @@ -145,7 +145,7 @@ def check_checkpoint_callback(self, should_update, is_last=False): def on_train_epoch_start(self, epoch): # update training progress in trainer - self.trainer.current_epoch = epoch + self.current_epoch = epoch model = self.trainer.lightning_module diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 303af3a117d81..4c0146cf3036d 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -115,7 +115,7 @@ def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None: def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times - trainer.current_epoch = 0 + trainer.train_loop.current_epoch = 0 trainer.max_steps = steps_per_trial # take few steps trainer.weights_summary = None # not needed before full run trainer.logger = DummyLogger() @@ -127,7 +127,7 @@ def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] - trainer.current_epoch = trainer.__dumped_params['current_epoch'] + trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch'] trainer.max_steps = trainer.__dumped_params['max_steps'] trainer.weights_summary = trainer.__dumped_params['weights_summary'] trainer.logger = trainer.__dumped_params['logger'] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0ae94ae667131..b2fc8ff266122 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -339,7 +339,7 @@ def mock_save_function(filepath, *args): # emulate callback's calls during the training for i, loss in enumerate(losses): - trainer.current_epoch = i + trainer.train_loop.current_epoch = i trainer.train_loop.global_step = i trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) From cccefdd8ae8c13c965dbcdbbf941733868b28f2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:24:00 +0200 Subject: [PATCH 05/17] min/max epochs --- .../trainer/connectors/debugging_connector.py | 2 +- pytorch_lightning/trainer/properties.py | 8 ++++++++ pytorch_lightning/trainer/training_loop.py | 4 ++-- tests/checkpointing/test_model_checkpoint.py | 2 +- tests/trainer/test_trainer.py | 2 +- 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 28c99f8f4de6d..b9d0e30ff19ce 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -60,7 +60,7 @@ def on_init_start( limit_predict_batches = fast_dev_run self.trainer.max_steps = fast_dev_run self.trainer.num_sanity_val_steps = 0 - self.trainer.max_epochs = 1 + self.trainer.train_loop.max_epochs = 1 val_check_interval = 1.0 self.trainer.check_val_every_n_epoch = 1 self.trainer.logger = DummyLogger() diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 0e23606241ec2..79edfda95470c 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -494,5 +494,13 @@ def global_step(self): @property def current_epoch(self): return self.train_loop.current_epoch + + @property + def max_epochs(self): + return self.train_loop.max_epochs + + @property + def min_epochs(self): + return self.train_loop.min_epochs # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a833c0eec5b59..886e80708bc3d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -69,9 +69,9 @@ def __init__( self.trainer.train_dataloader = None # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 - self.trainer.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs + self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 - self.trainer.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs + self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.trainer.max_steps = max_steps self.trainer.min_steps = min_steps diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d35d8e4badb44..41b3bc89f224a 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1244,7 +1244,7 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir): progress_bar_refresh_rate=0, ) trainer.fit(BoringModel()) - trainer.max_epochs = 4 + trainer.train_loop.max_epochs = 4 trainer.fit(BoringModel()) ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b2fc8ff266122..f04061a23e096 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1924,7 +1924,7 @@ def compare_optimizers(): trainer.fit(model) compare_optimizers() - trainer.max_epochs = 2 # simulate multiple fit calls + trainer.train_loop.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() From ecc3a432b9950aa33d99baaad4d337a25002648f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:38:07 +0200 Subject: [PATCH 06/17] max / min steps --- .../trainer/connectors/debugging_connector.py | 2 +- pytorch_lightning/trainer/properties.py | 10 ++++++++++ pytorch_lightning/trainer/training_loop.py | 4 ++-- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++-- pytorch_lightning/tuner/lr_finder.py | 6 +++--- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index b9d0e30ff19ce..e250e90d24f62 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -58,7 +58,7 @@ def on_init_start( limit_val_batches = fast_dev_run limit_test_batches = fast_dev_run limit_predict_batches = fast_dev_run - self.trainer.max_steps = fast_dev_run + self.trainer.train_loop.max_steps = fast_dev_run self.trainer.num_sanity_val_steps = 0 self.trainer.train_loop.max_epochs = 1 val_check_interval = 1.0 diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 79edfda95470c..b0f0c819fe7f8 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -502,5 +502,15 @@ def max_epochs(self): @property def min_epochs(self): return self.train_loop.min_epochs + + @property + def max_steps(self): + return self.train_loop.min_epochs + + @property + def min_steps(self): + return self.train_loop.min_epochs + + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 886e80708bc3d..4e495266df6a3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -72,8 +72,8 @@ def __init__( self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.trainer.max_steps = max_steps - self.trainer.min_steps = min_steps + self.max_steps = max_steps + self.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 4c0146cf3036d..71750ab06a249 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -116,7 +116,7 @@ def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.train_loop.current_epoch = 0 - trainer.max_steps = steps_per_trial # take few steps + trainer.train_loop.max_steps = steps_per_trial # take few steps trainer.weights_summary = None # not needed before full run trainer.logger = DummyLogger() trainer.callbacks = [] # not needed before full run @@ -128,7 +128,7 @@ def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch'] - trainer.max_steps = trainer.__dumped_params['max_steps'] + trainer.train_loop.max_steps = trainer.__dumped_params['max_steps'] trainer.weights_summary = trainer.__dumped_params['weights_summary'] trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 01f48c66ad201..01a233601f0ba 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -230,7 +230,7 @@ def lr_find( trainer.logger = DummyLogger() # Max step set to number of iterations - trainer.max_steps = num_training + trainer.train_loop.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: @@ -296,7 +296,7 @@ def __lr_finder_restore_params(trainer, model): trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] - trainer.max_steps = trainer.__dumped_params['max_steps'] + trainer.train_loop.max_steps = trainer.__dumped_params['max_steps'] model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] del trainer.__dumped_params @@ -364,7 +364,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data # Check if we diverging if self.early_stop_threshold is not None: if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: - trainer.max_steps = current_step # stop signal + trainer.train_loop.max_steps = current_step # stop signal if self.progress_bar: self.progress_bar.close() From 418b80b83caa2027ac25e14f80d7bbec489cfc05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:42:36 +0200 Subject: [PATCH 07/17] min /max steps --- pytorch_lightning/trainer/properties.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b0f0c819fe7f8..5ffbd1aa0136d 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -505,11 +505,11 @@ def min_epochs(self): @property def max_steps(self): - return self.train_loop.min_epochs + return self.train_loop.max_steps @property def min_steps(self): - return self.train_loop.min_epochs + return self.train_loop.min_steps # Used to represent the concrete type TrainerProperties class methods are called on. diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4e495266df6a3..5f55d02bf5760 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -92,7 +92,7 @@ def optimizer_freq_cumsum(self): return self._optimizer_freq_cumsum def should_skip_training(self): - should_by_max_steps = self.trainer.max_steps is not None and self.global_step >= self.trainer.max_steps + should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 @@ -530,7 +530,7 @@ def run_training_epoch(self): # max steps reached, end training if ( - self.trainer.max_steps is not None and self.trainer.max_steps <= self.global_step + 1 + self.max_steps is not None and self.max_steps <= self.global_step + 1 and self._accumulated_batches_reached() ): break From 2c6a37f41bfd32e3b5c2f0e2865786ba3cfea458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:48:17 +0200 Subject: [PATCH 08/17] batch idx --- pytorch_lightning/callbacks/progress.py | 2 +- .../logger_connector/epoch_result_store.py | 2 +- .../trainer/connectors/optimizer_connector.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 14 +++++++------- pytorch_lightning/tuner/lr_finder.py | 6 +++--- tests/callbacks/test_progress_bar.py | 4 ++-- tests/models/test_hooks.py | 4 ++-- 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index be9d2f44356f5..45e9e55e69bf0 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -200,7 +200,7 @@ def on_init_end(self, trainer): self._trainer = trainer def on_train_start(self, trainer, pl_module): - self._train_batch_idx = trainer.batch_idx + self._train_batch_idx = trainer.train_loop.batch_idx def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index a265ac8a35d55..09d66c13502a2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -253,7 +253,7 @@ def info(self): """ model_ref = self.trainer.lightning_module return { - "batch_idx": self.trainer.batch_idx, + "batch_idx": self.trainer.train_loop.batch_idx, "fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name, "dataloader_idx": model_ref._current_dataloader_idx or 0, "opt_idx": self._opt_idx or 0, diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index d45cbad927936..e7fbdf9b18c02 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -46,7 +46,7 @@ def update_learning_rates( if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices: continue - current_idx = self.trainer.batch_idx if interval == 'step' else self.trainer.current_epoch + current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero @@ -86,7 +86,7 @@ def update_learning_rates( if self.trainer.dev_debugger.enabled: self.trainer.dev_debugger.track_lr_schedulers_update( - self.trainer.batch_idx, + self.trainer.train_loop.batch_idx, interval, scheduler_idx, old_lr, diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 10f15a196e37b..e1918ba2d0c9a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -992,7 +992,7 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: opt_indices=[ opt_idx for opt_idx, _ in self.train_loop.get_optimizers_iterable(batch_idx=( - self.total_batch_idx - 1 + self.train_loop.total_batch_idx - 1 )) # Select the optimizers which were used in the last batch of the epoch ], ) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5f55d02bf5760..0b7411fb632cc 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -63,8 +63,8 @@ def __init__( self.trainer.should_stop = False self.trainer.state = TrainerState() - self.trainer.total_batch_idx = 0 - self.trainer.batch_idx = 0 + self.total_batch_idx = 0 + self.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None @@ -242,7 +242,7 @@ def get_optimizers_iterable(self, batch_idx=None): return list(enumerate(self.trainer.optimizers)) if batch_idx is None: - batch_idx = self.trainer.total_batch_idx + batch_idx = self.total_batch_idx optimizers_loop_length = self.optimizer_freq_cumsum[-1] current_place_in_loop = batch_idx % optimizers_loop_length @@ -480,7 +480,7 @@ def run_training_epoch(self): is_last_batch = None for batch_idx, (batch, is_last_batch) in train_dataloader: - self.trainer.batch_idx = batch_idx + self.batch_idx = batch_idx self.trainer.is_last_batch = is_last_batch # ------------------------------------ @@ -541,7 +541,7 @@ def run_training_epoch(self): if self.trainer.should_stop: break - self.trainer.total_batch_idx += 1 + self.total_batch_idx += 1 # stop epoch if we limited the number of training batches if self._num_training_batches_reached(is_last_batch): @@ -892,10 +892,10 @@ def increment_accumulated_grad_global_step(self): ) def _accumulated_batches_reached(self): - return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 + return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): - return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + return (self.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 01a233601f0ba..ffc83b9f3c1ec 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -255,7 +255,7 @@ def lr_find( # Transfer results from callback to lr finder object lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) - lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose + lr_finder._total_batch_idx = trainer.train_loop.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: @@ -338,7 +338,7 @@ def __init__( def on_batch_start(self, trainer, pl_module): """ Called before each training batch, logs the lr that will be used """ - if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return if self.progress_bar_refresh_rate and self.progress_bar is None: @@ -348,7 +348,7 @@ def on_batch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): """ Called when the training batch ends, logs the calculated loss """ - if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return if self.progress_bar: diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 76f1e4cb0570f..2a33fbf0c1455 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -191,11 +191,11 @@ class CurrentProgressBar(ProgressBar): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) - assert self.train_batch_idx == trainer.batch_idx + assert self.train_batch_idx == trainer.train_loop.batch_idx def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - assert self.train_batch_idx == trainer.batch_idx + 1 + assert self.train_batch_idx == trainer.train_loop.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx self.train_batches_seen += 1 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 24bf29a9e2eac..f19716e31b74d 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -238,10 +238,10 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): trainer = Trainer(max_epochs=max_epochs) trainer.fit(model) if batch_idx_ > len(model.val_dataloader()) - 1: - assert trainer.batch_idx == len(model.val_dataloader()) - 1 + assert trainer.train_loop.batch_idx == len(model.val_dataloader()) - 1 assert trainer.global_step == len(model.val_dataloader()) * max_epochs else: - assert trainer.batch_idx == batch_idx_ + assert trainer.train_loop.batch_idx == batch_idx_ assert trainer.global_step == (batch_idx_ + 1) * max_epochs From e6425fc104900e600d6bdac4e4501f7a133e7c2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:48:48 +0200 Subject: [PATCH 09/17] attrs --- pytorch_lightning/trainer/training_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0b7411fb632cc..5f58504d41d2c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -93,7 +93,7 @@ def optimizer_freq_cumsum(self): def should_skip_training(self): should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps - should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs + should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 def on_train_start(self): @@ -888,7 +888,7 @@ def increment_accumulated_grad_global_step(self): # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: self.global_step = self.trainer.accelerator.update_global_step( - self.trainer.total_batch_idx, self.global_step + self.total_batch_idx, self.global_step ) def _accumulated_batches_reached(self): From 1206ee99c829afc6f2e95f742554ccf327d259e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:52:03 +0200 Subject: [PATCH 10/17] formatting --- pytorch_lightning/trainer/trainer.py | 7 +++---- pytorch_lightning/trainer/training_loop.py | 19 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e1918ba2d0c9a..55dfbaa4490d6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -990,10 +990,9 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: self.optimizer_connector.update_learning_rates( interval='epoch', opt_indices=[ - opt_idx - for opt_idx, _ in self.train_loop.get_optimizers_iterable(batch_idx=( - self.train_loop.total_batch_idx - 1 - )) # Select the optimizers which were used in the last batch of the epoch + opt_idx for opt_idx, _ in self.train_loop.get_optimizers_iterable( + batch_idx=(self.train_loop.total_batch_idx - 1) + ) # Select the optimizers which were used in the last batch of the epoch ], ) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5f58504d41d2c..ee07f2637485c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -38,13 +38,14 @@ class TrainLoop: def __init__( - self, - trainer, multiple_trainloader_mode: str, - max_epochs: Optional[int], - min_epochs: Optional[int], - max_steps: Optional[int], - min_steps: Optional[int], - num_sanity_val_steps: int, + self, + trainer, + multiple_trainloader_mode: str, + max_epochs: Optional[int], + min_epochs: Optional[int], + max_steps: Optional[int], + min_steps: Optional[int], + num_sanity_val_steps: int, ): self.trainer = trainer self.accumulated_loss = None @@ -887,9 +888,7 @@ def increment_accumulated_grad_global_step(self): # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: - self.global_step = self.trainer.accelerator.update_global_step( - self.total_batch_idx, self.global_step - ) + self.global_step = self.trainer.accelerator.update_global_step(self.total_batch_idx, self.global_step) def _accumulated_batches_reached(self): return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 From 9464c8370c7d79f62456aa70dd729a832a19318f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 20:55:49 +0200 Subject: [PATCH 11/17] move state init --- pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 55dfbaa4490d6..41c8f32d9f5f6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -54,7 +54,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus +from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder @@ -308,6 +308,7 @@ def __init__( """ super().__init__() Trainer._log_api_event("init") + self.state = TrainerState() distributed_backend = distributed_backend or accelerator # init connectors diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ee07f2637485c..b0c9e3615003a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -22,7 +22,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info @@ -62,7 +61,6 @@ def __init__( self.global_step = 0 self.current_epoch = 0 self.trainer.should_stop = False - self.trainer.state = TrainerState() self.total_batch_idx = 0 self.batch_idx = 0 From f050cd6746264cb5620b346bd7398af0cd5c0861 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 May 2021 19:00:55 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/governance.rst | 2 -- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index fac8b68e1df53..5b1f9bd1916c1 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -38,5 +38,3 @@ Alumni - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) - - diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 41c8f32d9f5f6..01ba79303d283 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -54,7 +54,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus, TrainerState +from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder From 3678402f18530076162510c7db248126d33d07db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 23:19:34 +0200 Subject: [PATCH 13/17] update tuner --- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 71750ab06a249..120a95a5084b1 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -144,7 +144,7 @@ def _run_power_scaling( """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() - trainer.global_step = 0 # reset after each try + trainer.train_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) @@ -175,7 +175,7 @@ def _run_binsearch_scaling( count = 0 while True: garbage_collection_cuda() - trainer.global_step = 0 # reset after each try + trainer.train_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) From 1799c8b711086cc35a7cdcf2c1cdc59446758720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 23:24:16 +0200 Subject: [PATCH 14/17] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fdb02fb23e851..2d3e5c278f915 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) +- Refactored Loops + * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) ### Deprecated From 68619d86595d9bea5833fa9a8f8be6d6b0730976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 23:51:03 +0200 Subject: [PATCH 15/17] update max_epocs --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index df3edf17729bd..3ec7774d5f8b6 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -159,7 +159,7 @@ def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): self._max_epochs = trainer.max_epochs if self._model_contains_batch_norm: # virtually increase max_epochs to perform batch norm update on latest epoch. - trainer.max_epochs += 1 + trainer.train_loop.max_epochs += 1 def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if trainer.current_epoch == self.swa_start: @@ -232,7 +232,7 @@ def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches trainer.num_training_batches -= 1 - trainer.max_epochs -= 1 + trainer.train_loop.max_epochs -= 1 self.reset_momenta() elif trainer.current_epoch == self.swa_end: # Last SWA epoch. Transfer weights from average model to pl_module From b78d756eba8d7d194f420ad0e5f5670762252d5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 7 May 2021 23:55:42 +0200 Subject: [PATCH 16/17] someone ordered types and left a big tip --- pytorch_lightning/trainer/properties.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 5ffbd1aa0136d..ff12e5c6e9053 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -488,27 +488,27 @@ def sanity_checking(self, val: bool) -> None: self.state.stage = None @property - def global_step(self): + def global_step(self) -> int: return self.train_loop.global_step @property - def current_epoch(self): + def current_epoch(self) -> int: return self.train_loop.current_epoch @property - def max_epochs(self): + def max_epochs(self) -> Optional[int]: return self.train_loop.max_epochs @property - def min_epochs(self): + def min_epochs(self) -> Optional[int]: return self.train_loop.min_epochs @property - def max_steps(self): + def max_steps(self) -> Optional[int]: return self.train_loop.max_steps @property - def min_steps(self): + def min_steps(self) -> Optional[int]: return self.train_loop.min_steps From efda349afd97db9412e4d9c2e94f21eb33b24505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 8 May 2021 00:36:59 +0200 Subject: [PATCH 17/17] bool sig --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b0c9e3615003a..ecdadcbddc1d9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -90,7 +90,7 @@ def optimizer_freq_cumsum(self): self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum - def should_skip_training(self): + def should_skip_training(self) -> bool: should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0