Skip to content

Commit

Permalink
Add custom hook by config file (#970)
Browse files Browse the repository at this point in the history
* Assign different priority to default hooks, and add custom hook register in base runner.

* Add custom hook register in example train file

* Add unittest of custom hook

* Code format
  • Loading branch information
mzr1996 authored May 13, 2021
1 parent 9b8dd08 commit 15bcaa9
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 27 deletions.
3 changes: 2 additions & 1 deletion examples/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def main():
lr_config=cfg.lr_config,
optimizer_config=cfg.optimizer_config,
checkpoint_config=cfg.checkpoint_config,
log_config=cfg.log_config)
log_config=cfg.log_config,
custom_hooks_config=cfg.get('custom_train_hooks', None))
if dist:
runner.register_hook(DistSamplerSeedHook())

Expand Down
51 changes: 34 additions & 17 deletions mmcv/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def register_lr_hook(self, lr_config):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook)
self.register_hook(hook, priority=10)

def register_momentum_hook(self, momentum_config):
if momentum_config is None:
Expand All @@ -412,7 +412,7 @@ def register_momentum_hook(self, momentum_config):
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook)
self.register_hook(hook, priority=30)

def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
Expand All @@ -422,7 +422,7 @@ def register_optimizer_hook(self, optimizer_config):
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
self.register_hook(hook, priority=50)

def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
Expand All @@ -432,7 +432,7 @@ def register_checkpoint_hook(self, checkpoint_config):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
self.register_hook(hook, priority=70)

def register_logger_hooks(self, log_config):
if log_config is None:
Expand All @@ -441,7 +441,7 @@ def register_logger_hooks(self, log_config):
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
self.register_hook(logger_hook, priority=90)

def register_timer_hook(self, timer_config):
if timer_config is None:
Expand All @@ -451,7 +451,20 @@ def register_timer_hook(self, timer_config):
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook)
self.register_hook(hook, priority=80)

def register_custom_hooks(self, custom_config):
if custom_config is None:
return

if not isinstance(custom_config, list):
custom_config = [custom_config]

for item in custom_config:
if isinstance(item, dict):
self.register_hook_from_cfg(item)
else:
self.register_hook(item, priority='NORMAL')

def register_profiler_hook(self, profiler_config):
if profiler_config is None:
Expand All @@ -469,21 +482,25 @@ def register_training_hooks(self,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook')):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
"""Register default and custom hooks for training.
Default and custom hooks include:
Hooks Priority
- LrUpdaterHook 10
- MomentumUpdaterHook 30
- OptimizerStepperHook 50
- CheckpointSaverHook 70
- IterTimerHook 80
- LoggerHook(s) 90
- CustomHook(s) 50 (default)
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
77 changes: 68 additions & 9 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner)
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook,
OneCycleLrUpdaterHook,
Expand Down Expand Up @@ -123,6 +124,53 @@ def val_step(self, x, optimizer, **kwargs):
shutil.rmtree(work_dir)


def test_custom_hook():

@HOOKS.register_module()
class ToyHook(Hook):

def __init__(self, info, *args, **kwargs):
super().__init__()
self.info = info

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test if custom_hooks is None
runner.register_custom_hooks(None)
assert len(runner.hooks) == 0
# test if custom_hooks is dict list
custom_hooks_cfg = [
dict(type='ToyHook', priority=51, info=51),
dict(type='ToyHook', priority=49, info=49)
]
runner.register_custom_hooks(custom_hooks_cfg)
assert [hook.info for hook in runner.hooks] == [49, 51]
# test if custom_hooks is object and without priority
runner.register_custom_hooks(ToyHook(info='default'))
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir)

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order
custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'),
dict(type='ToyHook', priority=89, info='custom 89')
]
runner.register_training_hooks(
lr_config=ToyHook('lr'),
optimizer_config=ToyHook('optimizer'),
checkpoint_config=ToyHook('checkpoint'),
log_config=dict(interval=1, hooks=[dict(type='ToyHook', info='log')]),
momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg)
hooks_order = [
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
'custom 89', 'log'
]
assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir)


def test_pavi_hook():
sys.modules['pavi'] = MagicMock()

Expand Down Expand Up @@ -867,10 +915,10 @@ def test_wandb_hook():
hook.wandb.join.assert_called_with()


def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):

class Model(nn.Module):

Expand Down Expand Up @@ -900,11 +948,6 @@ def val_step(self, x, optimizer, **kwargs):
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)

log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])

tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
Expand All @@ -915,6 +958,22 @@ def val_step(self, x, optimizer, **kwargs):
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
return runner


def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):

log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])

runner = _build_demo_runner_without_hook(runner_type, max_epochs,
max_iters, multi_optimziers)

runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
Expand Down

0 comments on commit 15bcaa9

Please sign in to comment.