Skip to content

Commit

Permalink
[Fix] Fix skip_layer for RF-Next (#2489)
Browse files Browse the repository at this point in the history
* judge skip_layer by fullname

* lint

* skip_layer first

* update unit test
  • Loading branch information
lzyhha authored Dec 28, 2022
1 parent 30d975a commit 935ba78
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 41 deletions.
36 changes: 20 additions & 16 deletions mmcv/cnn/rfsearch/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def estimate_and_expand(self, model: nn.Module):
module.estimate_rates()
module.expand_rates()

def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d'):
def wrap_model(self,
model: nn.Module,
search_op: str = 'Conv2d',
prefix: str = ''):
"""wrap model to support searchable conv op.
Args:
Expand All @@ -152,9 +155,18 @@ def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d'):
Defaults to 'Conv2d'.
init_rates (int, optional): Set to other initial dilation rates.
Defaults to None.
prefix (str): Prefix for function recursion. Defaults to ''.
"""
op = 'torch.nn.' + search_op
for name, module in model.named_children():
if prefix == '':
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
Expand All @@ -167,14 +179,8 @@ def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d'):
logger.info('Wrap model %s to %s.' %
(str(module), str(moduleWrap)))
setattr(model, name, moduleWrap)
elif isinstance(module, BaseConvRFSearchOp):
pass
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in name
for layer in self.config['search']['skip_layer']):
continue
self.wrap_model(module, search_op)
elif not isinstance(module, BaseConvRFSearchOp):
self.wrap_model(module, search_op, fullname)

def set_model(self,
model: nn.Module,
Expand All @@ -198,6 +204,10 @@ def set_model(self,
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
Expand All @@ -224,11 +234,5 @@ def set_model(self,
logger.info(
'Set module %s dilation as: [%d %d]' %
(fullname, module.dilation[0], module.dilation[1]))
elif isinstance(module, BaseConvRFSearchOp):
pass
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
elif not isinstance(module, BaseConvRFSearchOp):
self.set_model(module, search_op, init_rates, fullname)
63 changes: 38 additions & 25 deletions tests/test_cnn/test_rfsearch/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,36 @@

def test_rfsearchhook():

def conv(in_channels, out_channels, kernel_size, stride, padding,
dilation):
return nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)

class Model(nn.Module):

def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=1,
stride=1,
padding=0,
dilation=1)
self.conv2 = nn.Conv2d(
in_channels=2,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv3 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=(1, 3),
stride=1,
padding=(0, 1),
dilation=1)
self.stem = conv(1, 2, 3, 1, 1, 1)
self.conv0 = conv(2, 2, 3, 1, 1, 1)
self.layer0 = nn.Sequential(
conv(2, 2, 3, 1, 1, 1), conv(2, 2, 3, 1, 1, 1))
self.conv1 = conv(2, 2, 1, 1, 0, 1)
self.conv2 = conv(2, 2, 3, 1, 1, 1)
self.conv3 = conv(2, 2, (1, 3), 1, (0, 1), 1)

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
x1 = self.stem(x)
x2 = self.layer0(x1)
x3 = self.conv0(x2)
x4 = self.conv1(x3)
x5 = self.conv2(x4)
x6 = self.conv3(x5)
return x6

def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).mean(), num_samples=x.shape[0])
Expand All @@ -63,13 +63,14 @@ def train_step(self, x, optimizer, **kwargs):
mmin=1,
mmax=24,
num_branches=2,
skip_layer=['stem', 'layer1'])),
skip_layer=['stem', 'conv0', 'layer0.1'])),
)

# hook for search
rfsearchhook_search = RFSearchHook(
'search', rfsearch_cfg['config'], by_epoch=True, verbose=True)
rfsearchhook_search.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
Expand All @@ -80,6 +81,7 @@ def train_step(self, x, optimizer, **kwargs):
by_epoch=True,
verbose=True)
rfsearchhook_fixed_single_branch.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
Expand All @@ -90,14 +92,22 @@ def train_step(self, x, optimizer, **kwargs):
by_epoch=True,
verbose=True)
rfsearchhook_fixed_multi_branch.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}

def test_skip_layer():
assert not isinstance(model.stem, Conv2dRFSearchOp)
assert not isinstance(model.conv0, Conv2dRFSearchOp)
assert isinstance(model.layer0[0], Conv2dRFSearchOp)
assert not isinstance(model.layer0[1], Conv2dRFSearchOp)

# 1. test init_model() with mode of search
model = Model()
rfsearchhook_search.init_model(model)

test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
Expand All @@ -111,6 +121,7 @@ def train_step(self, x, optimizer, **kwargs):
runner.register_hook(rfsearchhook_search)
runner.run([loader], [('train', 1)])

test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
Expand Down Expand Up @@ -145,6 +156,7 @@ def train_step(self, x, optimizer, **kwargs):
model = Model()
rfsearchhook_fixed_multi_branch.init_model(model)

test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
Expand All @@ -157,6 +169,7 @@ def train_step(self, x, optimizer, **kwargs):
runner.register_hook(rfsearchhook_fixed_multi_branch)
runner.run([loader], [('train', 1)])

test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
Expand Down

0 comments on commit 935ba78

Please sign in to comment.