-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply Transformer in the backbone #2329
Comments
👋 Hello @dingyiwei, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution. If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you. If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available. For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com. RequirementsPython 3.8 or later with all requirements.txt dependencies installed, including $ pip install -r requirements.txt EnvironmentsYOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):
StatusIf this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit. |
@dingyiwei hey very cool!! The updates seem a bit faster with a bit less FLOPS... I'll have to look at this a little more in depth, but very quickly I would add that the C3TR module you placed in at the end of the backbone will primarily effect large objects, so many of the smaller objects may not be significantly affected by the change. To give a bit of background: the largest C3 modules, like the 1024-channel one you replaced are responsible for most of the model parameter count, but execute very fast (due to the small 20x20 feature grid they sample), whereas the the earliest C3 modules like So it would be interesting to see the effects of replacing the |
@dingyiwei just checked, we have a multigpu instance freeing up soon, I think we can add a few C3TR runs to the queue to experiment further. Could you submit a PR with your above updates please? |
@dingyiwei I pasted your modules into common.py and added C3TR to the modules list in yolo.py, and I can build a model successfully, but my numbers look a little different than yours: default YOLOv5s
[-1, 3, C3TR, [1024, False]], # 9
My full C3TR module (with only self.m different): class C3TR(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(C3TR, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
self.m = TransformerBlock(c_, c_, 4, n)
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) EDIT: had to add C3TR in a second spot in yolo.py, now I match your numbers.
|
@dingyiwei @glenn-jocher Applying dropout can greatly improve Transformer's performance, so I did a slight modify on the
|
Hi @Joker316701882 , actually I removed dropout at the beginning since there's no dropout in this codebase 🤣. I'll have a try now on VOC. |
Hello @dingyiwei , may I ask if you trained with multi-gpu option or single-gpu? I saw that you wrote "2 Nvidia GTX 1080Ti cards" in your first post. The reason I'm asking is that I set 2 GPU & 4GPU runs for the 5m/5l using your backbone and got an error around the 110-120th epoch. RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable). Do you perhaps have any clue about this error? I also recall that glenn was planning to do multi-gpu training as well on this branch. Could you tell me if you run into any errors as well? |
Hi @NanoCode012 , I ran my experiments by I guess the problem could be caused by nn.MultiheadAttention, according to the error message. Its def forward(self, x):
x_ = self.ln1(x)
x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x # <---- here we only use the first output
x = self.ln2(x)
x = self.fc2(self.fc1(x)) + x
return x I'm going to check it when my last experiment finished. |
Hello @dingyiwei, I see! Have you tried just using a single GPU for training instead? From my test on COCO, DP didn't actually speed up training. Maybe you could run two training instead of one :) I found an issue pytorch/pytorch#26698 which talks about the incompatibility of nn.MultiheadAttention with DDP. I will try their proposed solution below. The author there did mention that it introduced another bug, but I'll have to try to test it out. I guess we will need a PR to DDP if we decide to include the transformer in the backbone. passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel` Another note: this can introduce some overhead in DDP https://pytorch.org/docs/stable/notes/ddp.html
|
Hi @Joker316701882 , I tested dropout and dropout+act on VOC (based on yolov5s + Transformer), but it seems no obvious promotion. May I ask for your experimental results about dropout? @glenn-jocher @zhiqwang @NanoCode012 And I found a MISTAKE in my PR #2333 : in a classic Transformer layer, the 2nd LayerNorm should be placed in the 2nd residual block (as described in Joker316701882's comment), according to ViT. But I executed Fortunately so far I didn't feel any damages or benefits from the mistake, but I'm not sure how it will affect on larger models. |
Hey, @dingyiwei, |
Hi @jaqub-manuel , usually components with self-attention mechanism e.g., Non-local and GCNet, are used for extract global information. So I just put Transformer at the last part of the backbone intuitively. @glenn-jocher is trying to put Transformer in different stages of the backbone and in the head of Yolov5. Maybe his experiments could give us some ideas. |
@dingyiwei @jaqub-manuel I started an experiment run but got sidetracked earlier in the week. I discovered some important information though. It seems like the transformer block uses up a lot of memory. I created a transformer branch: And tried to train 8 models, 1 default yolov5m.yaml and then 7 transformer models. Each of the transformer models replaces C3 with C3TR in the location mentioned, i.e. only in layer 2, or only in backbone, etc. Unfortunately all of the 7 models except the layer 9 model hit CUDA OOM, so I cancelled the training to think a bit. The layers that use the least amount of CUDA memory are the largest stride layers (P5/32), like layer 9, so this may be why @dingyiwei was using it for the test. I think maybe layer 9 is then the best place to implement, as it uses less memory, and affects the whole head. So all I've really learned is that the default test @dingyiwei ran is probably the best for producing a trainable model that doesn't eat too many resources. @dingyiwei can you update the PR with a fix for the mistake in #2329 (comment), and then I'll train a YOLOv5m model side by side with the layer 9 replacement, and maybe I can try a layer 9 + P5 head replacement also. The P5 layer itself is the largest mAP contributor at 640 resolution, so its not all bad news that we can only apply the transformer to that layer to minimize memory usage. |
Hello, I finished most of my trainings (2 left) on testing the Transformer. I noted down my results in wandb. It's my first time using it, so I hope I'm doing it right. My observations were that the Transformer runs (denoted by Edit: Added table here for backup
|
Inspired by @NanoCode012 , I tried to remove both LayerNorm layers of Transformer in YOLOv5s, and got a surprise:
Will run on test-dev and upload the model later. UPDATE: Experimental results:
Here is the implementation: class TransformerLayer(nn.Module):
def __init__(self, c, num_heads):
super().__init__()
self.q = nn.Linear(c, c, bias=False)
self.k = nn.Linear(c, c, bias=False)
self.v = nn.Linear(c, c, bias=False)
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
self.fc1 = nn.Linear(c, c, bias=False)
self.fc2 = nn.Linear(c, c, bias=False)
def forward(self, x):
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
x = self.fc2(self.fc1(x)) + x
return x New model is here. |
@dingyiwei According to your posted results, the mAP@0.5 improved but mAP@.05:.95 remains unchanged. Does it mean mAP@.75 actually dropped? |
Hi @Joker316701882 , I didn't record mAP@.75 in those experiments. According to @glenn-jocher 's explanation, |
Dear @dingyiwei , |
Hi everyone, I updated the experimental results, the implementation and the trained model of |
@dingyiwei very interesting result! I think layernorm() is a pretty resource intensive operation (at least when compared to batchnorm). Did removing it reduce the training memory requirements? |
Hi @glenn-jocher , in my experiments yes. For YOLOv5s + TR, |
@dingyiwei thanks for the info, so not much of a change in memory from removing layernorm(). |
hi, all, did anyone try position embedding? It seems like the transformer helps classification rather than localization according to the results of AP@0.5 and AP@0.5:0.95. |
@dingyiwei I'm working on getting the Transformer PR #2333 merged, I merged master to bring it up to date with the latest changes, and I noticed that the TransformerLayer() module in the PR is different from your most recent in #2329 (comment), which do you think we should we use for the PR? Let me know, thanks! |
@dingyiwei also we should add a one-line comment for each of the 3 new modules that explains a bit or cites a source if you can please. I've done this with C3TR(), but left the other two up to you. Once we have these updates and decide on TransformerLayer() then I can merge the PR. Thanks! |
how train this new module. can you show me the detail about this? you train with pretrain ?or train from scratch? |
@guyiyifeurach there are no transformer pretrained weights, but you can start from the normal pretrained weights instead. To train a YOLOv5s transformer model in our Colab notebook for example: # Train YOLOv5s on COCO128 for 3 epochs
!python train.py --img 640 --batch 16 --epochs 3 --data coco128.yaml --weights yolov5s.pt --cfg yolov5s-transformer.yaml |
This dimensional operation will change the batch_size dim? I don't understand why we're doing this?
I think the right operation is:
|
Hi @qiy20 , I forgot why to write this piece of code😂. Feel free to update it if you confirm it is correct. |
@qiy20 @dingyiwei would the right simplification be this? # b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
# simplied
p = x.flatten(2).transpose(0, 2) |
@glenn-jocher I think no.. # b,c,w,h-->b,c,wh-->1,b,c,wh-->wh,b,c,1-->wh,b,c
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
# b,c,w,h-->b,c,wh-->b,wh,c
p = x.flatten(2).transpose(1, 2)
# b,c,w,h-->b,c,wh-->wh,c,b
p = x.flatten(2).transpose(0, 2) I thought my original idea was to keep c after b. An alternative is adding p = x.flatten(2).transpose(1, 2)
return self.tr(p + self.linear(p)).transpose(1, 2).reshape(b, self.c2, w, h) I'll verify it with experiments. Let me know if you get different ideas :) |
@dingyiwei ok I think I've got it. Yes are right, transpose is acting unexpectedly. I had to use permute, but this seems to result in a 2x speedup: import torch
x= torch.rand(16,3,80,40)
p1 = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
p2 = x.flatten(2).permute(2,0,1)
print(torch.allclose(p1,p2)) # True
%timeit x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
# 5.36 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit x.flatten(2).permute(2,0,1)
# 2.83 µs ± 62 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) |
@dingyiwei if |
@glenn-jocher Training time and inference time appear no difference among the current code, I ran 10 epochs for each solution with
But p = x.flatten(2).permute(2, 0, 1)
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h) |
@dingyiwei understood! Yes please submit a PR for permute(). |
@dingyiwei #5645 PR is merged, replacing multiple transpose ops with a single permute in |
Sorry for the delay. @dingyiwei is right! I ignore the arg batch_first=FLase. |
But,i have another question about the pos embeding. |
Good question😂 Indeed ViT uses 1D learnable random-generated parameters as the pos embedding. I knew more about CV but little about NLP so I felt unfamiliar with the pos embedding at that time and applied a common operation in CV - something like a residual Linear layer. Detection is different from classification. It's hard to say whether a residual layer or standalone parameters works better for the pos embedding on Yolo. I'll try to conduct experiments on this issue and post results here. |
I think the pos embedding reflects the distance between the feature points, so standalone parameters may be better, the Linear(x) doesn't contain much position information. |
@dingyiwei I have a question, why transformer block only includes encoder, not including decoder. Is the encoder more suitable for classification tasks? |
my understanding is that here the intention of adding transformer block is to get the better features (by attending to different parts of the image), which might result in better box/class predictions compared to other modules (eg. C3) |
@dingyiwei hi, I have a question, if the transformer module is added, does it mean that the previous pure CNN pre-training weights can no longer be used. |
@Him-wen Yes, you have to train the model from scratch. |
can you provide a pretrained transformer model?thx!!! |
@mx2013713828 You may want to find a outdated model here with this commit. No official pretrained models for Yolov5s-transformer. |
@dingyiwei Do you have a reference to use this kind of structure?
|
@zhangweida2080 You may want to take a look at my first few comments in this thread.
|
@dingyiwei Thank you for your reply. There is no fixed thinking about the usage in different settings.
Thanks a lot. |
@zhangweida2080 For the first 2 questions, I had to work out a way to get a better result on it in a very short time due to my personal requirements, so I built a much simpler structure than the Transformer in that paper (but it really worked on COCO anyway) and shared here. If I got more time and more resources, I would try more structures and conduct more experiments. |
🚀 Feature
Transformer is popular in NLP, and now is also applied on CV. I added
C3TR
just by replacing the sequentialself.m
inC3
with a Transformer block, which could reduce GFlOPs and make Yolo achieve a better result.Motivation
Pitch
I add 3 classes in https://github.com/dingyiwei/yolov5/blob/Transformer/models/common.py :
And I just put it as the last part of the backbone instead of a
C3
block.I conducted experiments on 2 Nvidia GTX 1080Ti cards, where
depth_multiple
andwidth_multiple
are the same as Yolov5s. Here are my experimental results withimg-size
640. For convenience I named the method in this issue as Yolov5TRs.We can see that Yolov5TRs get higher scores in mAP@0.5 with a faster speed. (I'm not sure why my results of Yolov5s are different from which shown in README. The model was downloaded from release v4.0) When
depth_multiple
andwidth_multiple
are set to larger numbers,C3TR
should be more lightweight thanC3
. Since I do not have so much time on it and my machine is not very strong, I did not run experiments on M, L and X. Maybe someone could conduct the future experiments:smile:The text was updated successfully, but these errors were encountered: