Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a computational problem of scaledSoftmax. #1096

Merged
merged 5 commits into from
Mar 8, 2021
Merged

Fix a computational problem of scaledSoftmax. #1096

merged 5 commits into from
Mar 8, 2021

Conversation

yuanzexi
Copy link
Contributor

@yuanzexi yuanzexi commented Mar 4, 2021

Problem: The original implementation results in wrong results of sum of softmax such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) are very large or even 'nan', especially for FP16 mode.
Solution: This implementation fix the computational problem such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) become correct.

Signed-off-by: yuanzexi hiyuanzexi@outlook.com

cub::Sum sum;
float threadData(-FLT_MAX);

if (lastValid >= blockDim.x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we keep this check here in case the kernel is used in scenario TPB > ld? thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you can keep it. Actually, this check confused me a lot. From my point of view, with this check, the initial value of threadData will be 0 in the scenario TPB <= lastValid but -FLT_MAX in the scenario TPB > lastValid. However, threadData will definitely be updated by threadData = max(static_cast<float>(input[idx]), threadData); in the next loop, so does this check is necessary for threads in the block? Or could we directly initialize the threadData by 0 instead of -FLT_MAX?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the threadData = max(static_cast<float>(input[idx]) is only called on condition i < lastValid, if TPB <= lastValid, we will accumulate the -FLT_MAX in next block reduced sum.
If we initialize the threadData to 0, then the threadData = max(static_cast<float>(input[idx]) is invalid if the input is negative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my point of view, maybe we just try to get the fmax here for BlockReduceMax? I'll set threadData as 0 before BlockReduceSum.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean line 347

@ttyio
Copy link
Collaborator

ttyio commented Mar 5, 2021

Adding @rajeevsrao for visibility.

Thanks @yuanzexi , good catch!

So you fixed 2 issues in this PR, right?

  • The fMax is wrong when TPB < ld and there are multiple blocks
  • The - fMax is missing in the final softmax calculation

The overall looks good to me. Could you follow https://github.com/NVIDIA/TensorRT/blob/master/CONTRIBUTING.md to reformat the code? thanks!

@yuanzexi
Copy link
Contributor Author

yuanzexi commented Mar 5, 2021

Adding @rajeevsrao for visibility.

Thanks @yuanzexi , good catch!

So you fixed 2 issues in this PR, right?

  • The fMax is wrong when TPB < ld and there are multiple blocks
  • The - fMax is missing in the final softmax calculation

The overall looks good to me. Could you follow https://github.com/NVIDIA/TensorRT/blob/master/CONTRIBUTING.md to reformat the code? thanks!

Sorry about the code format, I'll format the code later.

}
__syncthreads();

threadData = 0;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll set threadData as 0 here for all threads.

threadData = max(static_cast<float>(input[idx]), threadData);
}

const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some threadData is -FLT_MAX when the maxElem is computed here when TPB > ld

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! I see~ Thanks for answering!

yuanzexi and others added 3 commits March 5, 2021 15:18
The original implementation results in wrong results of sum of softmax such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) are very large or even 'nan'.
This implementation fix the computational problem such that the results of BERT models (128 < seq_len < 384 and seq_len > 384) become correct.

Signed-off-by: yuanzexi <hiyuanzexi@outlook.com>
Signed-off-by: yuanzexi <percyyuan@tencent.com>
Signed-off-by: yuanzexi <percyyuan@tencent.com>
Signed-off-by: yuanzexi <hiyuanzexi@outlook.com>
Signed-off-by: yuanzexi <percyyuan@tencent.com>
@yuanzexi
Copy link
Contributor Author

yuanzexi commented Mar 5, 2021

I have formated the code and the check we discussed has been added. Looking forward to your code review.

@@ -35,7 +35,6 @@ __device__ inline T rsqrt(const T& x);
template <typename T>
__device__ inline T exp(const T& x);


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also revert the changes of blank lines, this could help us to do the integration between public repo and internal repo, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! I have reverted them. Looking forward to your review.

Signed-off-by: yuanzexi <percyyuan@tencent.com>
@@ -326,7 +326,7 @@ template <typename T, unsigned TPB>
__device__ inline void scaledSoftmax(
const int ld, const int lastValid, const float rsqrtHeadSize, const T* input, T* output)
{

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also remove the blanks in this line? thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

@@ -343,10 +343,11 @@ __device__ inline void scaledSoftmax(
{
threadData = 0;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

Signed-off-by: yuanzexi <percyyuan@tencent.com>
@ttyio
Copy link
Collaborator

ttyio commented Mar 5, 2021

LGTM, thanks @yuanzexi

Assign to @rajeevsrao , thanks!

@rajeevsrao
Copy link
Collaborator

Thanks for the fix @yuanzexi - will test internally and integrate.

@rajeevsrao rajeevsrao added bug Plugins Issues when using TensorRT plugins labels Mar 5, 2021
@rajeevsrao rajeevsrao merged commit 8c10371 into NVIDIA:master Mar 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Plugins Issues when using TensorRT plugins
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants