Skip to content

Commit

Permalink
[Refactor] Remove some code in mmdet/apis/train.py (#6576)
Browse files Browse the repository at this point in the history
* remove some code about custom hooks in apis/train.py

* files were modified by yapf
  • Loading branch information
Czm369 authored and ZwwWayne committed Nov 30, 2021
1 parent 43dc1dc commit a8a1ea3
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg

from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
Expand Down Expand Up @@ -162,9 +161,14 @@ def train_detector(model,
optimizer_config = cfg.optimizer_config

# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))

if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
Expand Down Expand Up @@ -192,20 +196,6 @@ def train_detector(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down

0 comments on commit a8a1ea3

Please sign in to comment.