Skip to content

Commit

Permalink
Custom Shufflenetv2 for xsmall and dense objects
Browse files Browse the repository at this point in the history
  • Loading branch information
vuongnp committed Sep 6, 2023
1 parent 226d8cc commit 4e25651
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 0 deletions.
117 changes: 117 additions & 0 deletions config/nanodet_custom_xml_5strides.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#Config File example
save_dir: workspace/nanodet_m_5strides
model:
weight_averager:
name: ExpMovingAverager
decay: 0.9998
arch:
name: NanoDetPlus
detach_epoch: 10
backbone:
name: ShuffleNetV2Dense
model_size: 1.0x
out_stages: [1,2,3,4]
activation: LeakyReLU
pretrain: False
fpn:
name: GhostPAN
in_channels: [58, 116, 232, 464]
out_channels: 96
kernel_size: 5
num_extra_level: 1
use_depthwise: True
activation: LeakyReLU
head:
name: NanoDetPlusHead
num_classes: 80
input_channel: 96
feat_channels: 96
stacked_convs: 2
kernel_size: 5
strides: [4, 8, 16, 32, 64]
activation: LeakyReLU
reg_max: 7
norm_cfg:
type: BN
loss:
loss_qfl:
name: QualityFocalLoss
use_sigmoid: True
beta: 2.0
loss_weight: 1.0
loss_dfl:
name: DistributionFocalLoss
loss_weight: 0.25
loss_bbox:
name: GIoULoss
loss_weight: 2.0
# Auxiliary head, only use in training time.
aux_head:
name: SimpleConvHead
num_classes: 80
input_channel: 192
feat_channels: 192
stacked_convs: 4
strides: [4, 8, 16, 32, 64]
activation: LeakyReLU
reg_max: 7

class_names: &class_names ['NAME1', 'NAME2', 'NAME3', 'NAME4', '...'] #Please fill in the category names (not include background category)
data:
train:
name: XMLDataset
class_names: *class_names
img_path: TRAIN_IMAGE_FOLDER #Please fill in train image path
ann_path: TRAIN_XML_FOLDER #Please fill in train xml path
input_size: [320,320] #[w,h]
keep_ratio: True
pipeline:
perspective: 0.0
scale: [0.6, 1.4]
stretch: [[1, 1], [1, 1]]
rotation: 0
shear: 0
translate: 0.2
flip: 0.5
brightness: 0.2
contrast: [0.8, 1.2]
saturation: [0.8, 1.2]
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]]
val:
name: XMLDataset
class_names: *class_names
img_path: VAL_IMAGE_FOLDER #Please fill in val image path
ann_path: VAL_XML_FOLDER #Please fill in val xml path
input_size: [320,320] #[w,h]
keep_ratio: True
pipeline:
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]]
device:
gpu_ids: [0] # Set like [0, 1, 2, 3] if you have multi-GPUs
workers_per_gpu: 8
batchsize_per_gpu: 96
precision: 32 # set to 16 to use AMP training
schedule:
# resume:
# load_model: YOUR_MODEL_PATH
optimizer:
name: AdamW
lr: 0.001
weight_decay: 0.05
warmup:
name: linear
steps: 500
ratio: 0.0001
total_epochs: 300
lr_schedule:
name: CosineAnnealingLR
T_max: 300
eta_min: 0.00005
val_intervals: 10
grad_clip: 35
evaluator:
name: CocoDetectionEvaluator
save_key: mAP

log:
interval: 10
3 changes: 3 additions & 0 deletions nanodet/model/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .repvgg import RepVGG
from .resnet import ResNet
from .shufflenetv2 import ShuffleNetV2
from .shufflenetv2_dense import ShuffleNetV2Dense
from .timm_wrapper import TIMMWrapper


Expand All @@ -43,5 +44,7 @@ def build_backbone(cfg):
return RepVGG(**backbone_cfg)
elif name == "TIMMWrapper":
return TIMMWrapper(**backbone_cfg)
elif name == "ShuffleNetV2Dense":
return ShuffleNetV2Dense(**backbone_cfg)
else:
raise NotImplementedError
206 changes: 206 additions & 0 deletions nanodet/model/backbone/shufflenetv2_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

from ..module.activation import act_layers

model_urls = {
"shufflenetv2_0.5x": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", # noqa: E501
"shufflenetv2_1.0x": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", # noqa: E501
"shufflenetv2_1.5x": "https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth", # noqa: E501
"shufflenetv2_2.0x": "https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth", # noqa: E501
}


def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups

Check warning on line 18 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L17-L18

Added lines #L17 - L18 were not covered by tests

# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)

Check warning on line 21 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L21

Added line #L21 was not covered by tests

x = torch.transpose(x, 1, 2).contiguous()

Check warning on line 23 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L23

Added line #L23 was not covered by tests

# flatten
x = x.view(batchsize, -1, height, width)

Check warning on line 26 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L26

Added line #L26 was not covered by tests

return x

Check warning on line 28 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L28

Added line #L28 was not covered by tests


class ShuffleV2Block(nn.Module):
def __init__(self, inp, oup, stride, activation="ReLU"):
super(ShuffleV2Block, self).__init__()

if not (1 <= stride <= 3):
raise ValueError("illegal stride value")

Check warning on line 36 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L36

Added line #L36 was not covered by tests
self.stride = stride

branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)

if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1
),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
)
else:
self.branch1 = nn.Sequential()

self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
)

@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)

Check warning on line 95 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L94-L95

Added lines #L94 - L95 were not covered by tests
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

Check warning on line 97 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L97

Added line #L97 was not covered by tests

out = channel_shuffle(out, 2)

Check warning on line 99 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L99

Added line #L99 was not covered by tests

return out

Check warning on line 101 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L101

Added line #L101 was not covered by tests


class ShuffleNetV2Dense(nn.Module):
def __init__(
self,
model_size="1.5x",
out_stages=(1, 2, 3, 4),
with_last_conv=False,
kernal_size=3,
activation="ReLU",
pretrain=True,
):
super(ShuffleNetV2Dense, self).__init__()
# out_stages can only be a subset of (1, 2, 3, 4)
assert set(out_stages).issubset((1, 2, 3, 4))

print("model size is ", model_size)

self.stage_repeats = [4, 4, 8, 4]
self.model_size = model_size
self.out_stages = out_stages
self.with_last_conv = with_last_conv
self.kernal_size = kernal_size
self.activation = activation
if model_size == "0.5x":
self._stage_out_channels = [24, 24, 48, 96, 192, 1024]

Check warning on line 127 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L127

Added line #L127 was not covered by tests
elif model_size == "1.0x":
self._stage_out_channels = [24, 58, 116, 232, 464, 1024]
elif model_size == "1.5x":
self._stage_out_channels = [24, 88, 176, 352, 704, 1024]

Check warning on line 131 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L131

Added line #L131 was not covered by tests
elif model_size == "2.0x":
self._stage_out_channels = [24, 122, 244, 488, 976, 2048]

Check warning on line 133 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L133

Added line #L133 was not covered by tests
else:
raise NotImplementedError

Check warning on line 135 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L135

Added line #L135 was not covered by tests

# building first layer
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
act_layers(activation),
)
input_channels = output_channels

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

stage_names = ["stage{}".format(i) for i in [1, 2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, self.stage_repeats, self._stage_out_channels[1:]
):
seq = [
ShuffleV2Block(
input_channels, output_channels, 2, activation=activation
)
]
for i in range(repeats - 1):
seq.append(
ShuffleV2Block(
output_channels, output_channels, 1, activation=activation
)
)
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
if self.with_last_conv:
conv5 = nn.Sequential(

Check warning on line 168 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L168

Added line #L168 was not covered by tests
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
act_layers(activation),
)
self.stage4.add_module("conv5", conv5)

Check warning on line 173 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L173

Added line #L173 was not covered by tests
self._initialize_weights(pretrain)

def forward(self, x):
x = self.conv1(x)
output = []

Check warning on line 178 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L177-L178

Added lines #L177 - L178 were not covered by tests
for i in range(1, 5):
stage = getattr(self, "stage{}".format(i))
x = stage(x)

Check warning on line 181 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L180-L181

Added lines #L180 - L181 were not covered by tests
if i in self.out_stages:
output.append(x)
return tuple(output)

Check warning on line 184 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L183-L184

Added lines #L183 - L184 were not covered by tests

def _initialize_weights(self, pretrain=True):
print("init weights...")
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if "first" in name:
nn.init.normal_(m.weight, 0, 0.01)

Check warning on line 191 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L191

Added line #L191 was not covered by tests
else:
nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
if m.bias is not None:
nn.init.constant_(m.bias, 0)

Check warning on line 195 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L195

Added line #L195 was not covered by tests
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0001)
nn.init.constant_(m.running_mean, 0)
if pretrain:
url = model_urls["shufflenetv2_{}".format(self.model_size)]

Check warning on line 202 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L202

Added line #L202 was not covered by tests
if url is not None:
pretrained_state_dict = model_zoo.load_url(url)
print("=> loading pretrained model {}".format(url))
self.load_state_dict(pretrained_state_dict, strict=False)

Check warning on line 206 in nanodet/model/backbone/shufflenetv2_dense.py

View check run for this annotation

Codecov / codecov/patch

nanodet/model/backbone/shufflenetv2_dense.py#L204-L206

Added lines #L204 - L206 were not covered by tests

0 comments on commit 4e25651

Please sign in to comment.