From 705e5764178d82a77abe60e0d4027b572b2efdf2 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Date: Mon, 2 Mar 2020 23:51:32 -0500 Subject: [PATCH] consolidate callbacks and hooks (#950) * consolidate callbacks and hooks * ensure callbacks recieve proper arg types * remove model from init callback events * clean up early stopping event * update changelog * remove on_fit_start and on_fit_end * fix args for on_init_start and on_init_end * handle case where early stopping is not used * show all callback methods * wrap checkpoint callback logic into proper class * fix check for main process in checkpoint callback * move callbacks test to separate file * refactor arg checks * get model and call hook on same line * define trainer_options dict in one call * add more asserts to callback test --- CHANGELOG.md | 1 + docs/source/callbacks.rst | 4 - pytorch_lightning/callbacks/base.py | 12 +- pytorch_lightning/callbacks/early_stopping.py | 2 - .../callbacks/model_checkpoint.py | 4 + pytorch_lightning/trainer/callback_hook.py | 22 +-- pytorch_lightning/trainer/evaluation_loop.py | 7 +- pytorch_lightning/trainer/trainer.py | 10 +- pytorch_lightning/trainer/training_loop.py | 100 ++++++------ tests/trainer/test_callbacks.py | 153 ++++++++++++++++++ tests/trainer/test_trainer.py | 119 -------------- 11 files changed, 220 insertions(+), 214 deletions(-) create mode 100644 tests/trainer/test_callbacks.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a17244a7b6c9..201d850bf922a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868)) - Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876)) - Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849)) +- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950)) - Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) ### Changed diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 07aba6425510b..fdea966521966 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -44,8 +44,4 @@ Callback Class _del_model, _save_model, _abc_impl, - on_epoch_end, - on_train_end, - on_epoch_start, check_monitor_top_k, - on_train_start, diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ea06cab75636a..6f04edbb3f988 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -12,19 +12,11 @@ class Callback(abc.ABC): """Abstract base class used to build new callbacks.""" def on_init_start(self, trainer): - """Called when the trainer initialization begins.""" + """Called when the trainer initialization begins, model has not yet been set.""" pass def on_init_end(self, trainer): - """Called when the trainer initialization ends.""" - pass - - def on_fit_start(self, trainer, pl_module): - """Called when the fit begins.""" - pass - - def on_fit_end(self, trainer, pl_module): - """Called when the fit ends.""" + """Called when the trainer initialization ends, model has not yet been set.""" pass def on_epoch_start(self, trainer, pl_module): diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 435f0d533c500..10823e956adba 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -64,8 +64,6 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.monitor_op = mode_dict[mode] self.min_delta *= 1 if self.monitor_op == np.greater else -1 - self.on_train_start(None, None) - def check_metrics(self, logs): monitor_val = logs.get(self.monitor) error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 24727033ff1d3..b3509fb68ccd0 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -118,6 +118,10 @@ def check_monitor_top_k(self, current): return self.monitor_op(current, self.best_k_models[self.kth_best_model]) def on_validation_end(self, trainer, pl_module): + # only run on main process + if trainer.proc_rank != 0: + return + logs = trainer.callback_metrics epoch = trainer.current_epoch self.epochs_since_last_check += 1 diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index f8c2848b07234..3ab7575fe8265 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -12,25 +12,15 @@ def __init__(self): self.callbacks: list[Callback] = [] self.get_model: Callable = ... - def on_init_start(self, trainer): - """Called when the trainer initialization begins.""" + def on_init_start(self): + """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_init_start(trainer) + callback.on_init_start(self) - def on_init_end(self, trainer): - """Called when the trainer initialization ends.""" + def on_init_end(self): + """Called when the trainer initialization ends, model has not yet been set.""" for callback in self.callbacks: - callback.on_init_end(trainer) - - def on_fit_start(self): - """Called when the fit begins.""" - for callback in self.callbacks: - callback.on_fit_start(self, self.get_model()) - - def on_fit_end(self): - """Called when the fit ends.""" - for callback in self.callbacks: - callback.on_fit_end(self, self.get_model()) + callback.on_init_end(self) def on_epoch_start(self): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index b7a5002b9c1da..4d9e1df3d4705 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -374,14 +374,13 @@ def run_evaluation(self, test_mode: bool = False): else: self.val_progress_bar.close() - # model checkpointing - if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode: - self.checkpoint_callback.on_validation_end(self, self.get_model()) - # Validation/Test end callbacks if test_mode: self.on_test_end() else: + # model checkpointing + if self.checkpoint_callback is not None: + self.checkpoint_callback.on_validation_end(self, self.get_model()) self.on_validation_end() def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5647d24a03404..a791d307cf3e8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -618,7 +618,7 @@ def on_train_end(self): # Init callbacks self.callbacks = callbacks - self.on_init_start(self) + self.on_init_start() # benchmarking self.benchmark = benchmark @@ -808,7 +808,7 @@ def on_train_end(self): self.init_amp(use_amp) # Callback system - self.on_init_end(self) + self.on_init_end() @property def slurm_job_id(self) -> int: @@ -941,9 +941,6 @@ def fit( # bind logger model.logger = self.logger - # Fit begin callbacks - self.on_fit_start() - # set up the passed in dataloaders (if needed) self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) @@ -1006,9 +1003,6 @@ def fit( self.run_pretrain_routine(model) - # Fit end callbacks - self.on_fit_end() - # return 1 when finished # used for testing or when we need to know that training succeeded return 1 diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e166f6e991de7..94c9de74d4c91 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -302,9 +302,15 @@ def train(self): self.reset_train_dataloader(model) self.reset_val_dataloader(model) - # Train begin callbacks - model.on_train_start() - self.on_train_start() + # Train start events + with self.profiler.profile('on_train_start'): + # callbacks + self.on_train_start() + # initialize early stop callback + if self.early_stop_callback is not None: + self.early_stop_callback.on_train_start(self, self.get_model()) + # model hooks + model.on_train_start() try: # run all epochs @@ -347,9 +353,6 @@ def train(self): desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else '' self.main_progress_bar.set_description(desc) - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_start(self, self.get_model()) - # ----------------- # RUN TNG EPOCH # ----------------- @@ -369,15 +372,14 @@ def train(self): self.reduce_lr_on_plateau_scheduler.step(val_loss) if self.max_steps and self.max_steps == self.global_step: - self.main_progress_bar.close() - model.on_train_end() - self.on_train_end() + self.run_training_teardown() return # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + # TODO wrap this logic into the callback if self.enable_early_stop and not self.disable_validation and is_val_epoch: if ((met_min_epochs and met_min_steps) or self.fast_dev_run): should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model()) @@ -385,7 +387,6 @@ def train(self): stop = should_stop and met_min_epochs if stop: self.run_training_teardown() - self.on_train_end() return self.run_training_teardown() @@ -394,19 +395,17 @@ def train(self): log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') self.run_training_teardown() - # Train end callbacks - self.on_train_end() - def run_training_epoch(self): - # Epoch begin callbacks - self.on_epoch_start() - - # before epoch hook - if self.is_function_implemented('on_epoch_start'): - model = self.get_model() - with self.profiler.profile('on_epoch_start'): - model.on_epoch_start() + # Epoch start events + with self.profiler.profile('on_epoch_start'): + # callbacks + self.on_epoch_start() + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_start(self, self.get_model()) + # model hooks + if self.is_function_implemented('on_epoch_start'): + self.get_model().on_epoch_start() # reset train dataloader if self.reload_dataloaders_every_epoch: @@ -485,14 +484,13 @@ def run_training_epoch(self): if early_stop_epoch or self.fast_dev_run: break - # epoch end hook - if self.is_function_implemented('on_epoch_end'): - model = self.get_model() - with self.profiler.profile('on_epoch_end'): - model.on_epoch_end() - - # Epoch begin callbacks - self.on_epoch_end() + # Epoch end events + with self.profiler.profile('on_epoch_end'): + # callbacks + self.on_epoch_end() + # model hooks + if self.is_function_implemented('on_epoch_end'): + self.get_model().on_epoch_end() def run_training_batch(self, batch, batch_idx): # track grad norms @@ -507,17 +505,15 @@ def run_training_batch(self, batch, batch_idx): if batch is None: return 0, grad_norm_dic, {} - # Batch begin callbacks - self.on_batch_start() - - # hook - if self.is_function_implemented('on_batch_start'): - model_ref = self.get_model() - with self.profiler.profile('on_batch_start'): - response = model_ref.on_batch_start(batch) - - if response == -1: - return -1, grad_norm_dic, {} + # Batch start events + with self.profiler.profile('on_batch_start'): + # callbacks + self.on_batch_start() + # hooks + if self.is_function_implemented('on_batch_start'): + response = self.get_model().on_batch_start(batch) + if response == -1: + return -1, grad_norm_dic, {} splits = [batch] if self.truncated_bptt_steps is not None: @@ -612,14 +608,13 @@ def optimizer_closure(): self.batch_loss_value = 0 self.avg_loss = np.mean(self.running_loss[-100:]) - # activate batch end hook - if self.is_function_implemented('on_batch_end'): - model = self.get_model() - with self.profiler.profile('on_batch_end'): - model.on_batch_end() - - # Batch end callbacks - self.on_batch_end() + # Batch end events + with self.profiler.profile('on_batch_end'): + # callbacks + self.on_batch_end() + # model hooks + if self.is_function_implemented('on_batch_end'): + self.get_model().on_batch_end() # update progress bar if batch_idx % self.progress_bar_refresh_rate == 0: @@ -635,12 +630,15 @@ def optimizer_closure(): return 0, grad_norm_dic, all_log_metrics def run_training_teardown(self): - model = self.get_model() - self.main_progress_bar.close() + # Train end events with self.profiler.profile('on_train_end'): - model.on_train_end() + # callbacks + self.on_train_end() + # model hooks + if self.is_function_implemented('on_train_end'): + self.get_model().on_train_end() if self.logger is not None: self.logger.finalize("success") diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py new file mode 100644 index 0000000000000..23cea98107434 --- /dev/null +++ b/tests/trainer/test_callbacks.py @@ -0,0 +1,153 @@ +import tests.models.utils as tutils +from pytorch_lightning import Trainer, LightningModule +from tests.models import ( + TestModelBase, + LightTrainDataloader, + LightValidationMixin, + LightTestMixin +) + +from pytorch_lightning import Callback + + +def test_trainer_callback_system(tmpdir): + """Test the callback system.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + def _check_args(trainer, pl_module): + assert isinstance(trainer, Trainer) + assert isinstance(pl_module, LightningModule) + + class TestCallback(Callback): + def __init__(self): + super().__init__() + self.on_init_start_called = False + self.on_init_end_called = False + self.on_epoch_start_called = False + self.on_epoch_end_called = False + self.on_batch_start_called = False + self.on_batch_end_called = False + self.on_train_start_called = False + self.on_train_end_called = False + self.on_validation_start_called = False + self.on_validation_end_called = False + self.on_test_start_called = False + self.on_test_end_called = False + + def on_init_start(self, trainer): + assert isinstance(trainer, Trainer) + self.on_init_start_called = True + + def on_init_end(self, trainer): + assert isinstance(trainer, Trainer) + self.on_init_end_called = True + + def on_epoch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_epoch_start_called = True + + def on_epoch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_epoch_end_called = True + + def on_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_batch_start_called = True + + def on_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_batch_end_called = True + + def on_train_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_train_start_called = True + + def on_train_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_train_end_called = True + + def on_validation_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_start_called = True + + def on_validation_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_end_called = True + + def on_test_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_start_called = True + + def on_test_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_end_called = True + + test_callback = TestCallback() + + trainer_options = { + 'callbacks': [test_callback], + 'max_epochs': 1, + 'val_percent_check': 0.1, + 'train_percent_check': 0.2, + 'show_progress_bar': False + } + + assert not test_callback.on_init_start_called + assert not test_callback.on_init_end_called + assert not test_callback.on_epoch_start_called + assert not test_callback.on_epoch_start_called + assert not test_callback.on_batch_start_called + assert not test_callback.on_batch_end_called + assert not test_callback.on_train_start_called + assert not test_callback.on_train_end_called + assert not test_callback.on_validation_start_called + assert not test_callback.on_validation_end_called + assert not test_callback.on_test_start_called + assert not test_callback.on_test_end_called + + # fit model + trainer = Trainer(**trainer_options) + + assert trainer.callbacks[0] == test_callback + assert test_callback.on_init_start_called + assert test_callback.on_init_end_called + assert not test_callback.on_epoch_start_called + assert not test_callback.on_epoch_start_called + assert not test_callback.on_batch_start_called + assert not test_callback.on_batch_end_called + assert not test_callback.on_train_start_called + assert not test_callback.on_train_end_called + assert not test_callback.on_validation_start_called + assert not test_callback.on_validation_end_called + assert not test_callback.on_test_start_called + assert not test_callback.on_test_end_called + + trainer.fit(model) + + assert test_callback.on_init_start_called + assert test_callback.on_init_end_called + assert test_callback.on_epoch_start_called + assert test_callback.on_epoch_start_called + assert test_callback.on_batch_start_called + assert test_callback.on_batch_end_called + assert test_callback.on_train_start_called + assert test_callback.on_train_end_called + assert test_callback.on_validation_start_called + assert test_callback.on_validation_end_called + assert not test_callback.on_test_start_called + assert not test_callback.on_test_end_called + + trainer.test() + + assert test_callback.on_test_start_called + assert test_callback.on_test_end_called diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a03e4087a702d..4c16c921296d2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -600,122 +600,3 @@ def test_end(self, outputs): model = LightningTestModel(hparams) Trainer().test(model) - - -def test_trainer_callback_system(tmpdir): - """Test the callback system.""" - - class CurrentTestModel( - LightTrainDataloader, - LightTestMixin, - LightValidationMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - class TestCallback(Callback): - def __init__(self): - super().__init__() - self.on_init_start_called = False - self.on_init_end_called = False - self.on_fit_start_called = False - self.on_fit_end_called = False - self.on_epoch_start_called = False - self.on_epoch_end_called = False - self.on_batch_start_called = False - self.on_batch_end_called = False - self.on_train_start_called = False - self.on_train_end_called = False - self.on_validation_start_called = False - self.on_validation_end_called = False - self.on_test_start_called = False - self.on_test_end_called = False - - def on_init_start(self, trainer): - self.on_init_start_called = True - - def on_init_end(self, trainer): - self.on_init_end_called = True - - def on_fit_start(self, trainer, pl_module): - self.on_fit_start_called = True - - def on_fit_end(self, trainer, pl_module): - self.on_fit_end_called = True - - def on_epoch_start(self, trainer, pl_module): - self.on_epoch_start_called = True - - def on_epoch_end(self, trainer, pl_module): - self.on_epoch_end_called = True - - def on_batch_start(self, trainer, pl_module): - self.on_batch_start_called = True - - def on_batch_end(self, trainer, pl_module): - self.on_batch_end_called = True - - def on_train_start(self, trainer, pl_module): - self.on_train_start_called = True - - def on_train_end(self, trainer, pl_module): - self.on_train_end_called = True - - def on_validation_start(self, trainer, pl_module): - self.on_validation_start_called = True - - def on_validation_end(self, trainer, pl_module): - self.on_validation_end_called = True - - def on_test_start(self, trainer, pl_module): - self.on_test_start_called = True - - def on_test_end(self, trainer, pl_module): - self.on_test_end_called = True - - test_callback = TestCallback() - - trainer_options = {} - trainer_options['callbacks'] = [test_callback] - trainer_options['max_epochs'] = 1 - trainer_options['val_percent_check'] = 0.1 - trainer_options['train_percent_check'] = 0.2 - trainer_options['show_progress_bar'] = False - - assert not test_callback.on_init_start_called - assert not test_callback.on_init_end_called - - # fit model - trainer = Trainer(**trainer_options) - - assert trainer.callbacks[0] == test_callback - assert test_callback.on_init_start_called - assert test_callback.on_init_end_called - assert not test_callback.on_fit_start_called - assert not test_callback.on_fit_start_called - - trainer.fit(model) - - assert test_callback.on_fit_start_called - assert test_callback.on_fit_end_called - assert test_callback.on_epoch_start_called - assert test_callback.on_epoch_start_called - assert test_callback.on_batch_start_called - assert test_callback.on_batch_end_called - assert test_callback.on_train_start_called - assert test_callback.on_train_end_called - assert test_callback.on_validation_start_called - assert test_callback.on_validation_end_called - assert not test_callback.on_test_start_called - assert not test_callback.on_test_end_called - - trainer.test() - - assert test_callback.on_test_start_called - assert test_callback.on_test_end_called - -# if __name__ == '__main__': -# pytest.main([__file__])