Skip to content

Commit

Permalink
Add unittest of custom hook
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed May 11, 2021
1 parent b523d41 commit 3ce7470
Showing 1 changed file with 67 additions and 9 deletions.
76 changes: 67 additions & 9 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
build_runner)
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
OneCycleLrUpdaterHook)
from mmcv.runner.hooks.hook import HOOKS, Hook


def test_checkpoint_hook():
Expand Down Expand Up @@ -121,6 +122,51 @@ def val_step(self, x, optimizer, **kwargs):
shutil.rmtree(work_dir)


def test_custom_hook():
@HOOKS.register_module()
class ToyHook(Hook):
def __init__(self, info, *args, **kwargs):
super().__init__()
self.info = info

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test if custom_hooks is None
runner.register_custom_hooks(None)
assert len(runner.hooks) == 0
# test if custom_hooks is dict list
custom_hooks_cfg = [
dict(type='ToyHook', priority=51, info=51),
dict(type='ToyHook', priority=49, info=49)
]
runner.register_custom_hooks(custom_hooks_cfg)
assert [hook.info for hook in runner.hooks] == [49, 51]
# test if custom_hooks is object and without priority
runner.register_custom_hooks(ToyHook(info='default'))
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir)

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order
custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'),
dict(type='ToyHook', priority=89, info='custom 89')
]
runner.register_training_hooks(
lr_config=ToyHook('lr'),
optimizer_config=ToyHook('optimizer'),
checkpoint_config=ToyHook('checkpoint'),
log_config=dict(interval=1, hooks=[dict(type='ToyHook', info='log')]),
momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg)
hooks_order = [
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
'custom 89', 'log'
]
assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir)


def test_pavi_hook():
sys.modules['pavi'] = MagicMock()

Expand Down Expand Up @@ -583,10 +629,10 @@ def test_wandb_hook():
hook.wandb.join.assert_called_with()


def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):

class Model(nn.Module):

Expand Down Expand Up @@ -616,11 +662,6 @@ def val_step(self, x, optimizer, **kwargs):
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)

log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])

tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
Expand All @@ -631,6 +672,23 @@ def val_step(self, x, optimizer, **kwargs):
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
return runner


def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):


log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])

runner = _build_demo_runner_without_hook(runner_type, max_epochs,
max_iters, multi_optimziers)

runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
Expand Down

0 comments on commit 3ce7470

Please sign in to comment.