-
Notifications
You must be signed in to change notification settings - Fork 9.5k
/
yolof_r50-c5_8xb8-1x_coco.py
116 lines (113 loc) · 3.51 KB
/
yolof_r50-c5_8xb8-1x_coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='YOLOF',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[103.530, 116.280, 123.675],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron/resnet50_caffe')),
neck=dict(
type='DilatedEncoder',
in_channels=2048,
out_channels=512,
block_mid_channels=128,
num_residual_blocks=4,
block_dilations=[2, 4, 6, 8]),
bbox_head=dict(
type='YOLOFHead',
num_classes=80,
in_channels=512,
reg_decoded_bbox=True,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[1, 2, 4, 8, 16],
strides=[32]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1., 1., 1., 1.],
add_ctr_clamp=True,
ctr_clamp=32),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='UniformAssigner', pos_ignore_thr=0.15, neg_ignore_thr=0.7),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# optimizer
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=0.0001),
paramwise_cfg=dict(
norm_decay_mult=0., custom_keys={'backbone': dict(lr_mult=1. / 3)}))
# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.00066667,
by_epoch=False,
begin=0,
end=1500),
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
train_pipeline = [
dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='RandomShift', prob=0.5, max_shift_px=32),
dict(type='PackDetInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=8, num_workers=8, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (8 samples per GPU)
auto_scale_lr = dict(base_batch_size=64)