Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom hook by config file #970

Merged
merged 4 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def main():
lr_config=cfg.lr_config,
optimizer_config=cfg.optimizer_config,
checkpoint_config=cfg.checkpoint_config,
log_config=cfg.log_config)
log_config=cfg.log_config,
custom_hooks_config=cfg.get('custom_train_hooks', None))
if dist:
runner.register_hook(DistSamplerSeedHook())

Expand Down
51 changes: 34 additions & 17 deletions mmcv/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def register_lr_hook(self, lr_config):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook)
self.register_hook(hook, priority=10)

def register_momentum_hook(self, momentum_config):
if momentum_config is None:
Expand All @@ -407,7 +407,7 @@ def register_momentum_hook(self, momentum_config):
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook)
self.register_hook(hook, priority=30)

def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
Expand All @@ -417,7 +417,7 @@ def register_optimizer_hook(self, optimizer_config):
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
self.register_hook(hook, priority=50)

def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
Expand All @@ -427,7 +427,7 @@ def register_checkpoint_hook(self, checkpoint_config):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
self.register_hook(hook, priority=70)

def register_logger_hooks(self, log_config):
if log_config is None:
Expand All @@ -436,7 +436,7 @@ def register_logger_hooks(self, log_config):
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
self.register_hook(logger_hook, priority=90)

def register_timer_hook(self, timer_config):
if timer_config is None:
Expand All @@ -446,7 +446,20 @@ def register_timer_hook(self, timer_config):
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook)
self.register_hook(hook, priority=80)

def register_custom_hooks(self, custom_config):
if custom_config is None:
return

if not isinstance(custom_config, list):
custom_config = [custom_config]

for item in custom_config:
if isinstance(item, dict):
self.register_hook_from_cfg(item)
else:
self.register_hook(item, priority='NORMAL')

def register_profiler_hook(self, profiler_config):
if profiler_config is None:
Expand All @@ -464,21 +477,25 @@ def register_training_hooks(self,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook')):
"""Register default hooks for training.

Default hooks include:

- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
"""Register default and custom hooks for training.

Default and custom hooks include:

Hooks Priority
- LrUpdaterHook 10
- MomentumUpdaterHook 30
- OptimizerStepperHook 50
- CheckpointSaverHook 70
- IterTimerHook 80
- LoggerHook(s) 90
- CustomHook(s) 50 (default)
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
77 changes: 68 additions & 9 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
build_runner)
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
OneCycleLrUpdaterHook,
StepLrUpdaterHook)
Expand Down Expand Up @@ -122,6 +123,53 @@ def val_step(self, x, optimizer, **kwargs):
shutil.rmtree(work_dir)


def test_custom_hook():

@HOOKS.register_module()
class ToyHook(Hook):

def __init__(self, info, *args, **kwargs):
super().__init__()
self.info = info

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test if custom_hooks is None
runner.register_custom_hooks(None)
assert len(runner.hooks) == 0
# test if custom_hooks is dict list
custom_hooks_cfg = [
dict(type='ToyHook', priority=51, info=51),
dict(type='ToyHook', priority=49, info=49)
]
runner.register_custom_hooks(custom_hooks_cfg)
assert [hook.info for hook in runner.hooks] == [49, 51]
# test if custom_hooks is object and without priority
runner.register_custom_hooks(ToyHook(info='default'))
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir)

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order
custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'),
dict(type='ToyHook', priority=89, info='custom 89')
]
runner.register_training_hooks(
lr_config=ToyHook('lr'),
optimizer_config=ToyHook('optimizer'),
checkpoint_config=ToyHook('checkpoint'),
log_config=dict(interval=1, hooks=[dict(type='ToyHook', info='log')]),
momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg)
hooks_order = [
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
'custom 89', 'log'
]
assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir)


def test_pavi_hook():
sys.modules['pavi'] = MagicMock()

Expand Down Expand Up @@ -760,10 +808,10 @@ def test_wandb_hook():
hook.wandb.join.assert_called_with()


def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):

class Model(nn.Module):

Expand Down Expand Up @@ -793,11 +841,6 @@ def val_step(self, x, optimizer, **kwargs):
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)

log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])

tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
Expand All @@ -808,6 +851,22 @@ def val_step(self, x, optimizer, **kwargs):
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
return runner


def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):

log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])

runner = _build_demo_runner_without_hook(runner_type, max_epochs,
max_iters, multi_optimziers)

runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
Expand Down