Skip to content

Commit

Permalink
[Enhancement] Make build_xxx_layer allow accepting a class type (#2782)
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE authored May 11, 2023
1 parent b4dee63 commit 59c1418
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 116 deletions.
4 changes: 3 additions & 1 deletion mmcv/cnn/bricks/conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, Optional

from mmengine.registry import MODELS
Expand Down Expand Up @@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy()

layer_type = cfg_.pop('type')

if inspect.isclass(layer_type):
return layer_type(*args, **kwargs, **cfg_) # type: ignore
# Switch registry to the target scope. If `conv_layer` cannot be found
# in the registry, fallback to search `conv_layer` in the
# mmengine.MODELS.
Expand Down
21 changes: 12 additions & 9 deletions mmcv/cnn/bricks/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,25 @@ def build_norm_layer(cfg: Dict,

layer_type = cfg_.pop('type')

# Switch registry to the target scope. If `norm_layer` cannot be found
# in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under scope '
f'name {registry.scope}')
if inspect.isclass(layer_type):
norm_layer = layer_type
else:
# Switch registry to the target scope. If `norm_layer` cannot be found
# in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under '
f'scope name {registry.scope}')
abbr = infer_abbr(norm_layer)

assert isinstance(postfix, (int, str))
name = abbr + str(postfix)

requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
if norm_layer is not nn.GroupNorm:
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1)
Expand Down
4 changes: 3 additions & 1 deletion mmcv/cnn/bricks/padding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict

import torch.nn as nn
Expand Down Expand Up @@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:

cfg_ = cfg.copy()
padding_type = cfg_.pop('type')

if inspect.isclass(padding_type):
return padding_type(*args, **kwargs, **cfg_)
# Switch registry to the target scope. If `padding_layer` cannot be found
# in the registry, fallback to search `padding_layer` in the
# mmengine.MODELS.
Expand Down
21 changes: 12 additions & 9 deletions mmcv/cnn/bricks/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,18 @@ def build_plugin_layer(cfg: Dict,
cfg_ = cfg.copy()

layer_type = cfg_.pop('type')

# Switch registry to the target scope. If `plugin_layer` cannot be found
# in the registry, fallback to search `plugin_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
plugin_layer = registry.get(layer_type)
if plugin_layer is None:
raise KeyError(f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
if inspect.isclass(layer_type):
plugin_layer = layer_type
else:
# Switch registry to the target scope. If `plugin_layer` cannot be
# found in the registry, fallback to search `plugin_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
plugin_layer = registry.get(layer_type)
if plugin_layer is None:
raise KeyError(
f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
abbr = infer_abbr(plugin_layer)

assert isinstance(postfix, (int, str))
Expand Down
18 changes: 11 additions & 7 deletions mmcv/cnn/bricks/upsample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict

import torch
Expand Down Expand Up @@ -76,15 +77,18 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:

layer_type = cfg_.pop('type')

if inspect.isclass(layer_type):
upsample = layer_type
# Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
else:
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_)
return layer
5 changes: 3 additions & 2 deletions mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor,
max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]

nms_type = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type)
nms_op = nms_cfg_.pop('type', 'nms')
if isinstance(nms_op, str):
nms_op = eval(nms_op)

split_thr = nms_cfg_.pop('split_thr', 10000)
# Won't split to multiple nms nodes when exporting to onnx
Expand Down
Loading

0 comments on commit 59c1418

Please sign in to comment.