Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RunnerConstructor #1296

Merged
merged 12 commits into from
Aug 24, 2021
Merged
3 changes: 2 additions & 1 deletion mmcv/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .checkpoint import (CheckpointLoader, _load_checkpoint,
_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu)
from .default_constructor import DefaultRunnerConstructor
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner
Expand Down Expand Up @@ -42,5 +43,5 @@
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook'
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
]
20 changes: 18 additions & 2 deletions mmcv/runner/builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..utils import Registry, build_from_cfg
import copy

from ..utils import Registry

RUNNERS = Registry('runner')
RUNNER_BUILDERS = Registry('runner builder')


def build_runner_constructor(cfg):
return RUNNER_BUILDERS.build(cfg)


def build_runner(cfg, default_args=None):
return build_from_cfg(cfg, RUNNERS, default_args=default_args)
runner_cfg = copy.deepcopy(cfg)
constructor_type = runner_cfg.pop('constructor',
'DefaultRunnerConstructor')
runner_constructor = build_runner_constructor(
dict(
type=constructor_type,
runner_cfg=runner_cfg,
default_args=default_args))
runner = runner_constructor()
return runner
44 changes: 44 additions & 0 deletions mmcv/runner/default_constructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from .builder import RUNNER_BUILDERS, RUNNERS


@RUNNER_BUILDERS.register_module()
class DefaultRunnerConstructor:
"""Default constructor for runners.

Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
For example, We can inject some new properties and functions for `Runner`.

Example:
>>> from mmcv.runner import RUNNER_BUILDERS, build_runner
>>> # Define a new RunnerReconstructor
>>> @RUNNER_BUILDERS.register_module()
>>> class MyRunnerConstructor:
... def __init__(self, runner_cfg, default_args=None):
... if not isinstance(runner_cfg, dict):
... raise TypeError('runner_cfg should be a dict',
... f'but got {type(runner_cfg)}')
... self.runner_cfg = runner_cfg
... self.default_args = default_args
...
... def __call__(self):
... runner = RUNNERS.build(self.runner_cfg,
... default_args=self.default_args)
... # Add new properties for existing runner
... runner.my_name = 'my_runner'
... runner.my_function = lambda self: print(self.my_name)
... ...
>>> # build your runner
>>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
... constructor='MyRunnerConstructor')
>>> runner = build_runner(runner_cfg)
"""

def __init__(self, runner_cfg, default_args=None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
self.runner_cfg = runner_cfg
self.default_args = default_args

def __call__(self):
return RUNNERS.build(self.runner_cfg, default_args=self.default_args)