diff --git a/mmcv/utils/registry.py b/mmcv/utils/registry.py index 64b83f1d75..7f96b6cc67 100644 --- a/mmcv/utils/registry.py +++ b/mmcv/utils/registry.py @@ -2,7 +2,7 @@ import warnings from functools import partial -from .misc import is_str +from .misc import is_seq_of class Registry: @@ -54,10 +54,18 @@ def _register_module(self, module_class, module_name=None, force=False): if module_name is None: module_name = module_class.__name__ - if not force and module_name in self._module_dict: - raise KeyError(f'{module_name} is already registered ' - f'in {self.name}') - self._module_dict[module_name] = module_class + if isinstance(module_name, str): + module_name = [module_name] + else: + assert is_seq_of( + module_name, + str), ('module_name should be either of None, an ' + f'instance of str or list, but got {type(module_name)}') + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f'{name} is already registered ' + f'in {self.name}') + self._module_dict[name] = module_class def deprecated_register_module(self, cls=None, force=False): warnings.warn( @@ -157,7 +165,7 @@ def build_from_cfg(cfg, registry, default_args=None): args.setdefault(name, value) obj_type = args.pop('type') - if is_str(obj_type): + if isinstance(obj_type, str): obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError( diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py index 104cc1964c..3106c39e56 100644 --- a/tests/test_utils/test_registry.py +++ b/tests/test_utils/test_registry.py @@ -58,6 +58,9 @@ class SphynxCat: CATS.register_module(name='Sphynx', module=SphynxCat) assert CATS.get('Sphynx') is SphynxCat + CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat) + assert CATS.get('Sphynx2') is SphynxCat + repr_str = 'Registry(name=cat, items={' repr_str += ("'BritishShorthair': .BritishShorthair'>, ") @@ -66,10 +69,18 @@ class SphynxCat: repr_str += ("'Siamese': .SiameseCat'>, ") repr_str += ("'Sphynx': .SphynxCat'>, ") + repr_str += ("'Sphynx1': .SphynxCat'>, ") + repr_str += ("'Sphynx2': .SphynxCat'>") repr_str += '})' assert repr(CATS) == repr_str + # name type + with pytest.raises(AssertionError): + CATS.register_module(name=7474741, module=SphynxCat) + # the registered module should be a class with pytest.raises(TypeError): CATS.register_module(0)