-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Support MultiScaleDeformableAttention with AMP #2541
Conversation
Should I post a new PR for mmcv 2.x versino? |
Thanks for your contribution. I will cherry-pick this commit to 2.x. |
MultiScaleDeformableAttention
Please add unit tests to cover this modification. You can refer to https://github.com/open-mmlab/mmcv/blob/master/tests/test_ops/test_deform_conv.py. |
85e0c20
to
8ed8896
Compare
@zhouzaida Sorry for the late response. I added UT. |
@nijkah Sorry for my late reply. The precision threshold is OK. |
…lab#2541) * [Enhance] Support FP16 for MSDeformAttn * [Fix] Data type mismatch * Update mmcv/ops/multi_scale_deform_attn.py * Add UT Author: nijkah <nijkah@gmail.com> * Add cuda available condition --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
…lab#2541) * [Enhance] Support FP16 for MSDeformAttn * [Fix] Data type mismatch * Update mmcv/ops/multi_scale_deform_attn.py * Add UT Author: nijkah <nijkah@gmail.com> * Add cuda available condition --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
* [Enhance] Support FP16 for MSDeformAttn * [Fix] Data type mismatch * Update mmcv/ops/multi_scale_deform_attn.py * Add UT Author: nijkah <nijkah@gmail.com> * Add cuda available condition --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
…lab#2541) * [Enhance] Support FP16 for MSDeformAttn * [Fix] Data type mismatch * Update mmcv/ops/multi_scale_deform_attn.py * Add UT Author: nijkah <nijkah@gmail.com> * Add cuda available condition --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
…lab#2541) * [Enhance] Support FP16 for MSDeformAttn * [Fix] Data type mismatch * Update mmcv/ops/multi_scale_deform_attn.py * Add UT Author: nijkah <nijkah@gmail.com> * Add cuda available condition --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
…lab#2541) * [Enhance] Support FP16 for MSDeformAttn * [Fix] Data type mismatch * Update mmcv/ops/multi_scale_deform_attn.py * Add UT Author: nijkah <nijkah@gmail.com> * Add cuda available condition --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Currently,
MultiScaleDeformableAttention
with AMP is not supported.Modification
Following https://discuss.pytorch.org/t/how-can-i-write-the-cuda-code-to-support-fp16-calculation/107181 and
DeformConv2d
,apply
AT_DISPATCH_FLOATING_TYPES_AND_HALF
and solve the type mismatch.Checklist
Before PR:
After PR: