-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom Shufflenetv2 for xsmall and dense objects
- Loading branch information
Showing
3 changed files
with
326 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#Config File example | ||
save_dir: workspace/nanodet_m_5strides | ||
model: | ||
weight_averager: | ||
name: ExpMovingAverager | ||
decay: 0.9998 | ||
arch: | ||
name: NanoDetPlus | ||
detach_epoch: 10 | ||
backbone: | ||
name: ShuffleNetV2Dense | ||
model_size: 1.0x | ||
out_stages: [1,2,3,4] | ||
activation: LeakyReLU | ||
pretrain: False | ||
fpn: | ||
name: GhostPAN | ||
in_channels: [58, 116, 232, 464] | ||
out_channels: 96 | ||
kernel_size: 5 | ||
num_extra_level: 1 | ||
use_depthwise: True | ||
activation: LeakyReLU | ||
head: | ||
name: NanoDetPlusHead | ||
num_classes: 80 | ||
input_channel: 96 | ||
feat_channels: 96 | ||
stacked_convs: 2 | ||
kernel_size: 5 | ||
strides: [4, 8, 16, 32, 64] | ||
activation: LeakyReLU | ||
reg_max: 7 | ||
norm_cfg: | ||
type: BN | ||
loss: | ||
loss_qfl: | ||
name: QualityFocalLoss | ||
use_sigmoid: True | ||
beta: 2.0 | ||
loss_weight: 1.0 | ||
loss_dfl: | ||
name: DistributionFocalLoss | ||
loss_weight: 0.25 | ||
loss_bbox: | ||
name: GIoULoss | ||
loss_weight: 2.0 | ||
# Auxiliary head, only use in training time. | ||
aux_head: | ||
name: SimpleConvHead | ||
num_classes: 80 | ||
input_channel: 192 | ||
feat_channels: 192 | ||
stacked_convs: 4 | ||
strides: [4, 8, 16, 32, 64] | ||
activation: LeakyReLU | ||
reg_max: 7 | ||
|
||
class_names: &class_names ['NAME1', 'NAME2', 'NAME3', 'NAME4', '...'] #Please fill in the category names (not include background category) | ||
data: | ||
train: | ||
name: XMLDataset | ||
class_names: *class_names | ||
img_path: TRAIN_IMAGE_FOLDER #Please fill in train image path | ||
ann_path: TRAIN_XML_FOLDER #Please fill in train xml path | ||
input_size: [320,320] #[w,h] | ||
keep_ratio: True | ||
pipeline: | ||
perspective: 0.0 | ||
scale: [0.6, 1.4] | ||
stretch: [[1, 1], [1, 1]] | ||
rotation: 0 | ||
shear: 0 | ||
translate: 0.2 | ||
flip: 0.5 | ||
brightness: 0.2 | ||
contrast: [0.8, 1.2] | ||
saturation: [0.8, 1.2] | ||
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] | ||
val: | ||
name: XMLDataset | ||
class_names: *class_names | ||
img_path: VAL_IMAGE_FOLDER #Please fill in val image path | ||
ann_path: VAL_XML_FOLDER #Please fill in val xml path | ||
input_size: [320,320] #[w,h] | ||
keep_ratio: True | ||
pipeline: | ||
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] | ||
device: | ||
gpu_ids: [0] # Set like [0, 1, 2, 3] if you have multi-GPUs | ||
workers_per_gpu: 8 | ||
batchsize_per_gpu: 96 | ||
precision: 32 # set to 16 to use AMP training | ||
schedule: | ||
# resume: | ||
# load_model: YOUR_MODEL_PATH | ||
optimizer: | ||
name: AdamW | ||
lr: 0.001 | ||
weight_decay: 0.05 | ||
warmup: | ||
name: linear | ||
steps: 500 | ||
ratio: 0.0001 | ||
total_epochs: 300 | ||
lr_schedule: | ||
name: CosineAnnealingLR | ||
T_max: 300 | ||
eta_min: 0.00005 | ||
val_intervals: 10 | ||
grad_clip: 35 | ||
evaluator: | ||
name: CocoDetectionEvaluator | ||
save_key: mAP | ||
|
||
log: | ||
interval: 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.utils.model_zoo as model_zoo | ||
|
||
from ..module.activation import act_layers | ||
|
||
model_urls = { | ||
"shufflenetv2_0.5x": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", # noqa: E501 | ||
"shufflenetv2_1.0x": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", # noqa: E501 | ||
"shufflenetv2_1.5x": "https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth", # noqa: E501 | ||
"shufflenetv2_2.0x": "https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth", # noqa: E501 | ||
} | ||
|
||
|
||
def channel_shuffle(x, groups): | ||
# type: (torch.Tensor, int) -> torch.Tensor | ||
batchsize, num_channels, height, width = x.data.size() | ||
channels_per_group = num_channels // groups | ||
|
||
# reshape | ||
x = x.view(batchsize, groups, channels_per_group, height, width) | ||
|
||
x = torch.transpose(x, 1, 2).contiguous() | ||
|
||
# flatten | ||
x = x.view(batchsize, -1, height, width) | ||
|
||
return x | ||
|
||
|
||
class ShuffleV2Block(nn.Module): | ||
def __init__(self, inp, oup, stride, activation="ReLU"): | ||
super(ShuffleV2Block, self).__init__() | ||
|
||
if not (1 <= stride <= 3): | ||
raise ValueError("illegal stride value") | ||
self.stride = stride | ||
|
||
branch_features = oup // 2 | ||
assert (self.stride != 1) or (inp == branch_features << 1) | ||
|
||
if self.stride > 1: | ||
self.branch1 = nn.Sequential( | ||
self.depthwise_conv( | ||
inp, inp, kernel_size=3, stride=self.stride, padding=1 | ||
), | ||
nn.BatchNorm2d(inp), | ||
nn.Conv2d( | ||
inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False | ||
), | ||
nn.BatchNorm2d(branch_features), | ||
act_layers(activation), | ||
) | ||
else: | ||
self.branch1 = nn.Sequential() | ||
|
||
self.branch2 = nn.Sequential( | ||
nn.Conv2d( | ||
inp if (self.stride > 1) else branch_features, | ||
branch_features, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0, | ||
bias=False, | ||
), | ||
nn.BatchNorm2d(branch_features), | ||
act_layers(activation), | ||
self.depthwise_conv( | ||
branch_features, | ||
branch_features, | ||
kernel_size=3, | ||
stride=self.stride, | ||
padding=1, | ||
), | ||
nn.BatchNorm2d(branch_features), | ||
nn.Conv2d( | ||
branch_features, | ||
branch_features, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0, | ||
bias=False, | ||
), | ||
nn.BatchNorm2d(branch_features), | ||
act_layers(activation), | ||
) | ||
|
||
@staticmethod | ||
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): | ||
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) | ||
|
||
def forward(self, x): | ||
if self.stride == 1: | ||
x1, x2 = x.chunk(2, dim=1) | ||
out = torch.cat((x1, self.branch2(x2)), dim=1) | ||
else: | ||
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) | ||
|
||
out = channel_shuffle(out, 2) | ||
|
||
return out | ||
|
||
|
||
class ShuffleNetV2Dense(nn.Module): | ||
def __init__( | ||
self, | ||
model_size="1.5x", | ||
out_stages=(1, 2, 3, 4), | ||
with_last_conv=False, | ||
kernal_size=3, | ||
activation="ReLU", | ||
pretrain=True, | ||
): | ||
super(ShuffleNetV2Dense, self).__init__() | ||
# out_stages can only be a subset of (1, 2, 3, 4) | ||
assert set(out_stages).issubset((1, 2, 3, 4)) | ||
|
||
print("model size is ", model_size) | ||
|
||
self.stage_repeats = [4, 4, 8, 4] | ||
self.model_size = model_size | ||
self.out_stages = out_stages | ||
self.with_last_conv = with_last_conv | ||
self.kernal_size = kernal_size | ||
self.activation = activation | ||
if model_size == "0.5x": | ||
self._stage_out_channels = [24, 24, 48, 96, 192, 1024] | ||
elif model_size == "1.0x": | ||
self._stage_out_channels = [24, 58, 116, 232, 464, 1024] | ||
elif model_size == "1.5x": | ||
self._stage_out_channels = [24, 88, 176, 352, 704, 1024] | ||
elif model_size == "2.0x": | ||
self._stage_out_channels = [24, 122, 244, 488, 976, 2048] | ||
else: | ||
raise NotImplementedError | ||
|
||
# building first layer | ||
input_channels = 3 | ||
output_channels = self._stage_out_channels[0] | ||
self.conv1 = nn.Sequential( | ||
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), | ||
nn.BatchNorm2d(output_channels), | ||
act_layers(activation), | ||
) | ||
input_channels = output_channels | ||
|
||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
|
||
stage_names = ["stage{}".format(i) for i in [1, 2, 3, 4]] | ||
for name, repeats, output_channels in zip( | ||
stage_names, self.stage_repeats, self._stage_out_channels[1:] | ||
): | ||
seq = [ | ||
ShuffleV2Block( | ||
input_channels, output_channels, 2, activation=activation | ||
) | ||
] | ||
for i in range(repeats - 1): | ||
seq.append( | ||
ShuffleV2Block( | ||
output_channels, output_channels, 1, activation=activation | ||
) | ||
) | ||
setattr(self, name, nn.Sequential(*seq)) | ||
input_channels = output_channels | ||
output_channels = self._stage_out_channels[-1] | ||
if self.with_last_conv: | ||
conv5 = nn.Sequential( | ||
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), | ||
nn.BatchNorm2d(output_channels), | ||
act_layers(activation), | ||
) | ||
self.stage4.add_module("conv5", conv5) | ||
self._initialize_weights(pretrain) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
output = [] | ||
for i in range(1, 5): | ||
stage = getattr(self, "stage{}".format(i)) | ||
x = stage(x) | ||
if i in self.out_stages: | ||
output.append(x) | ||
return tuple(output) | ||
|
||
def _initialize_weights(self, pretrain=True): | ||
print("init weights...") | ||
for name, m in self.named_modules(): | ||
if isinstance(m, nn.Conv2d): | ||
if "first" in name: | ||
nn.init.normal_(m.weight, 0, 0.01) | ||
else: | ||
nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0.0001) | ||
nn.init.constant_(m.running_mean, 0) | ||
if pretrain: | ||
url = model_urls["shufflenetv2_{}".format(self.model_size)] | ||
if url is not None: | ||
pretrained_state_dict = model_zoo.load_url(url) | ||
print("=> loading pretrained model {}".format(url)) | ||
self.load_state_dict(pretrained_state_dict, strict=False) | ||