Skip to content

Commit

Permalink
consolidate callbacks and hooks (#950)
Browse files Browse the repository at this point in the history
* consolidate callbacks and hooks

* ensure callbacks recieve proper arg types

* remove model from init callback events

* clean up early stopping event

* update changelog

* remove on_fit_start and on_fit_end

* fix args for on_init_start and on_init_end

* handle case where early stopping is not used

* show all callback methods

* wrap checkpoint callback logic into proper class

* fix check for main process in checkpoint callback

* move callbacks test to separate file

* refactor arg checks

* get model and call hook on same line

* define trainer_options dict in one call

* add more asserts to callback test
  • Loading branch information
jeremyjordan committed Mar 3, 2020
1 parent 1789165 commit 705e576
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 214 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868))
- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876))
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))

### Changed
Expand Down
4 changes: 0 additions & 4 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,4 @@ Callback Class
_del_model,
_save_model,
_abc_impl,
on_epoch_end,
on_train_end,
on_epoch_start,
check_monitor_top_k,
on_train_start,
12 changes: 2 additions & 10 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,11 @@ class Callback(abc.ABC):
"""Abstract base class used to build new callbacks."""

def on_init_start(self, trainer):
"""Called when the trainer initialization begins."""
"""Called when the trainer initialization begins, model has not yet been set."""
pass

def on_init_end(self, trainer):
"""Called when the trainer initialization ends."""
pass

def on_fit_start(self, trainer, pl_module):
"""Called when the fit begins."""
pass

def on_fit_end(self, trainer, pl_module):
"""Called when the fit ends."""
"""Called when the trainer initialization ends, model has not yet been set."""
pass

def on_epoch_start(self, trainer, pl_module):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1

self.on_train_start(None, None)

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def check_monitor_top_k(self, current):
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.proc_rank != 0:
return

logs = trainer.callback_metrics
epoch = trainer.current_epoch
self.epochs_since_last_check += 1
Expand Down
22 changes: 6 additions & 16 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,15 @@ def __init__(self):
self.callbacks: list[Callback] = []
self.get_model: Callable = ...

def on_init_start(self, trainer):
"""Called when the trainer initialization begins."""
def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_init_start(trainer)
callback.on_init_start(self)

def on_init_end(self, trainer):
"""Called when the trainer initialization ends."""
def on_init_end(self):
"""Called when the trainer initialization ends, model has not yet been set."""
for callback in self.callbacks:
callback.on_init_end(trainer)

def on_fit_start(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_start(self, self.get_model())

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end(self, self.get_model())
callback.on_init_end(self)

def on_epoch_start(self):
"""Called when the epoch begins."""
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,13 @@ def run_evaluation(self, test_mode: bool = False):
else:
self.val_progress_bar.close()

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
self.checkpoint_callback.on_validation_end(self, self.get_model())

# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
# model checkpointing
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def on_train_end(self):

# Init callbacks
self.callbacks = callbacks
self.on_init_start(self)
self.on_init_start()

# benchmarking
self.benchmark = benchmark
Expand Down Expand Up @@ -808,7 +808,7 @@ def on_train_end(self):
self.init_amp(use_amp)

# Callback system
self.on_init_end(self)
self.on_init_end()

@property
def slurm_job_id(self) -> int:
Expand Down Expand Up @@ -941,9 +941,6 @@ def fit(
# bind logger
model.logger = self.logger

# Fit begin callbacks
self.on_fit_start()

# set up the passed in dataloaders (if needed)
self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)

Expand Down Expand Up @@ -1006,9 +1003,6 @@ def fit(

self.run_pretrain_routine(model)

# Fit end callbacks
self.on_fit_end()

# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
Expand Down
100 changes: 49 additions & 51 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,15 @@ def train(self):
self.reset_train_dataloader(model)
self.reset_val_dataloader(model)

# Train begin callbacks
model.on_train_start()
self.on_train_start()
# Train start events
with self.profiler.profile('on_train_start'):
# callbacks
self.on_train_start()
# initialize early stop callback
if self.early_stop_callback is not None:
self.early_stop_callback.on_train_start(self, self.get_model())
# model hooks
model.on_train_start()

try:
# run all epochs
Expand Down Expand Up @@ -347,9 +353,6 @@ def train(self):
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# -----------------
# RUN TNG EPOCH
# -----------------
Expand All @@ -369,23 +372,21 @@ def train(self):
self.reduce_lr_on_plateau_scheduler.step(val_loss)

if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
model.on_train_end()
self.on_train_end()
self.run_training_teardown()
return

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

# TODO wrap this logic into the callback
if self.enable_early_stop and not self.disable_validation and is_val_epoch:
if ((met_min_epochs and met_min_steps) or self.fast_dev_run):
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop:
self.run_training_teardown()
self.on_train_end()
return

self.run_training_teardown()
Expand All @@ -394,19 +395,17 @@ def train(self):
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
self.run_training_teardown()

# Train end callbacks
self.on_train_end()

def run_training_epoch(self):

# Epoch begin callbacks
self.on_epoch_start()

# before epoch hook
if self.is_function_implemented('on_epoch_start'):
model = self.get_model()
with self.profiler.profile('on_epoch_start'):
model.on_epoch_start()
# Epoch start events
with self.profiler.profile('on_epoch_start'):
# callbacks
self.on_epoch_start()
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
# model hooks
if self.is_function_implemented('on_epoch_start'):
self.get_model().on_epoch_start()

# reset train dataloader
if self.reload_dataloaders_every_epoch:
Expand Down Expand Up @@ -485,14 +484,13 @@ def run_training_epoch(self):
if early_stop_epoch or self.fast_dev_run:
break

# epoch end hook
if self.is_function_implemented('on_epoch_end'):
model = self.get_model()
with self.profiler.profile('on_epoch_end'):
model.on_epoch_end()

# Epoch begin callbacks
self.on_epoch_end()
# Epoch end events
with self.profiler.profile('on_epoch_end'):
# callbacks
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
self.get_model().on_epoch_end()

def run_training_batch(self, batch, batch_idx):
# track grad norms
Expand All @@ -507,17 +505,15 @@ def run_training_batch(self, batch, batch_idx):
if batch is None:
return 0, grad_norm_dic, {}

# Batch begin callbacks
self.on_batch_start()

# hook
if self.is_function_implemented('on_batch_start'):
model_ref = self.get_model()
with self.profiler.profile('on_batch_start'):
response = model_ref.on_batch_start(batch)

if response == -1:
return -1, grad_norm_dic, {}
# Batch start events
with self.profiler.profile('on_batch_start'):
# callbacks
self.on_batch_start()
# hooks
if self.is_function_implemented('on_batch_start'):
response = self.get_model().on_batch_start(batch)
if response == -1:
return -1, grad_norm_dic, {}

splits = [batch]
if self.truncated_bptt_steps is not None:
Expand Down Expand Up @@ -612,14 +608,13 @@ def optimizer_closure():
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])

# activate batch end hook
if self.is_function_implemented('on_batch_end'):
model = self.get_model()
with self.profiler.profile('on_batch_end'):
model.on_batch_end()

# Batch end callbacks
self.on_batch_end()
# Batch end events
with self.profiler.profile('on_batch_end'):
# callbacks
self.on_batch_end()
# model hooks
if self.is_function_implemented('on_batch_end'):
self.get_model().on_batch_end()

# update progress bar
if batch_idx % self.progress_bar_refresh_rate == 0:
Expand All @@ -635,12 +630,15 @@ def optimizer_closure():
return 0, grad_norm_dic, all_log_metrics

def run_training_teardown(self):
model = self.get_model()

self.main_progress_bar.close()

# Train end events
with self.profiler.profile('on_train_end'):
model.on_train_end()
# callbacks
self.on_train_end()
# model hooks
if self.is_function_implemented('on_train_end'):
self.get_model().on_train_end()

if self.logger is not None:
self.logger.finalize("success")
Expand Down
Loading

0 comments on commit 705e576

Please sign in to comment.