From 2ade6343e547f239642b8be6c106464f61ad18d5 Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Tue, 20 Apr 2021 10:09:32 +0800 Subject: [PATCH 1/4] Assign different priority to default hooks, and add custom hook register in base runner. --- mmcv/runner/base_runner.py | 44 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py index 879ae9a885..57edce3eae 100644 --- a/mmcv/runner/base_runner.py +++ b/mmcv/runner/base_runner.py @@ -386,7 +386,7 @@ def register_lr_hook(self, lr_config): hook = mmcv.build_from_cfg(lr_config, HOOKS) else: hook = lr_config - self.register_hook(hook) + self.register_hook(hook, priority=10) def register_momentum_hook(self, momentum_config): if momentum_config is None: @@ -407,7 +407,7 @@ def register_momentum_hook(self, momentum_config): hook = mmcv.build_from_cfg(momentum_config, HOOKS) else: hook = momentum_config - self.register_hook(hook) + self.register_hook(hook, priority=30) def register_optimizer_hook(self, optimizer_config): if optimizer_config is None: @@ -417,7 +417,7 @@ def register_optimizer_hook(self, optimizer_config): hook = mmcv.build_from_cfg(optimizer_config, HOOKS) else: hook = optimizer_config - self.register_hook(hook) + self.register_hook(hook, priority=50) def register_checkpoint_hook(self, checkpoint_config): if checkpoint_config is None: @@ -427,7 +427,7 @@ def register_checkpoint_hook(self, checkpoint_config): hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) else: hook = checkpoint_config - self.register_hook(hook) + self.register_hook(hook, priority=70) def register_logger_hooks(self, log_config): if log_config is None: @@ -436,7 +436,7 @@ def register_logger_hooks(self, log_config): for info in log_config['hooks']: logger_hook = mmcv.build_from_cfg( info, HOOKS, default_args=dict(interval=log_interval)) - self.register_hook(logger_hook, priority='VERY_LOW') + self.register_hook(logger_hook, priority=90) def register_timer_hook(self, timer_config): if timer_config is None: @@ -446,7 +446,20 @@ def register_timer_hook(self, timer_config): hook = mmcv.build_from_cfg(timer_config_, HOOKS) else: hook = timer_config - self.register_hook(hook) + self.register_hook(hook, priority=80) + + def register_custom_hooks(self, custom_config): + if custom_config is None: + return + + if not isinstance(custom_config, list): + custom_config = [custom_config] + + for item in custom_config: + if isinstance(item, dict): + self.register_hook_from_cfg(item) + else: + self.register_hook(item, priority='NORMAL') def register_profiler_hook(self, profiler_config): if profiler_config is None: @@ -464,8 +477,9 @@ def register_training_hooks(self, checkpoint_config=None, log_config=None, momentum_config=None, - timer_config=dict(type='IterTimerHook')): - """Register default hooks for training. + timer_config=dict(type='IterTimerHook'), + custom_hooks_config=None): + """Register default and custom hooks for training. Default hooks include: @@ -476,9 +490,11 @@ def register_training_hooks(self, - IterTimerHook - LoggerHook(s) """ - self.register_lr_hook(lr_config) - self.register_momentum_hook(momentum_config) - self.register_optimizer_hook(optimizer_config) - self.register_checkpoint_hook(checkpoint_config) - self.register_timer_hook(timer_config) - self.register_logger_hooks(log_config) + # priority + self.register_lr_hook(lr_config) # 10 + self.register_momentum_hook(momentum_config) # 30 + self.register_optimizer_hook(optimizer_config) # 50 + self.register_checkpoint_hook(checkpoint_config) # 70 + self.register_timer_hook(timer_config) # 80 + self.register_logger_hooks(log_config) # 90 + self.register_custom_hooks(custom_hooks_config) # 50 (default) From 77c20e34172bfb73aaa71069d839063c01ae9576 Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Tue, 20 Apr 2021 10:26:20 +0800 Subject: [PATCH 2/4] Add custom hook register in example train file --- examples/train_cifar10.py | 3 ++- mmcv/runner/base_runner.py | 33 +++++++++++++++++---------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/examples/train_cifar10.py b/examples/train_cifar10.py index 63d90ee988..d1fcaf857c 100644 --- a/examples/train_cifar10.py +++ b/examples/train_cifar10.py @@ -159,7 +159,8 @@ def main(): lr_config=cfg.lr_config, optimizer_config=cfg.optimizer_config, checkpoint_config=cfg.checkpoint_config, - log_config=cfg.log_config) + log_config=cfg.log_config, + custom_hooks_config=cfg.get('custom_train_hooks', None)) if dist: runner.register_hook(DistSamplerSeedHook()) diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py index 57edce3eae..de29bcefb0 100644 --- a/mmcv/runner/base_runner.py +++ b/mmcv/runner/base_runner.py @@ -481,20 +481,21 @@ def register_training_hooks(self, custom_hooks_config=None): """Register default and custom hooks for training. - Default hooks include: - - - LrUpdaterHook - - MomentumUpdaterHook - - OptimizerStepperHook - - CheckpointSaverHook - - IterTimerHook - - LoggerHook(s) + Default and custom hooks include: + + Hooks Priority + - LrUpdaterHook 10 + - MomentumUpdaterHook 30 + - OptimizerStepperHook 50 + - CheckpointSaverHook 70 + - IterTimerHook 80 + - LoggerHook(s) 90 + - CustomHook(s) 50 (default) """ - # priority - self.register_lr_hook(lr_config) # 10 - self.register_momentum_hook(momentum_config) # 30 - self.register_optimizer_hook(optimizer_config) # 50 - self.register_checkpoint_hook(checkpoint_config) # 70 - self.register_timer_hook(timer_config) # 80 - self.register_logger_hooks(log_config) # 90 - self.register_custom_hooks(custom_hooks_config) # 50 (default) + self.register_lr_hook(lr_config) + self.register_momentum_hook(momentum_config) + self.register_optimizer_hook(optimizer_config) + self.register_checkpoint_hook(checkpoint_config) + self.register_timer_hook(timer_config) + self.register_logger_hooks(log_config) + self.register_custom_hooks(custom_hooks_config) From 8db13b0aa366b5999db964d0d7935185628e2f12 Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Tue, 11 May 2021 10:34:05 +0800 Subject: [PATCH 3/4] Add unittest of custom hook --- tests/test_runner/test_hooks.py | 76 +++++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index 6a6e72960d..7f0054ca7f 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -24,6 +24,7 @@ from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, OneCycleLrUpdaterHook, StepLrUpdaterHook) +from mmcv.runner.hooks.hook import HOOKS, Hook def test_checkpoint_hook(): @@ -122,6 +123,51 @@ def val_step(self, x, optimizer, **kwargs): shutil.rmtree(work_dir) +def test_custom_hook(): + @HOOKS.register_module() + class ToyHook(Hook): + def __init__(self, info, *args, **kwargs): + super().__init__() + self.info = info + + runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1) + # test if custom_hooks is None + runner.register_custom_hooks(None) + assert len(runner.hooks) == 0 + # test if custom_hooks is dict list + custom_hooks_cfg = [ + dict(type='ToyHook', priority=51, info=51), + dict(type='ToyHook', priority=49, info=49) + ] + runner.register_custom_hooks(custom_hooks_cfg) + assert [hook.info for hook in runner.hooks] == [49, 51] + # test if custom_hooks is object and without priority + runner.register_custom_hooks(ToyHook(info='default')) + assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default' + shutil.rmtree(runner.work_dir) + + runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1) + # test register_training_hooks order + custom_hooks_cfg = [ + dict(type='ToyHook', priority=1, info='custom 1'), + dict(type='ToyHook', priority=89, info='custom 89') + ] + runner.register_training_hooks( + lr_config=ToyHook('lr'), + optimizer_config=ToyHook('optimizer'), + checkpoint_config=ToyHook('checkpoint'), + log_config=dict(interval=1, hooks=[dict(type='ToyHook', info='log')]), + momentum_config=ToyHook('momentum'), + timer_config=ToyHook('timer'), + custom_hooks_config=custom_hooks_cfg) + hooks_order = [ + 'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer', + 'custom 89', 'log' + ] + assert [hook.info for hook in runner.hooks] == hooks_order + shutil.rmtree(runner.work_dir) + + def test_pavi_hook(): sys.modules['pavi'] = MagicMock() @@ -760,10 +806,10 @@ def test_wandb_hook(): hook.wandb.join.assert_called_with() -def _build_demo_runner(runner_type='EpochBasedRunner', - max_epochs=1, - max_iters=None, - multi_optimziers=False): +def _build_demo_runner_without_hook(runner_type='EpochBasedRunner', + max_epochs=1, + max_iters=None, + multi_optimziers=False): class Model(nn.Module): @@ -793,11 +839,6 @@ def val_step(self, x, optimizer, **kwargs): else: optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95) - log_config = dict( - interval=1, hooks=[ - dict(type='TextLoggerHook'), - ]) - tmp_dir = tempfile.mkdtemp() runner = build_runner( dict(type=runner_type), @@ -808,6 +849,23 @@ def val_step(self, x, optimizer, **kwargs): logger=logging.getLogger(), max_epochs=max_epochs, max_iters=max_iters)) + return runner + + +def _build_demo_runner(runner_type='EpochBasedRunner', + max_epochs=1, + max_iters=None, + multi_optimziers=False): + + + log_config = dict( + interval=1, hooks=[ + dict(type='TextLoggerHook'), + ]) + + runner = _build_demo_runner_without_hook(runner_type, max_epochs, + max_iters, multi_optimziers) + runner.register_checkpoint_hook(dict(interval=1)) runner.register_logger_hooks(log_config) return runner From 858ab1f67025f295ac5e2df95a37199f8c2e2fdb Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Tue, 11 May 2021 15:29:00 +0800 Subject: [PATCH 4/4] Code format --- tests/test_runner/test_hooks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index 7f0054ca7f..1598742652 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -21,10 +21,10 @@ from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook, MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook, build_runner) +from mmcv.runner.hooks.hook import HOOKS, Hook from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, OneCycleLrUpdaterHook, StepLrUpdaterHook) -from mmcv.runner.hooks.hook import HOOKS, Hook def test_checkpoint_hook(): @@ -124,8 +124,10 @@ def val_step(self, x, optimizer, **kwargs): def test_custom_hook(): + @HOOKS.register_module() class ToyHook(Hook): + def __init__(self, info, *args, **kwargs): super().__init__() self.info = info @@ -857,7 +859,6 @@ def _build_demo_runner(runner_type='EpochBasedRunner', max_iters=None, multi_optimziers=False): - log_config = dict( interval=1, hooks=[ dict(type='TextLoggerHook'),