Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASFF - Learning Spatial Fusion for Single-Shot Object Detection - 63% mAP@0.5 with 45.5FPS #4382

Closed
Kyuuki93 opened this issue Nov 26, 2019 · 140 comments

Comments

@Kyuuki93
Copy link

Learning Spatial Fusion for Single-Shot Object Detection

image

image

image

@AlexeyAB it's seems worth to take a look

@AlexeyAB AlexeyAB added the want enhancement Want to improve accuracy, speed or functionality label Nov 26, 2019
@AlexeyAB
Copy link
Owner

AlexeyAB commented Nov 26, 2019

ASFF significantly improves the box AP from 38.8% to 40.6% as shown in Table 3.

image


Also there are used:

  1. BoF (MixUp, ...) - +4.2 mAP@0.5...0.95, but +0 mAP@0.5 and +5.6% AP@70: BoF (Bag of Freebies) - Visually Coherent Image Mixup ~+4 AP@[.5, .95] #3272

  2. MegDet: A Large Mini-Batch Object Detector (synchronized batch normalization technique) - mAP 52.5%) to COCO 2017 Challenge, where we won the 1st place of Detection task: https://arxiv.org/abs/1711.07240v4 - issue: Beta: Using CPU-RAM instead of GPU-VRAM for large Mini_batch=32 - 128 #4386

  3. Dropblock + Receptive field block gives +1.7% AP@0.5...0.95

  4. So ASFF gives only +1.8% AP@0.5...0.95 and 1.5% AP@0.5 and 2.5% AP@.07

  5. cosine learning rate: Implement stochastic gradient descent with warm restarts #2651

@Kyuuki93
Copy link
Author

Kyuuki93 commented Nov 27, 2019

This paper is a bit confusing, so I took a look at his code, his code using conv_bn_leakyReLU for the level_weights instead of this formula before Softmax
image

In shortly, ASFF mapping the inputsx0, x1, x2 of yolo0, yolo1, yolo2 to each other to enhance the detection, but I still wonder, which layers output respond to x0, x1, x2

@AlexeyAB
Copy link
Owner

@Kyuuki93

his code using conv_bn_leakyReLU for the level_weights instead of this formula before Softmax

Can you provide link to these lines of code?

@Kyuuki93
Copy link
Author

https://github.com/ruinmessi/ASFF/blob/master/models/network_blocks.py

calc weights:
image

weights func:
image

add_conv func:
image

@AlexeyAB
Copy link
Owner

@Kyuuki93

This paper is a bit confusing, so I took a look at his code, his code using conv_bn_leakyReLU for the level_weights instead of this formula before Softmax
image

This formula seems to be softmax a = exp(x1) / (exp(x1) + exp(x2) + exp(x3)) https://en.wikipedia.org/wiki/Softmax_function

I added fixes to implement ASFF and BiFPN (from EfficientDet): #3772 (comment)


In shortly, ASFF mapping the inputsx0, x1, x2 of yolo0, yolo1, yolo2 to each other to enhance the detection, but I still wonder, which layers output respond to x0, x1, x2?

It seems layers: 17, 24, 32

https://github.com/ruinmessi/ASFF/blob/c74e08591b2756e5f773892628dd9a6d605f4b77/models/yolov3_asff.py#L142

https://github.com/ruinmessi/ASFF/blob/c74e08591b2756e5f773892628dd9a6d605f4b77/models/yolov3_asff.py#L129

@AlexeyAB AlexeyAB changed the title Learning Spatial Fusion for Single-Shot Object Detection - 63% mAP@0.5 with 45.5FPS ASFF - Learning Spatial Fusion for Single-Shot Object Detection - 63% mAP@0.5 with 45.5FPS Nov 28, 2019
@isgursoy
Copy link

waiting for improvements, good things happening here

@Kyuuki93
Copy link
Author

This formula seems to be softmax a = exp(x1) / (exp(x1) + exp(x2) + exp(x3)) https://en.wikipedia.org/wiki/Softmax_function

Yeah, I got it, his fusion was finished by 1x1 conv, softmax and sum.

I added fixes to implement ASFF and BiFPN (from EfficientDet): #3772 (comment)

I will try to implement ASFF, BiFPN module and run some tests

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 3, 2019

For up-sampling, we first apply a 1x1 convolution layer to compress the number f channels of features to that in level l, and then upscale the resolutions respectively with interpolation.

@AlexeyAB How to implement this upscale in .cfg file?

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 3, 2019

@Kyuuki93 [upsample] layer with stride=2 or stride=4

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 4, 2019

   layer   filters  size/strd(dil)      input                output
   0 conv     32       3 x 3/ 1    416 x 416 x   3 ->  416 x 416 x  32 0.299 BF
   1 conv     64       3 x 3/ 2    416 x 416 x  32 ->  208 x 208 x  64 1.595 BF
   2 conv     32       1 x 1/ 1    208 x 208 x  64 ->  208 x 208 x  32 0.177 BF
   3 conv     64       3 x 3/ 1    208 x 208 x  32 ->  208 x 208 x  64 1.595 BF
   4 Shortcut Layer: 1
   5 conv    128       3 x 3/ 2    208 x 208 x  64 ->  104 x 104 x 128 1.595 BF
   6 conv     64       1 x 1/ 1    104 x 104 x 128 ->  104 x 104 x  64 0.177 BF
   7 conv    128       3 x 3/ 1    104 x 104 x  64 ->  104 x 104 x 128 1.595 BF
   8 Shortcut Layer: 5
   9 conv     64       1 x 1/ 1    104 x 104 x 128 ->  104 x 104 x  64 0.177 BF
  10 conv    128       3 x 3/ 1    104 x 104 x  64 ->  104 x 104 x 128 1.595 BF
  11 Shortcut Layer: 8
  12 conv    256       3 x 3/ 2    104 x 104 x 128 ->   52 x  52 x 256 1.595 BF
  13 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  14 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  15 Shortcut Layer: 12
  16 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  17 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  18 Shortcut Layer: 15
  19 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  20 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  21 Shortcut Layer: 18
  22 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  23 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  24 Shortcut Layer: 21
  25 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  26 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  27 Shortcut Layer: 24
  28 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  29 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  30 Shortcut Layer: 27
  31 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  32 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  33 Shortcut Layer: 30
  34 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
  35 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
  36 Shortcut Layer: 33
  37 conv    512       3 x 3/ 2     52 x  52 x 256 ->   26 x  26 x 512 1.595 BF
  38 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  39 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  40 Shortcut Layer: 37
  41 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  42 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  43 Shortcut Layer: 40
  44 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  45 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  46 Shortcut Layer: 43
  47 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  48 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  49 Shortcut Layer: 46
  50 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  51 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  52 Shortcut Layer: 49
  53 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  54 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  55 Shortcut Layer: 52
  56 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  57 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  58 Shortcut Layer: 55
  59 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  60 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  61 Shortcut Layer: 58
  62 conv   1024       3 x 3/ 2     26 x  26 x 512 ->   13 x  13 x1024 1.595 BF
  63 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  64 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  65 Shortcut Layer: 62
  66 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  67 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  68 Shortcut Layer: 65
  69 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  70 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  71 Shortcut Layer: 68
  72 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  73 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  74 Shortcut Layer: 71
  75 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  76 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  77 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
  78 max                5x 5/ 1     13 x  13 x 512 ->   13 x  13 x 512 0.002 BF
  79 route  77 		                           ->   13 x  13 x 512 
  80 max                9x 9/ 1     13 x  13 x 512 ->   13 x  13 x 512 0.007 BF
  81 route  77 		                           ->   13 x  13 x 512 
  82 max               13x13/ 1     13 x  13 x 512 ->   13 x  13 x 512 0.015 BF
  83 route  82 80 78 77 	                   ->   13 x  13 x2048 
# END SPP #
  84 conv    512       1 x 1/ 1     13 x  13 x2048 ->   13 x  13 x 512 0.354 BF
  85 conv   1024       3 x 3/ 1     13 x  13 x 512 ->   13 x  13 x1024 1.595 BF
  86 conv    512       1 x 1/ 1     13 x  13 x1024 ->   13 x  13 x 512 0.177 BF
# A(/32 Feature Map) #
  87 conv    256       1 x 1/ 1     13 x  13 x 512 ->   13 x  13 x 256 0.044 BF
  88 upsample                 2x    13 x  13 x 256 ->   26 x  26 x 256
# A -> B # 
  89 route  86 		                           ->   13 x  13 x 512 
  90 conv    128       1 x 1/ 1     13 x  13 x 512 ->   13 x  13 x 128 0.022 BF
  91 upsample                 4x    13 x  13 x 128 ->   52 x  52 x 128
# A -> C #
  92 route  86 		                           ->   13 x  13 x512
  93 conv    256       1 x 1/ 1     13 x  13 x512 ->   13 x  13 x 256 0.044 BF
  94 upsample                 2x    13 x  13 x 256 ->   26 x  26 x 256
  95 route  94 61 	                           ->   26 x  26 x 768 
  96 conv    256       1 x 1/ 1     26 x  26 x 768 ->   26 x  26 x 256 0.266 BF
  97 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
  98 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
  99 conv    512       3 x 3/ 1     26 x  26 x 256 ->   26 x  26 x 512 1.595 BF
 100 conv    256       1 x 1/ 1     26 x  26 x 512 ->   26 x  26 x 256 0.177 BF
# B(/16 Feature Map) #
 101 conv    512       3 x 3/ 2     26 x  26 x 256 ->   13 x  13 x 512 0.399 BF
# B -> A #
 102 route  100 		                           ->   26 x  26 x 256 
 103 conv    128       1 x 1/ 1     26 x  26 x 256 ->   26 x  26 x 128 0.044 BF
 104 upsample                 2x    26 x  26 x 128 ->   52 x  52 x 128
# B -> C #
 105 route  100 		                           ->   26 x  26 x 256 
 106 conv    128       1 x 1/ 1     26 x  26 x 256 ->   26 x  26 x 128 0.044 BF
 107 upsample                 2x    26 x  26 x 128 ->   52 x  52 x 128
 108 route  107 36 	                           ->   52 x  52 x 384 
 109 conv    128       1 x 1/ 1     52 x  52 x 384 ->   52 x  52 x 128 0.266 BF
 110 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
 111 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
 112 conv    256       3 x 3/ 1     52 x  52 x 128 ->   52 x  52 x 256 1.595 BF
 113 conv    128       1 x 1/ 1     52 x  52 x 256 ->   52 x  52 x 128 0.177 BF
# C(/8 Feature Map) #
 114 max                2x 2/ 2     52 x  52 x 128 ->   26 x  26 x 128 0.000 BF
 115 conv    512       3 x 3/ 2     26 x  26 x 128 ->   13 x  13 x 512 0.199 BF
# C -> A #
 116 route  113 		                           ->   52 x  52 x 128 
 117 conv    256       3 x 3/ 2     52 x  52 x 128 ->   26 x  26 x 256 0.399 BF
# C -> B #
 118 route  86 101 115 	                           ->   13 x  13 x1536 
 119 conv      3       1 x 1/ 1     13 x  13 x1536 ->   13 x  13 x   3 0.002 BF
 120 route  119 		                       0/3 ->   13 x  13 x   1 
 121 scale Layer: 86
darknet: ./src/scale_channels_layer.c:23: make_scale_channels_layer: Assertion `l.out_c == l.c' failed.
Aborted (core dumped)

@AlexeyAB I created a asff.cfg based yolov3-spp.cfg, there is a error seems layer-86 is 13x13x512 and layer-119 e.g. alpha is 13x13x1, in [scale_channels] those layers output should be same?

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 4, 2019

@Kyuuki93 It seems I fixed it: 5ddf9c7#diff-35a105a0ce468de87dbd554c901a45eeR23

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 6, 2019

[route]
layers=22,33,44 # 3-layers which are already resized to the same WxHxC

[convolutional]
stride=1
size=1
filters=3
activation=normalize_channels # ReLU is integrated to activation=normalize_channels

[route]
layers=-1
group_id=0
groups=3

[scale_channels]
from=22
scale_wh=1

[route]
layers=-3
group_id=1
groups=3

[scale_channels]
from=33
scale_wh=1

[route]
layers=-5
group_id=2
groups=3

[scale_channels]
from=44
scale_wh=1

[shortcut]
from=-3
activation=linear

[shortcut]
from=-6
activation=linear

@AlexeyAB
In your ASFF-like module, what exactly activation = normalize_channels do?

If activation = normalize_channels use relu to calculate gradients,
I think it should be activation = linear and use another softmax for (x1, x2, x3), to mach this formula alpha = exp(x1) / (exp(x1) + exp(x2) + exp(x3)), Or activation = softmax for SoftmaxBackward?

https://github.com/ruinmessi/ASFF/blob/f7814211b1fd1e6cde5e144503796f4676933667/models/network_blocks.py#L242

levels_weight = F.softmax(levels_weight, dim=1)
levels_weight.shape was torch.Size([1,3,13,13])

Is 'activation = normalize_channels' same with this F.softmax ?

If activation = normalize_channels actually excuse this code, normalize_channels with relu function, negative value was removed,

darknet/src/activations.c

Lines 151 to 177 in 9bb3c53

void activate_array_normalize_channels(float *x, const int n, int batch, int channels, int wh_step, float *output)
{
int size = n / channels;
int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
int wh_i = i % wh_step;
int b = i / wh_step;
const float eps = 0.0001;
if (i < size) {
float sum = eps;
int k;
for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels];
if (val > 0) sum += val;
}
for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels];
if (val > 0) val = val / sum;
else val = 0;
output[wh_i + k * wh_step + b*wh_step*channels] = val;
}
}
}
}

maybe this result got a explain

Model chart cfg
spp,mse yolov3-spp-chart yolov3-spp.cfg.txt
spp,mse,asff chart yolov3-spp-asff.cfg.txt

I think the normalization with constraints channels_sum() = 1 was crucial, which indicate objects belongs to which ASFF feature.

And this ASFF module have a little different with your example, instead of

[route]
layers = 22,33,44# 3-layers which are already resized to the same WxHxC

...

use

[route]
layers = 22

[convolutional]
batch_normalize=1
size=1
stride=1
filters=8
activation=leaky

[route]
layers = 33

[convolutional]
batch_normalize=1
size=1
stride=1
filters=8
activation=leaky

[route]
layers = 44

[convolutional]
batch_normalize=1
size=1
stride=1
filters=8
activation=leaky

[route]
layers = -1,-3,-5

[convolutional]
stride=1
size=1
filters=3
activation= normalize_channels

...

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 6, 2019

@Kyuuki93

I think the normalization with constraints channels_sum() = 1 was crucial, which indicate objects belongs to which ASFF feature.

What do you mean?

And this ASFF module have a little different with your example, instead of

Why?


In your ASFF-like module, what exactly activation = normalize_channels do?

If activation = normalize_channels use relu to calculate gradients,
I think it should be activation = linear and use another softmax for (x1, x2, x3), to mach this formula alpha = exp(x1) / (exp(x1) + exp(x2) + exp(x3)), Or activation = softmax for SoftmaxBackward?

There is in the normalize_channels implemented Fast Normalized Fusion that should have the same Accuracy but faster Speed than SoftMax across channels, that is used in BiFPN for EfficientDet: #4346

Later I will add activation=normalize_channels_softmax

image

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 6, 2019

I think the normalization with constraints channels_sum() = 1 was crucial, which indicate objects belongs to which ASFF feature.

What do you mean?

Sorry, let me clear,

alpha(i,j) + beta(i,j) + gamma(i,j) = 1,
 and alpha(i,j)> 0, beta(i,j)>0, gamma(i,j)>0

In normalize_channels, maybe result from this code:

if (val > 0) val = val / sum; 
else val = 0;

many alpha or beta, gamma were set to 0, so relu gradients was 0 too, so gradients were vanished at very beginning, and in this way, training doesn't work properly, e.g. after 25k iters, best mAP@0.5 just 10.41%, in the training, the value of Obj: were very hard to increase.

And this ASFF module have a little different with your example, instead of

Why?

I checked author's model, layers 22,33,44 were never concat, I just implemented his network structure. In his model, the coefficients were calculate from layers 22,33,44 separately, and channels changes like

512 -> 8
512 -> 8  (cat to) 24 -> 3 
512 -> 8

instead of
512 -> 3

There is in the normalize_channels implemented Fast Normalized Fusion that should have the same Accuracy but faster Speed than SoftMax across channels, that is used in BiFPN for EfficientDet: #4346

I will try to find why BiFPN can work with relu style normalize_channels but ASFF can not, I have a thought, just let me check it out

Later I will add activation=normalize_channels_softmax

I will take another test then

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 6, 2019

@Kyuuki93

I checked author's model, layers 22,33,44 were never concat, I just implemented his network structure.

You have done right. I have not yet verified the entire cfg file as a whole.

Here we are not talking about layers with indices exactly 22, 33, 44. This is just an example.
This means that already some layers with indicies XX,YY,ZZ are resized to the same WxHx8. It is assumed here that the layers are already applied: conv_stride_2, maxpool_sride_2, upsample_stride_2 and 4. And then applied conv-layer filters=8.
And these 3 layers with size WxHx8 will be concatenated: https://github.com/ruinmessi/ASFF/blob/master/models/network_blocks.py#L240

That's how you did it.


In normalize_channels, maybe result from this code:

if (val > 0) val = val / sum;
else val = 0;
many alpha or beta, gamma were set to 0, so relu gradients was 0 too, so gradients were vanished at very beginning, and in this way, training doesn't work properly, e.g. after 25k iters, best mAP@0.5 just 10.41%, in the training, the value of Obj: were very hard to increase.

Yes, for one image - some outputs( alpha or beta, gamma) will have zeros, and for another image - other outputs( alpha or beta, gamma) will have zeros. There will not be dead neurons in Yolo, since all other layers use leaky-ReLU rather than ReLU.

This is a common problem for ReLU, calls dead neurons. https://datascience.stackexchange.com/questions/5706/what-is-the-dying-relu-problem-in-neural-networks
This applies to all modern neural networks that use RELU: MobileNet v1, ResNet-101, ...
The Leaky-ReLU, Swish or Mish solves this problem.

There will be dead neurons problem only if at least 2 conv-layers with ReLU in a row, go one after another. So output of conv-1 will be always >=0, so both input and output of conv-2 will be always >=0 In this case, since input of conv-2 is alwyas >=0, then if weights[i] < 0 then output of ReLU will be always 0 and Gradient will be always 0 - so there will be dead neurons, this weights[i]<0 will never be changed.

But if conv-1 layer has leak-ReLU (as in Yolo) or Swish or Mish activation, then input of conv-2 can be >0 or <0, then regardless of weights[i] (if weights[i] != 0) the Gradient will not be always == 0, and this weights[i]<0 will be changed sometime.

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 6, 2019

@Kyuuki93

Also you can try to use

[convolutional]
stride=1
size=1
filters=3
activation=logistic

instead of

[convolutional]
stride=1
size=1
filters=3
activation=normalize_channels

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 6, 2019

@Kyuuki93

I added [convolutional] activation=normalize_channels_softmax
Check whether there are bugs: c9c745c and 4f52ba1


Page 4: https://arxiv.org/pdf/1911.09516v2.pdf

image

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 7, 2019

Here we are not talking about layers with indices exactly 22, 33, 44. This is just an example.

Yes, I aware that layers 22 exactly is layers 86 in darknet's yolov3-spp.cfg and so on.

There will be dead neurons problem only if at least 2 conv-layers with ReLU in a row, go one after another. So output of conv-1 will be always >=0, so both input and output of conv-2 will be always >=0 In this case, since input of conv-2 is alwyas >=0, then if weights[i] < 0 then output of ReLU will be always 0 and Gradient will be always 0 - so there will be dead neurons, this weights[i]<0 will never be changed.

But if conv-1 layer has leak-ReLU (as in Yolo) or Swish or Mish activation, then input of conv-2 can be >0 or <0, then regardless of weights[i] (if weights[i] != 0) the Gradient will not be always == 0, and this weights[i]<0 will be changed sometime.

I see, so there are a little influence but should be work,
wil try activation=logistic and activation=normalize_channels_softmax, result update later

@Kyuuki93

This comment has been minimized.

@Kyuuki93

This comment has been minimized.

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 9, 2019

This one should be right for ASFF module, yolov3-spp-asff.cfg.txt,

Did you try to use your new ASFF with default [yolo] without Gaussian and without GIoU and without iou_thresh and normalizers?

like

[yolo]
mask = 0,1,2
#anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
anchors =  57, 64,  87,113, 146,110, 116,181, 184,157, 175,230, 270,196, 236,282, 322,319
classes=1
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 9, 2019

This one should be right for ASFF module, yolov3-spp-asff.cfg.txt,

Did you try to use your new ASFF with default [yolo] without Gaussian and without GIoU and without iou_thresh and normalizers?

like

[yolo]
mask = 0,1,2
#anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
anchors =  57, 64,  87,113, 146,110, 116,181, 184,157, 175,230, 270,196, 236,282, 322,319
classes=1
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1

I will try, and asff-sim results with gs, giou, iou_thresh:
|baseline | AP@.5 = 91.89% |AP@.75 = 63.53%|
|+asffsim| AP@.5 = 91.62% |AP@.75 = 63.28%|

results with mse loss will report tomorrow

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 9, 2019

@Kyuuki93
So assf-simplified doesn't improve accuracy.

Try with default [yolo]+mse without normalizers and if it doesn't work then try with default anchors.

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 10, 2019

Did you try to use your new ASFF with default [yolo] without Gaussian and without GIoU and without iou_thresh and normalizers?

Yes, ASFF-SIM with default [yolo] decrease 0.48% AP@.5, and AP@.75

|baseline | AP@.5 = 89.52% |AP@.75 = 51.72%|
|+asffsim| AP@.5 = 89.04% |AP@.75 = 51.24%|

@AlexeyAB
Copy link
Owner

@Kyuuki93

Try norm_channels or norm_channels_softmax with default [yolo] layers.
May be only [Gaussian_yolo] produces Nan with ASFF.

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 10, 2019 via email

@Kyuuki93
Copy link
Author

@AlexeyAB
image

There are any op like nn.Parameter() in this repo for implementing this wi in BiFPN?

@AlexeyAB
Copy link
Owner

@Kyuuki93

There are any op like nn.Parameter() in this repo for implementing this wi in BiFPN?

What do you mean?

If you want Unbounded fusion, then just use activation=linear instead of activation=NORM_CHAN_SOFTMAX

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 26, 2019

@AlexeyAB

For example, wi is a scalar,
P4_mid = Conv( ( w1*P4_in + w2* Resize(P5_in)) / ( w1+ w2) ),
this wi should trainable but not relevant with any feature map

In ASFF, w was calculated by feature map through a conv_layer

@AlexeyAB
Copy link
Owner

AlexeyAB commented Dec 26, 2019

@Kyuuki93

In ASFF, w was calculated by feature map through a conv_layer

Do you mean that is not so in BiFPN? https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/model.py

If you want w constant during inference, then you can do something like this:

[route]
layers = P4

[convolutional]
batch_normalize=1
filters=256
groups=256
size=1
stride=1
pad=1
activation=linear

[route]
layers = P5

[convolutional]
batch_normalize=1
filters=256
groups=256
size=1
stride=1
pad=1
activation=linear

[shortcut]
from = -3

@AlexeyAB
Copy link
Owner

For comparison of spinenet(fixed, 5 yolo-layers) and yolov3-spp(3 yolo-layers), training from scratch with same settings

Also try to compare with spinenet(fixed, 3 yolo-layers) + spp, where is added SPP-block to the P5 or P6 block: #4382 (comment)

darknet/cfg/yolov3-spp.cfg

Lines 575 to 597 in 35a3870

### SPP ###
[maxpool]
stride=1
size=5
[route]
layers=-2
[maxpool]
stride=1
size=9
[route]
layers=-4
[maxpool]
stride=1
size=13
[route]
layers=-1,-3,-5,-6
### End SPP ###

image

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 27, 2019

@AlexeyAB

Do you mean that is not so in BiFPN? https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/model.py

def build_BiFPN() here is not so, it without w
https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/model.py#L40-L93

def build_wBiFPN() here is BiFPN with w
https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/model.py#L96-L149
w was defined here, actually, we need a layer like this one
https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/layers.py#L33-L60

Maybe add a weights to [shortcut] layer is a option, also [shortcut] can take more than 2 inputs, something like

[shortcut]
from=P4, P5_up
weights_type = feature (or channel or pixel)
weights_normalizion = relu (or softmax or linear)
activation = linear

@Kyuuki93
Copy link
Author

image

Feature map on P6 only 4x4, could be too small to get useful feature?

Normally, SPP was on the middle and connected Backbone and FPN? Look like Backbone -> SPP -> FPN

But in Spinenet49, it seems all network is a FPN

@Kyuuki93
Copy link
Author

@AlexeyAB I moved spinenet related comment to its issue

@AlexeyAB
Copy link
Owner

@Kyuuki93

Feature map on P6 only 4x4, could be too small to get useful feature?

Yes, then spp should be placed in P5 (especially if you use small initiall network resolution)


[shortcut]
from=P4, P5_up
weights_type = feature (or channel or pixel)
weights_normalizion = relu (or softmax or linear)
activation = linear

Yes, or maybe just enough feature without channel or pixel

Interestingly, a fusion from BiPPN is more effective than such a fusion?

  • this is the same as w - a vector (per-channel) in BiFPN with ReLU
  • batch_normalize=1 - will do normalization to solve training instability issue
  • leaky in this block and in conv-layers L1, L2, L3, ensures that weights will be mostly positive too
[route] 
layers= L1, L2, L3    # output: W x H x 3*C

[convolutional]
batch_normalize=1
filters=3*C
groups=3*C
size=1
stride=1
pad=1
activation=leaky

[local_avgpool]
avgpool_depth = 1  # isn't implemented yet
                   #avg across C instead of WxH - Same meaning as maxpool_depth=1 in [maxpool]
out_channels = C

@AlexeyAB
Copy link
Owner

@Kyuuki93
It seems that higher ignore_thresh=0.85 is better than ignore_thresh=0.7 for your dataset. #3874 (comment)
Also turth_tresh=1.0 is good.
So for your dataset is better to use iou_tresh=1.0 (or not use it at all).

@Kyuuki93
Copy link
Author

Kyuuki93 commented Dec 31, 2019

@AlexeyAB

It seems that higher ignore_thresh=0.85 is better than ignore_thresh=0.7 for your dataset.

ignore_thresh = 0.85 got higher AP@.5 but much lower recall than ignore_thresh = 0.7

Also turth_tresh=1.0 is good.

Actually,

  • truth_tresh only worked on 1.0
  • when both truth_thresh and ignore_thresh set to 0.7 network become untrainable
  • keep ignore_thresh = 0.7, truth_thresh = 0.85`, decrease perfomance

So for your dataset is better to use iou_tresh=1.0 (or not use it at all).

What do you mean? For now, all training with iou_thresh = 0.213, do you mean set iou_thresh=1.0 when change truth_thresh or ignore_thresh?

Other one-stage methods worked on dual threshold such as ignore_thresh = 0.3 and truth_thresh = 0.5, but yolo worked on single threshold with ignore_thresh = 0.7, this also mentioned in yolov3's paper but no explain, I just wonder why

@AlexeyAB
Copy link
Owner

@Kyuuki93

Happy New Year! 🎆 🎇

What do you mean? For now, all training with iou_thresh = 0.213, do you mean set iou_thresh=1.0 when change truth_thresh or ignore_thresh?

I mean may be better to use in your dataset:

ignore_thresh = 0.7
truth_thresh = 1.0
iou_thresh=1.0

While for MS COCO may be better to use

ignore_thresh = 0.7
truth_thresh = 1.0
iou_thresh=0.213

Other one-stage methods worked on dual threshold such as ignore_thresh = 0.3 and truth_thresh = 0.5, but yolo worked on single threshold with ignore_thresh = 0.7, this also mentioned in yolov3's paper but no explain, I just wonder why

What methods do you mean?

In the original Darknet there are several issues which may degrade accuracy when using low values of ignore_thresh or truth_thresh

Initially in the original Darknet there were several wrong places which I fixed:

  1. There was used if (best_iou > l.ignore_thresh) {
    instead of if (best_match_iou > l.ignore_thresh) { https://github.com/AlexeyAB/darknet/blame/dcfeea30f195e0ca1210d580cac8b91b6beaf3f7/src/yolo_layer.c#L355
    Thus, it didn't decrease objectness even if there was an incorrect class_id.
    Now it decrease objectness if detection_class_id != truth_class_id - it improves accuracy if ignore_thresh < 1.0.

  2. When truth_thresh < 1.0 then the probability that many objects will correspond to one anchor increases. But in the original Darknet, only the last (from label-txt-file) truth-bbox affected the anchor. I fixed it - now it averages deltas of all truths which correspond to this one anchor - so truth_thresh < 1.0 and iou_thresh < 1.0 may have a better effect:

  3. Also isn't tested and isn't fixed possible bug with MSE: XY Loss delta: t^* - t* or sigmoid(t^*) - sigmoid(t*)? #4594 (comment)

@Kyuuki93
Copy link
Author

Kyuuki93 commented Jan 2, 2020

@AlexeyAB Happy New Year!

There are old cpc and new cpc results, seems use loss multiplier on all loss parts could balance classes AP slightly but not improve it

Model mAP@.5(C0/C1) mAP@.75(C0/C1)
giou 79.53%(69.24%/89.83%) 59.65%(42.96%/76.34%)
giou,cpc 79.51% (69.07%/89.96%) 59.52%(42.17%/76.87%)
giou,cpc(new) 79.44%(70.03%/88.84%) 59.61%(44.95%/74.27%)

@Kyuuki93
Copy link
Author

Kyuuki93 commented Jan 2, 2020

I mean may be better to use in your dataset: iou_thresh=1.0
While for MS COCO may be better to use: iou_thresh=0.213

Actually, in my dataset iou_thresh = 0.213 always get better results, I think use a lower iou_thresh allows several anchors can predict same object, and in original darknet use only nearest anchor to predict object which limited yolo's ability, so set a lower iou_thresh will always get better results, just need to search a suit value for a certain dataset.

What methods do you mean?

Some method use like ignore_thresh = 0.5 & truth_thresh =0.7, which means
iou < 0.5, negative sample
0.5 <iou<0.7, ignore
iou > 0.7, positive sample

I'm not sure this is exactly yolo's ignore_thresh and truth_thresh

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jan 2, 2020

@Kyuuki93

seems use loss multiplier on all loss parts could balance classes AP slightly but not improve it

Yes.

Some method use like ignore_thresh = 0.5 & truth_thresh =0.7, which means
iou < 0.5, negative sample
0.5 <iou<0.7, ignore
iou > 0.7, positive sample

Yes.

truth_thresh is very similar (but not the same) as iou_thresh, so this is strange that you get better result with higher truth_thresh and with lower iou_thresh.

For MS COCO iou_thresh=0.213 greatly increases accuracy.

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jan 7, 2020

@WongKinYiu @Kyuuki93
I am adding new version of [shortcut], now I am re-making [shortcut] layer for fast BiFPN: #4382 (comment)

so be careful by using commits from Jan 7, 2020 it may have bugs in [shortcut] layer.

Before using, try to train small model with [shortcut] layer

@WongKinYiu
Copy link
Collaborator

@AlexeyAB

Okay, thanks.

@Kyuuki93
Copy link
Author

Kyuuki93 commented Jan 7, 2020

@AlexeyAB ok, thanks

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jan 9, 2020

@Kyuuki93 @WongKinYiu I added new version of [shortcut] layer for BiFPN from EfficientDet: #4662

So you can try to make Detector with 1 or several BiFPN blocks.
And with 1 ASFF + several BiFPN blocks (yolov3-spp-asff-bifpn-db-it.cfg)

@AlexeyAB
Copy link
Owner

@nyj-ocean

[convolutional]
stride=1
size=1
filters=4
activation=normalize_channels_softmax

[route]
layers=-1
group_id=0
groups=4

...


[route]
layers=-1
group_id=3
groups=4

@AlexeyAB
Copy link
Owner

@nyj-ocean It is due that 4-th branch has 4=(2x2) more outputs. So you should use /2 less filters in conv-layers.

@nyj-ocean
Copy link

@AlexeyAB
I reduce the value of filters in some [convolutional] layers.
But the FPS of yolov3-4l+ASFF.cfg is still slow than yolov3-4l.cfg
I am waiting to see whether the final mAP of yolov3-4l+ASFF.cfg increase or not compared with yolov3-4l.cfg

But the way , i want to try ASFF + several BiFPN ,where could i download the yolov3-spp-asff-bifpn-db-it.cfg in #4382 (comment)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants