Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consolidate callbacks and hooks #950

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868))
- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876))
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))

### Changed
Expand Down
4 changes: 0 additions & 4 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,4 @@ Callback Class
_del_model,
_save_model,
_abc_impl,
on_epoch_end,
on_train_end,
on_epoch_start,
check_monitor_top_k,
on_train_start,
12 changes: 2 additions & 10 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,11 @@ class Callback(abc.ABC):
"""Abstract base class used to build new callbacks."""

def on_init_start(self, trainer):
"""Called when the trainer initialization begins."""
"""Called when the trainer initialization begins, model has not yet been set."""
pass

def on_init_end(self, trainer):
"""Called when the trainer initialization ends."""
pass

def on_fit_start(self, trainer, pl_module):
"""Called when the fit begins."""
pass

def on_fit_end(self, trainer, pl_module):
"""Called when the fit ends."""
"""Called when the trainer initialization ends, model has not yet been set."""
pass

def on_epoch_start(self, trainer, pl_module):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1

self.on_train_start(None, None)

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def check_monitor_top_k(self, current):
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.proc_rank != 0:
return

logs = trainer.callback_metrics
epoch = trainer.current_epoch
self.epochs_since_last_check += 1
Expand Down
22 changes: 6 additions & 16 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,15 @@ def __init__(self):
self.callbacks: list[Callback] = []
self.get_model: Callable = ...

def on_init_start(self, trainer):
"""Called when the trainer initialization begins."""
def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_init_start(trainer)
callback.on_init_start(self)

def on_init_end(self, trainer):
"""Called when the trainer initialization ends."""
def on_init_end(self):
"""Called when the trainer initialization ends, model has not yet been set."""
for callback in self.callbacks:
callback.on_init_end(trainer)

def on_fit_start(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_start(self, self.get_model())

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end(self, self.get_model())
callback.on_init_end(self)

def on_epoch_start(self):
"""Called when the epoch begins."""
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,13 @@ def run_evaluation(self, test_mode: bool = False):
else:
self.val_progress_bar.close()

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
self.checkpoint_callback.on_validation_end(self, self.get_model())

# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
# model checkpointing
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def on_train_end(self):

# Init callbacks
self.callbacks = callbacks
self.on_init_start(self)
self.on_init_start()

# benchmarking
self.benchmark = benchmark
Expand Down Expand Up @@ -808,7 +808,7 @@ def on_train_end(self):
self.init_amp(use_amp)

# Callback system
self.on_init_end(self)
self.on_init_end()

@property
def slurm_job_id(self) -> int:
Expand Down Expand Up @@ -941,9 +941,6 @@ def fit(
# bind logger
model.logger = self.logger

# Fit begin callbacks
self.on_fit_start()

# set up the passed in dataloaders (if needed)
self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)

Expand Down Expand Up @@ -987,9 +984,6 @@ def fit(

self.run_pretrain_routine(model)

# Fit end callbacks
self.on_fit_end()

# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
Expand Down
99 changes: 51 additions & 48 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,15 @@ def train(self):
self.reset_train_dataloader(model)
self.reset_val_dataloader(model)

# Train begin callbacks
model.on_train_start()
self.on_train_start()
# Train start events
with self.profiler.profile('on_train_start'):
# callbacks
self.on_train_start()
# initialize early stop callback
if self.early_stop_callback is not None:
self.early_stop_callback.on_train_start(self, self.get_model())
Borda marked this conversation as resolved.
Show resolved Hide resolved
# model hooks
model.on_train_start()

try:
# run all epochs
Expand Down Expand Up @@ -378,9 +384,6 @@ def train(self):
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# -----------------
# RUN TNG EPOCH
# -----------------
Expand All @@ -400,23 +403,21 @@ def train(self):
self.reduce_lr_on_plateau_scheduler.step(val_loss)

if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
model.on_train_end()
self.on_train_end()
self.run_training_teardown()
return

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

# TODO wrap this logic into the callback
Borda marked this conversation as resolved.
Show resolved Hide resolved
if self.enable_early_stop and not self.disable_validation and is_val_epoch:
if ((met_min_epochs and met_min_steps) or self.fast_dev_run):
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop:
self.run_training_teardown()
self.on_train_end()
return

self.run_training_teardown()
Expand All @@ -425,18 +426,17 @@ def train(self):
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
self.run_training_teardown()

# Train end callbacks
self.on_train_end()

def run_training_epoch(self):

# Epoch begin callbacks
self.on_epoch_start()

# before epoch hook
if self.is_function_implemented('on_epoch_start'):
model = self.get_model()
with self.profiler.profile('on_epoch_start'):
# Epoch start events
with self.profiler.profile('on_epoch_start'):
# callbacks
self.on_epoch_start()
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
# model hooks
if self.is_function_implemented('on_epoch_start'):
model = self.get_model()
model.on_epoch_start()
jeremyjordan marked this conversation as resolved.
Show resolved Hide resolved

# reset train dataloader
Expand Down Expand Up @@ -516,15 +516,15 @@ def run_training_epoch(self):
if early_stop_epoch or self.fast_dev_run:
break

# epoch end hook
if self.is_function_implemented('on_epoch_end'):
model = self.get_model()
with self.profiler.profile('on_epoch_end'):
# Epoch end events
with self.profiler.profile('on_epoch_end'):
# callbacks
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
model = self.get_model()
model.on_epoch_end()
jeremyjordan marked this conversation as resolved.
Show resolved Hide resolved

# Epoch begin callbacks
self.on_epoch_end()

def run_training_batch(self, batch, batch_idx):
# track grad norms
grad_norm_dic = {}
Expand All @@ -538,17 +538,16 @@ def run_training_batch(self, batch, batch_idx):
if batch is None:
return 0, grad_norm_dic, {}

# Batch begin callbacks
self.on_batch_start()

# hook
if self.is_function_implemented('on_batch_start'):
model_ref = self.get_model()
with self.profiler.profile('on_batch_start'):
response = model_ref.on_batch_start(batch)

if response == -1:
return -1, grad_norm_dic, {}
# Batch start events
with self.profiler.profile('on_batch_start'):
# callbacks
self.on_batch_start()
# hooks
if self.is_function_implemented('on_batch_start'):
model = self.get_model()
response = model.on_batch_start(batch)
jeremyjordan marked this conversation as resolved.
Show resolved Hide resolved
if response == -1:
return -1, grad_norm_dic, {}

splits = [batch]
if self.truncated_bptt_steps is not None:
Expand Down Expand Up @@ -643,15 +642,15 @@ def optimizer_closure():
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])

# activate batch end hook
if self.is_function_implemented('on_batch_end'):
model = self.get_model()
with self.profiler.profile('on_batch_end'):
# Batch end events
with self.profiler.profile('on_batch_end'):
# callbacks
self.on_batch_end()
# model hooks
if self.is_function_implemented('on_batch_end'):
model = self.get_model()
jeremyjordan marked this conversation as resolved.
Show resolved Hide resolved
model.on_batch_end()

# Batch end callbacks
self.on_batch_end()

# update progress bar
if batch_idx % self.progress_bar_refresh_rate == 0:
self.main_progress_bar.update(self.progress_bar_refresh_rate)
Expand All @@ -666,12 +665,16 @@ def optimizer_closure():
return 0, grad_norm_dic, all_log_metrics

def run_training_teardown(self):
model = self.get_model()

self.main_progress_bar.close()

# Train end events
with self.profiler.profile('on_train_end'):
model.on_train_end()
# callbacks
self.on_train_end()
# model hooks
if self.is_function_implemented('on_train_end'):
model = self.get_model()
model.on_train_end()

if self.logger is not None:
self.logger.finalize("success")
Expand Down
Loading