Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename old Trainer.train_loop -> Trainer.fit_loop #8025

Merged
merged 8 commits into from
Jun 22, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))


- Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#8025](https://github.com/PyTorchLightning/pytorch-lightning/pull/8025))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10):
)
trainer.fit(model)

return trainer.train_loop.running_loss.last().item(), _hook_memory()
return trainer.fit_loop.running_loss.last().item(), _hook_memory()
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _store(

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.train_loop.get_active_optimizers():
for opt_idx, optimizer in trainer.fit_loop.get_active_optimizers():
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def on_init_end(self, trainer):
self._trainer = trainer

def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.train_loop.batch_idx
self._train_batch_idx = trainer.fit_loop.batch_idx

def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.train_loop.max_epochs += 1
trainer.fit_loop.max_epochs += 1

def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
if trainer.current_epoch == self.swa_start:
Expand Down Expand Up @@ -220,19 +220,19 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
trainer.num_training_batches += 1
trainer.train_loop._skip_backward = True
trainer.fit_loop._skip_backward = True
self._accumulate_grad_batches = trainer.accumulate_grad_batches
trainer.accumulate_grad_batches = len(trainer.train_dataloader)

def on_train_epoch_end(self, trainer: 'pl.Trainer', *args):
trainer.train_loop._skip_backward = False
trainer.fit_loop._skip_backward = False

def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.train_loop.max_epochs -= 1
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,7 +1675,7 @@ def get_progress_bar_dict(self):
Dictionary with the items to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_train_loss = self.trainer.train_loop.running_loss.mean()
running_train_loss = self.trainer.fit_loop.running_loss.mean()
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
Expand All @@ -1689,7 +1689,7 @@ def get_progress_bar_dict(self):
module_tbptt_enabled = self.truncated_bptt_steps > 0
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
if module_tbptt_enabled or trainer_tbptt_enabled:
tqdm_dict["split_idx"] = self.trainer.train_loop.split_idx
tqdm_dict["split_idx"] = self.trainer.fit_loop.split_idx

if self.trainer.logger is not None and self.trainer.logger.version is not None:
version = self.trainer.logger.version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def restore_progress(self) -> None:
if not self._loaded_checkpoint:
return

self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step']
self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch']
self.trainer.fit_loop.global_step = self._loaded_checkpoint['global_step']
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint['epoch']

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def on_init_start(
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
limit_predict_batches = fast_dev_run
self.trainer.train_loop.max_steps = fast_dev_run
self.trainer.fit_loop.max_steps = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.train_loop.max_epochs = 1
self.trainer.fit_loop.max_epochs = 1
val_check_interval = 1.0
self.trainer.check_val_every_n_epoch = 1
self.trainer.logger = DummyLogger()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any)
self._split_idx = split_idx

def update_train_step_metrics(self) -> None:
if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return

# when metrics should be logged
Expand Down Expand Up @@ -299,6 +299,6 @@ def progress_bar_metrics(self) -> Dict[str, float]:
return self._progress_bar_metrics

def teardown(self):
self.trainer.train_loop.results.cpu()
self.trainer.fit_loop.results.cpu()
self.trainer.evaluation_loop._val_results.cpu()
self.trainer.evaluation_loop._test_results.cpu()
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def update_learning_rates(
if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
continue

current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx = self.trainer.fit_loop.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
Expand Down Expand Up @@ -92,7 +92,7 @@ def update_learning_rates(

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.train_loop.batch_idx,
self.trainer.fit_loop.batch_idx,
interval,
scheduler_idx,
old_lr,
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops import FitLoop
from pytorch_lightning.utilities import rank_zero_deprecation


class DeprecatedTrainerAttributes:

sanity_checking: bool
fit_loop: FitLoop

@property
def running_sanity_check(self) -> bool:
rank_zero_deprecation(
"`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."
)
return self.sanity_checking

@property
def train_loop(self) -> FitLoop:
rank_zero_deprecation(
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
)
return self.fit_loop
19 changes: 7 additions & 12 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,39 +483,34 @@ def sanity_checking(self, val: bool) -> None:
Loop properties
"""

@property
def train_loop(self) -> FitLoop:
# FIXME(@awaelchli): the current train_loop should be renamed to fit_loop
return self.fit_loop

@property
def global_step(self) -> int:
return self.train_loop.global_step
return self.fit_loop.global_step

@property
def current_epoch(self) -> int:
return self.train_loop.current_epoch
return self.fit_loop.current_epoch

@property
def max_epochs(self) -> Optional[int]:
return self.train_loop.max_epochs
return self.fit_loop.max_epochs

@property
def min_epochs(self) -> Optional[int]:
return self.train_loop.min_epochs
return self.fit_loop.min_epochs

@property
def max_steps(self) -> Optional[int]:
return self.train_loop.max_steps
return self.fit_loop.max_steps

@property
def min_steps(self) -> Optional[int]:
return self.train_loop.min_steps
return self.fit_loop.min_steps

@property
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationDataLoaderLoop]]:
if self.training:
return self.train_loop
return self.fit_loop
elif self.sanity_checking or self.evaluating:
return self.evaluation_loop

Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None:
def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None:
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.train_loop.current_epoch = 0
trainer.train_loop.max_steps = steps_per_trial # take few steps
trainer.fit_loop.current_epoch = 0
trainer.fit_loop.max_steps = steps_per_trial # take few steps
trainer.weights_summary = None # not needed before full run
trainer.logger = DummyLogger()
trainer.callbacks = [] # not needed before full run
Expand All @@ -127,8 +127,8 @@ def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule

def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None:
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch']
trainer.train_loop.max_steps = trainer.__dumped_params['max_steps']
trainer.fit_loop.current_epoch = trainer.__dumped_params['current_epoch']
trainer.fit_loop.max_steps = trainer.__dumped_params['max_steps']
trainer.weights_summary = trainer.__dumped_params['weights_summary']
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
Expand All @@ -144,7 +144,7 @@ def _run_power_scaling(
""" Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """
for _ in range(max_trials):
garbage_collection_cuda()
trainer.train_loop.global_step = 0 # reset after each try
trainer.fit_loop.global_step = 0 # reset after each try
try:
# Try fit
trainer.tuner._run(model)
Expand Down Expand Up @@ -178,7 +178,7 @@ def _run_binsearch_scaling(
count = 0
while True:
garbage_collection_cuda()
trainer.train_loop.global_step = 0 # reset after each try
trainer.fit_loop.global_step = 0 # reset after each try
try:
# Try fit
trainer.tuner._run(model)
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def lr_find(
trainer.logger = DummyLogger()

# Max step set to number of iterations
trainer.train_loop.max_steps = num_training
trainer.fit_loop.max_steps = num_training

# Disable standard progress bar for fit
if trainer.progress_bar_callback:
Expand All @@ -255,7 +255,7 @@ def lr_find(

# Transfer results from callback to lr finder object
lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.train_loop.total_batch_idx # for debug purpose
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose

# Reset model state
if trainer.is_global_zero:
Expand Down Expand Up @@ -297,8 +297,8 @@ def __lr_finder_restore_params(trainer, model):
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.train_loop.max_steps = trainer.__dumped_params['max_steps']
trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch']
trainer.fit_loop.max_steps = trainer.__dumped_params['max_steps']
trainer.fit_loop.current_epoch = trainer.__dumped_params['current_epoch']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
del trainer.__dumped_params

Expand Down Expand Up @@ -340,7 +340,7 @@ def __init__(

def on_batch_start(self, trainer, pl_module):
""" Called before each training batch, logs the lr that will be used """
if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar_refresh_rate and self.progress_bar is None:
Expand All @@ -350,13 +350,13 @@ def on_batch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar:
self.progress_bar.update()

current_loss = trainer.train_loop.running_loss.last().item()
current_loss = trainer.fit_loop.running_loss.last().item()
current_step = trainer.global_step

# Avg loss (loss with momentum) + smoothing
Expand All @@ -366,7 +366,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
# Check if we diverging
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.train_loop.max_steps = current_step # stop signal
trainer.fit_loop.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()

Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ class CurrentProgressBar(ProgressBar):

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
assert self.train_batch_idx == trainer.train_loop.batch_idx
assert self.train_batch_idx == trainer.fit_loop.batch_idx

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
assert self.train_batch_idx == trainer.train_loop.batch_idx + 1
assert self.train_batch_idx == trainer.fit_loop.batch_idx + 1
if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0:
assert self.main_progress_bar.n == self.train_batch_idx
self.train_batches_seen += 1
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def transfer_weights(self, *args, **kwargs):

def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end)
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
assert trainer.lr_schedulers[0]["interval"] == "epoch"
Expand All @@ -92,7 +92,7 @@ def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)

# make sure these are correctly set again
assert not trainer.train_loop._skip_backward
assert not trainer.fit_loop._skip_backward
assert trainer.accumulate_grad_batches == 2
assert trainer.num_training_batches == 5

Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir):
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())
trainer.train_loop.max_epochs = 4
trainer.fit_loop.max_epochs = 4
trainer.fit(BoringModel())

ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION)
Expand Down
8 changes: 8 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,11 @@ def training_step(self, *args):
match = r"\{'foo'\} has a `grad_fn`.*behaviour will change in v1\.6"
with pytest.deprecated_call(match=match):
trainer.fit(model)


def test_v1_6_0_train_loop(tmpdir):
trainer = Trainer()
with pytest.deprecated_call(
match=r"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
):
_ = trainer.train_loop
2 changes: 1 addition & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(self):

def training_step(self, *args):
self.log('foo', 1, on_step=True, on_epoch=True)
if not self.trainer.train_loop.should_accumulate():
if not self.trainer.fit_loop.should_accumulate():
if self.trainer.logger_connector.should_update_logs:
self.indexes.append(self.trainer.global_step)
return super().training_step(*args)
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10)
trainer.fit(model)
if batch_idx_ > trainer.num_training_batches - 1:
assert trainer.train_loop.batch_idx == trainer.num_training_batches - 1
assert trainer.fit_loop.batch_idx == trainer.num_training_batches - 1
assert trainer.global_step == trainer.num_training_batches * max_epochs
else:
assert trainer.train_loop.batch_idx == batch_idx_
assert trainer.fit_loop.batch_idx == batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs


Expand Down
Loading