diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index c8f9f9485ae..d1aca853377 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -6,10 +6,9 @@ import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, +from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, Fp16OptimizerHook, OptimizerHook, build_optimizer, build_runner, get_dist_info) -from mmcv.utils import build_from_cfg from mmdet.core import DistEvalHook, EvalHook from mmdet.datasets import (build_dataloader, build_dataset, @@ -162,9 +161,14 @@ def train_detector(model, optimizer_config = cfg.optimizer_config # register hooks - runner.register_training_hooks(cfg.lr_config, optimizer_config, - cfg.checkpoint_config, cfg.log_config, - cfg.get('momentum_config', None)) + runner.register_training_hooks( + cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get('momentum_config', None), + custom_hooks_config=cfg.get('custom_hooks', None)) + if distributed: if isinstance(runner, EpochBasedRunner): runner.register_hook(DistSamplerSeedHook()) @@ -192,20 +196,6 @@ def train_detector(model, runner.register_hook( eval_hook(val_dataloader, **eval_cfg), priority='LOW') - # user-defined hooks - if cfg.get('custom_hooks', None): - custom_hooks = cfg.custom_hooks - assert isinstance(custom_hooks, list), \ - f'custom_hooks expect list type, but got {type(custom_hooks)}' - for hook_cfg in cfg.custom_hooks: - assert isinstance(hook_cfg, dict), \ - 'Each item in custom_hooks expects dict type, but got ' \ - f'{type(hook_cfg)}' - hook_cfg = hook_cfg.copy() - priority = hook_cfg.pop('priority', 'NORMAL') - hook = build_from_cfg(hook_cfg, HOOKS) - runner.register_hook(hook, priority=priority) - if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: