From edc267bb54d3a034264012333aba89d82df48a82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Jun 2021 03:58:43 +0200 Subject: [PATCH 1/3] rename all --- CHANGELOG.md | 3 +++ benchmarks/test_basic_parity.py | 2 +- pytorch_lightning/callbacks/finetuning.py | 2 +- pytorch_lightning/callbacks/progress.py | 2 +- .../callbacks/stochastic_weight_avg.py | 8 ++++---- pytorch_lightning/core/lightning.py | 4 ++-- .../connectors/checkpoint_connector.py | 4 ++-- .../trainer/connectors/debugging_connector.py | 4 ++-- .../logger_connector/logger_connector.py | 4 ++-- .../trainer/connectors/optimizer_connector.py | 4 ++-- pytorch_lightning/trainer/deprecated_api.py | 10 +++++++++- pytorch_lightning/trainer/properties.py | 19 +++++++------------ pytorch_lightning/tuner/batch_size_scaling.py | 12 ++++++------ pytorch_lightning/tuner/lr_finder.py | 16 ++++++++-------- tests/callbacks/test_progress_bar.py | 4 ++-- tests/callbacks/test_stochastic_weight_avg.py | 4 ++-- tests/checkpointing/test_model_checkpoint.py | 2 +- tests/deprecated_api/test_remove_1-6.py | 8 ++++++++ tests/loggers/test_tensorboard.py | 2 +- tests/trainer/loops/test_training_loop.py | 4 ++-- .../optimization/test_manual_optimization.py | 2 +- tests/trainer/test_trainer.py | 6 +++--- 22 files changed, 70 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23ba6c4f26411..d230309311d62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -227,6 +227,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) +- Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#xxxx](https://github.com/PyTorchLightning/pytorch-lightning/pull/xxxx)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/benchmarks/test_basic_parity.py b/benchmarks/test_basic_parity.py index 53f303693ffdb..bf2ddae2c0084 100644 --- a/benchmarks/test_basic_parity.py +++ b/benchmarks/test_basic_parity.py @@ -174,4 +174,4 @@ def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10): ) trainer.fit(model) - return trainer.train_loop.running_loss.last().item(), _hook_memory() + return trainer.fit_loop.running_loss.last().item(), _hook_memory() diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 29430d866288d..b40bbf1ffb3e8 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -285,7 +285,7 @@ def _store( def on_train_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" - for opt_idx, optimizer in trainer.train_loop.get_active_optimizers(): + for opt_idx, optimizer in trainer.fit_loop.get_active_optimizers(): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) current_param_groups = optimizer.param_groups diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 0fe05ff812e20..2fd4b8c25df19 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -200,7 +200,7 @@ def on_init_end(self, trainer): self._trainer = trainer def on_train_start(self, trainer, pl_module): - self._train_batch_idx = trainer.train_loop.batch_idx + self._train_batch_idx = trainer.fit_loop.batch_idx def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 3ec7774d5f8b6..ce36ce0273e24 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -159,7 +159,7 @@ def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): self._max_epochs = trainer.max_epochs if self._model_contains_batch_norm: # virtually increase max_epochs to perform batch norm update on latest epoch. - trainer.train_loop.max_epochs += 1 + trainer.fit_loop.max_epochs += 1 def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if trainer.current_epoch == self.swa_start: @@ -220,19 +220,19 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo # performing only one pass over the train data-loader to compute activation statistics # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. trainer.num_training_batches += 1 - trainer.train_loop._skip_backward = True + trainer.fit_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches trainer.accumulate_grad_batches = len(trainer.train_dataloader) def on_train_epoch_end(self, trainer: 'pl.Trainer', *args): - trainer.train_loop._skip_backward = False + trainer.fit_loop._skip_backward = False def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches trainer.num_training_batches -= 1 - trainer.train_loop.max_epochs -= 1 + trainer.fit_loop.max_epochs -= 1 self.reset_momenta() elif trainer.current_epoch == self.swa_end: # Last SWA epoch. Transfer weights from average model to pl_module diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c41423948f8c8..b124a39960217 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1674,7 +1674,7 @@ def get_progress_bar_dict(self): Dictionary with the items to be displayed in the progress bar. """ # call .item() only once but store elements without graphs - running_train_loss = self.trainer.train_loop.running_loss.mean() + running_train_loss = self.trainer.fit_loop.running_loss.mean() avg_training_loss = None if running_train_loss is not None: avg_training_loss = running_train_loss.cpu().item() @@ -1688,7 +1688,7 @@ def get_progress_bar_dict(self): module_tbptt_enabled = self.truncated_bptt_steps > 0 trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0 if module_tbptt_enabled or trainer_tbptt_enabled: - tqdm_dict["split_idx"] = self.trainer.train_loop.split_idx + tqdm_dict["split_idx"] = self.trainer.fit_loop.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: version = self.trainer.logger.version diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b599caf91e20d..c2a0411c0df36 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -198,8 +198,8 @@ def restore_progress(self) -> None: if not self._loaded_checkpoint: return - self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step'] - self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch'] + self.trainer.fit_loop.global_step = self._loaded_checkpoint['global_step'] + self.trainer.fit_loop.current_epoch = self._loaded_checkpoint['epoch'] # crash if max_epochs is lower then the current epoch from the checkpoint if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 0108a1045698f..e49c40bfe6baf 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -58,9 +58,9 @@ def on_init_start( limit_val_batches = fast_dev_run limit_test_batches = fast_dev_run limit_predict_batches = fast_dev_run - self.trainer.train_loop.max_steps = fast_dev_run + self.trainer.fit_loop.max_steps = fast_dev_run self.trainer.num_sanity_val_steps = 0 - self.trainer.train_loop.max_epochs = 1 + self.trainer.fit_loop.max_epochs = 1 val_check_interval = 1.0 self.trainer.check_val_every_n_epoch = 1 self.trainer.logger = DummyLogger() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index bc6e2b54f584d..883f8ee3fec8c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -211,7 +211,7 @@ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) self._split_idx = split_idx def update_train_step_metrics(self) -> None: - if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: + if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization: return # when metrics should be logged @@ -299,6 +299,6 @@ def progress_bar_metrics(self) -> Dict[str, float]: return self._progress_bar_metrics def teardown(self): - self.trainer.train_loop.results.cpu() + self.trainer.fit_loop.results.cpu() self.trainer.evaluation_loop._val_results.cpu() self.trainer.evaluation_loop._test_results.cpu() diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 2797504288bd3..eb2fb5b4e7723 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -46,7 +46,7 @@ def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices: continue - current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch + current_idx = self.trainer.fit_loop.batch_idx if interval == 'step' else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero @@ -83,7 +83,7 @@ def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] if self.trainer.dev_debugger.enabled: self.trainer.dev_debugger.track_lr_schedulers_update( - self.trainer.train_loop.batch_idx, + self.trainer.fit_loop.batch_idx, interval, scheduler_idx, old_lr, diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 7e7817d277dae..a650c6bfe73e8 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from pytorch_lightning.loops import FitLoop from pytorch_lightning.utilities import rank_zero_deprecation class DeprecatedTrainerAttributes: sanity_checking: bool + fit_loop: FitLoop @property def running_sanity_check(self) -> bool: @@ -25,3 +26,10 @@ def running_sanity_check(self) -> bool: "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5." ) return self.sanity_checking + + @property + def train_loop(self) -> FitLoop: + rank_zero_deprecation( + "`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6." + ) + return self.fit_loop diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 811d7eaa80291..c6a4ffc10d28e 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -483,39 +483,34 @@ def sanity_checking(self, val: bool) -> None: Loop properties """ - @property - def train_loop(self) -> FitLoop: - # FIXME(@awaelchli): the current train_loop should be renamed to fit_loop - return self.fit_loop - @property def global_step(self) -> int: - return self.train_loop.global_step + return self.fit_loop.global_step @property def current_epoch(self) -> int: - return self.train_loop.current_epoch + return self.fit_loop.current_epoch @property def max_epochs(self) -> Optional[int]: - return self.train_loop.max_epochs + return self.fit_loop.max_epochs @property def min_epochs(self) -> Optional[int]: - return self.train_loop.min_epochs + return self.fit_loop.min_epochs @property def max_steps(self) -> Optional[int]: - return self.train_loop.max_steps + return self.fit_loop.max_steps @property def min_steps(self) -> Optional[int]: - return self.train_loop.min_steps + return self.fit_loop.min_steps @property def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop]]: if self.training: - return self.train_loop + return self.fit_loop elif self.sanity_checking or self.evaluating: return self.evaluation_loop diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 0103dcd4c1805..f23a7f883c5a2 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -115,8 +115,8 @@ def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None: def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times - trainer.train_loop.current_epoch = 0 - trainer.train_loop.max_steps = steps_per_trial # take few steps + trainer.fit_loop.current_epoch = 0 + trainer.fit_loop.max_steps = steps_per_trial # take few steps trainer.weights_summary = None # not needed before full run trainer.logger = DummyLogger() trainer.callbacks = [] # not needed before full run @@ -127,8 +127,8 @@ def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] - trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch'] - trainer.train_loop.max_steps = trainer.__dumped_params['max_steps'] + trainer.fit_loop.current_epoch = trainer.__dumped_params['current_epoch'] + trainer.fit_loop.max_steps = trainer.__dumped_params['max_steps'] trainer.weights_summary = trainer.__dumped_params['weights_summary'] trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] @@ -144,7 +144,7 @@ def _run_power_scaling( """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() - trainer.train_loop.global_step = 0 # reset after each try + trainer.fit_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) @@ -178,7 +178,7 @@ def _run_binsearch_scaling( count = 0 while True: garbage_collection_cuda() - trainer.train_loop.global_step = 0 # reset after each try + trainer.fit_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 8a595fc9da35a..29a93d3916aea 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -230,7 +230,7 @@ def lr_find( trainer.logger = DummyLogger() # Max step set to number of iterations - trainer.train_loop.max_steps = num_training + trainer.fit_loop.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: @@ -255,7 +255,7 @@ def lr_find( # Transfer results from callback to lr finder object lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) - lr_finder._total_batch_idx = trainer.train_loop.total_batch_idx # for debug purpose + lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: @@ -297,8 +297,8 @@ def __lr_finder_restore_params(trainer, model): trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] - trainer.train_loop.max_steps = trainer.__dumped_params['max_steps'] - trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch'] + trainer.fit_loop.max_steps = trainer.__dumped_params['max_steps'] + trainer.fit_loop.current_epoch = trainer.__dumped_params['current_epoch'] model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] del trainer.__dumped_params @@ -340,7 +340,7 @@ def __init__( def on_batch_start(self, trainer, pl_module): """ Called before each training batch, logs the lr that will be used """ - if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return if self.progress_bar_refresh_rate and self.progress_bar is None: @@ -350,13 +350,13 @@ def on_batch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): """ Called when the training batch ends, logs the calculated loss """ - if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return if self.progress_bar: self.progress_bar.update() - current_loss = trainer.train_loop.running_loss.last().item() + current_loss = trainer.fit_loop.running_loss.last().item() current_step = trainer.global_step # Avg loss (loss with momentum) + smoothing @@ -366,7 +366,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data # Check if we diverging if self.early_stop_threshold is not None: if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: - trainer.train_loop.max_steps = current_step # stop signal + trainer.fit_loop.max_steps = current_step # stop signal if self.progress_bar: self.progress_bar.close() diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 5e768bb8fec8d..2eb1bdf690c30 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -194,11 +194,11 @@ class CurrentProgressBar(ProgressBar): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) - assert self.train_batch_idx == trainer.train_loop.batch_idx + assert self.train_batch_idx == trainer.fit_loop.batch_idx def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - assert self.train_batch_idx == trainer.train_loop.batch_idx + 1 + assert self.train_batch_idx == trainer.fit_loop.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx self.train_batches_seen += 1 diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 81efc12b34662..e92f8e71da086 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -74,7 +74,7 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) - assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end) + assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) assert trainer.lr_schedulers[0]["interval"] == "epoch" @@ -92,7 +92,7 @@ def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) # make sure these are correctly set again - assert not trainer.train_loop._skip_backward + assert not trainer.fit_loop._skip_backward assert trainer.accumulate_grad_batches == 2 assert trainer.num_training_batches == 5 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 62b9d8364b01c..0a4311001fc96 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1297,7 +1297,7 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir): progress_bar_refresh_rate=0, ) trainer.fit(BoringModel()) - trainer.train_loop.max_epochs = 4 + trainer.fit_loop.max_epochs = 4 trainer.fit(BoringModel()) ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index cb150cb013ec2..51c8e52713a31 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -212,3 +212,11 @@ def test_v1_6_0_early_stopping_monitor(tmpdir): " For backward compatibility, setting this to `early_stop_on`." ): EarlyStopping() + + +def test_v1_6_0_train_loop(tmpdir): + trainer = Trainer() + with pytest.deprecated_call( + match=r"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6." + ): + _ = trainer.train_loop diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index ffd89a0c14984..b8bafae8508e8 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -276,7 +276,7 @@ def __init__(self): def training_step(self, *args): self.log('foo', 1, on_step=True, on_epoch=True) - if not self.trainer.train_loop.should_accumulate(): + if not self.trainer.fit_loop.should_accumulate(): if self.trainer.logger_connector.should_update_logs: self.indexes.append(self.trainer.global_step) return super().training_step(*args) diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 193399473dc37..c0fde2983985d 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -108,10 +108,10 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10) trainer.fit(model) if batch_idx_ > trainer.num_training_batches - 1: - assert trainer.train_loop.batch_idx == trainer.num_training_batches - 1 + assert trainer.fit_loop.batch_idx == trainer.num_training_batches - 1 assert trainer.global_step == trainer.num_training_batches * max_epochs else: - assert trainer.train_loop.batch_idx == batch_idx_ + assert trainer.fit_loop.batch_idx == batch_idx_ assert trainer.global_step == batch_idx_ * max_epochs diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 9abc9b47ab82a..75a509e07c26b 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -611,7 +611,7 @@ def optimizer_closure(): opt.step(closure=optimizer_closure) weight_after = self.layer.weight.clone() - if not self.trainer.train_loop.should_accumulate(): + if not self.trainer.fit_loop.should_accumulate(): assert not torch.equal(weight_before, weight_after) else: assert self.layer.weight.grad is not None diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bdfaf277f86f1..6c93c43d0cfe7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -371,8 +371,8 @@ def mock_save_function(filepath, *args): # emulate callback's calls during the training for i, loss in enumerate(losses): - trainer.train_loop.current_epoch = i - trainer.train_loop.global_step = i + trainer.fit_loop.current_epoch = i + trainer.fit_loop.global_step = i trainer.logger_connector.callback_metrics.update({"checkpoint_on": loss}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) @@ -1765,7 +1765,7 @@ def compare_optimizers(): trainer.fit(model) compare_optimizers() - trainer.train_loop.max_epochs = 2 # simulate multiple fit calls + trainer.fit_loop.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() From 4383ce113372d89fee49a137c91848d45e67a4a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 18 Jun 2021 04:01:54 +0200 Subject: [PATCH 2/3] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d230309311d62..00bdfc666270c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -227,7 +227,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) -- Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#xxxx](https://github.com/PyTorchLightning/pytorch-lightning/pull/xxxx)) +- Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#8025](https://github.com/PyTorchLightning/pytorch-lightning/pull/8025)) ### Removed From 6cdea7d056c916ad832eef109f6a347f48f02664 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 02:02:12 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-6.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 51c8e52713a31..2a2eed7a2739d 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -217,6 +217,6 @@ def test_v1_6_0_early_stopping_monitor(tmpdir): def test_v1_6_0_train_loop(tmpdir): trainer = Trainer() with pytest.deprecated_call( - match=r"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6." + match=r"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6." ): _ = trainer.train_loop