-
Notifications
You must be signed in to change notification settings - Fork 8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ASFF - Learning Spatial Fusion for Single-Shot Object Detection - 63% mAP@0.5 with 45.5FPS #4382
Comments
ASFF significantly improves the box AP from 38.8% to 40.6% as shown in Table 3. Also there are used:
|
Can you provide link to these lines of code? |
This formula seems to be softmax I added fixes to implement ASFF and BiFPN (from EfficientDet): #3772 (comment)
It seems layers: 17, 24, 32 |
waiting for improvements, good things happening here |
Yeah, I got it, his fusion was finished by 1x1 conv, softmax and sum.
I will try to implement ASFF, BiFPN module and run some tests |
@AlexeyAB How to implement this upscale in .cfg file? |
@Kyuuki93 |
@AlexeyAB I created a asff.cfg based yolov3-spp.cfg, there is a error seems |
@Kyuuki93 It seems I fixed it: 5ddf9c7#diff-35a105a0ce468de87dbd554c901a45eeR23 |
@AlexeyAB If
Is 'activation = normalize_channels' same with this If Lines 151 to 177 in 9bb3c53
maybe this result got a explain
I think the normalization with constraints And this ASFF module have a little different with your example, instead of
use
|
What do you mean?
Why?
There is in the Later I will add |
Sorry, let me clear,
In
many
I checked author's model,
instead of
I will try to find why BiFPN can work with relu style normalize_channels but ASFF can not, I have a thought, just let me check it out
I will take another test then |
You have done right. I have not yet verified the entire cfg file as a whole. Here we are not talking about layers with indices exactly 22, 33, 44. This is just an example. That's how you did it.
Yes, for one image - some outputs( alpha or beta, gamma) will have zeros, and for another image - other outputs( alpha or beta, gamma) will have zeros. There will not be dead neurons in Yolo, since all other layers use leaky-ReLU rather than ReLU. This is a common problem for ReLU, calls dead neurons. https://datascience.stackexchange.com/questions/5706/what-is-the-dying-relu-problem-in-neural-networks There will be dead neurons problem only if at least 2 conv-layers with ReLU in a row, go one after another. So output of conv-1 will be always But if conv-1 layer has leak-ReLU (as in Yolo) or Swish or Mish activation, then input of conv-2 can be >0 or <0, then regardless of weights[i] (if weights[i] != 0) the Gradient will not be always == 0, and this |
Also you can try to use
instead of
|
I added |
Yes, I aware that
I see, so there are a little influence but should be work, |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
Did you try to use your new ASFF with default [yolo] without Gaussian and without GIoU and without iou_thresh and normalizers? like
|
I will try, and asff-sim results with results with mse loss will report tomorrow |
@Kyuuki93 Try with default [yolo]+mse without normalizers and if it doesn't work then try with default anchors. |
Yes, |baseline | AP@.5 = |
Try |
I tried, it’s same
|
There are any op like |
What do you mean? If you want |
For example, In |
Do you mean that is not so in BiFPN? https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/model.py If you want
|
Also try to compare with Lines 575 to 597 in 35a3870
|
Maybe add a
|
@AlexeyAB I moved |
Yes, then spp should be placed in P5 (especially if you use small initiall network resolution)
Yes, or maybe just enough Interestingly, a fusion from BiPPN is more effective than such a fusion?
|
@Kyuuki93 |
Actually,
What do you mean? For now, all training with Other one-stage methods worked on dual threshold such as |
Happy New Year! 🎆 🎇
I mean may be better to use in your dataset:
While for MS COCO may be better to use
What methods do you mean? In the original Darknet there are several issues which may degrade accuracy when using low values of Initially in the original Darknet there were several wrong places which I fixed:
|
@AlexeyAB Happy New Year! There are
|
Actually, in my dataset
Some method use like I'm not sure this is exactly yolo's |
Yes.
Yes.
For MS COCO |
@WongKinYiu @Kyuuki93 so be careful by using commits from Before using, try to train small model with [shortcut] layer |
Okay, thanks. |
@AlexeyAB ok, thanks |
@Kyuuki93 @WongKinYiu I added new version of So you can try to make Detector with 1 or several BiFPN blocks. |
|
@nyj-ocean It is due that 4-th branch has 4=(2x2) more outputs. So you should use /2 less filters in conv-layers. |
@AlexeyAB But the way , i want to try |
Learning Spatial Fusion for Single-Shot Object Detection
@AlexeyAB it's seems worth to take a look
The text was updated successfully, but these errors were encountered: