-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix a computational problem of scaledSoftmax. #1096
Conversation
plugin/common/common.cuh
Outdated
cub::Sum sum; | ||
float threadData(-FLT_MAX); | ||
|
||
if (lastValid >= blockDim.x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we keep this check here in case the kernel is used in scenario TPB > ld
? thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you can keep it. Actually, this check confused me a lot. From my point of view, with this check, the initial value of threadData
will be 0 in the scenario TPB <= lastValid
but -FLT_MAX in the scenario TPB > lastValid
. However, threadData
will definitely be updated by threadData = max(static_cast<float>(input[idx]), threadData);
in the next loop, so does this check is necessary for threads in the block? Or could we directly initialize the threadData
by 0 instead of -FLT_MAX
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the threadData = max(static_cast<float>(input[idx])
is only called on condition i < lastValid
, if TPB <= lastValid
, we will accumulate the -FLT_MAX
in next block reduced sum.
If we initialize the threadData to 0, then the threadData = max(static_cast<float>(input[idx])
is invalid if the input is negative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my point of view, maybe we just try to get the fmax here for BlockReduceMax
? I'll set threadData
as 0 before BlockReduceSum
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean line 347
Adding @rajeevsrao for visibility. Thanks @yuanzexi , good catch! So you fixed 2 issues in this PR, right?
The overall looks good to me. Could you follow https://github.com/NVIDIA/TensorRT/blob/master/CONTRIBUTING.md to reformat the code? thanks! |
Sorry about the code format, I'll format the code later. |
plugin/common/common.cuh
Outdated
} | ||
__syncthreads(); | ||
|
||
threadData = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll set threadData
as 0 here for all threads.
plugin/common/common.cuh
Outdated
threadData = max(static_cast<float>(input[idx]), threadData); | ||
} | ||
|
||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some threadData is -FLT_MAX when the maxElem is computed here when TPB > ld
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! I see~ Thanks for answering!
The original implementation results in wrong results of sum of softmax such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) are very large or even 'nan'. This implementation fix the computational problem such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) become correct. Signed-off-by: yuanzexi <hiyuanzexi@outlook.com> Signed-off-by: yuanzexi <percyyuan@tencent.com>
Signed-off-by: yuanzexi <percyyuan@tencent.com>
Signed-off-by: yuanzexi <hiyuanzexi@outlook.com> Signed-off-by: yuanzexi <percyyuan@tencent.com>
I have formated the code and the check we discussed has been added. Looking forward to your code review. |
plugin/common/common.cuh
Outdated
@@ -35,7 +35,6 @@ __device__ inline T rsqrt(const T& x); | |||
template <typename T> | |||
__device__ inline T exp(const T& x); | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also revert the changes of blank lines, this could help us to do the integration between public repo and internal repo, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem! I have reverted them. Looking forward to your review.
Signed-off-by: yuanzexi <percyyuan@tencent.com>
plugin/common/common.cuh
Outdated
@@ -326,7 +326,7 @@ template <typename T, unsigned TPB> | |||
__device__ inline void scaledSoftmax( | |||
const int ld, const int lastValid, const float rsqrtHeadSize, const T* input, T* output) | |||
{ | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also remove the blanks in this line? thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
plugin/common/common.cuh
Outdated
@@ -343,10 +343,11 @@ __device__ inline void scaledSoftmax( | |||
{ | |||
threadData = 0; | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
Signed-off-by: yuanzexi <percyyuan@tencent.com>
LGTM, thanks @yuanzexi Assign to @rajeevsrao , thanks! |
Thanks for the fix @yuanzexi - will test internally and integrate. |
Problem: The original implementation results in wrong results of sum of softmax such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) are very large or even 'nan', especially for FP16 mode.
Solution: This implementation fix the computational problem such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) become correct.
Signed-off-by: yuanzexi hiyuanzexi@outlook.com