Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a computational problem of scaledSoftmax. #1096

Merged
merged 5 commits into from
Mar 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions plugin/common/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ template <typename T, unsigned TPB>
__device__ inline void scaledSoftmax(
const int ld, const int lastValid, const float rsqrtHeadSize, const T* input, T* output)
{

using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;

Expand All @@ -346,7 +345,7 @@ __device__ inline void scaledSoftmax(
for (int i = threadIdx.x; i < lastValid; i += TPB)
{
const int idx = offset + i;
threadData = input[idx];
threadData = max(static_cast<float>(input[idx]), threadData);
}

const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
Expand All @@ -356,16 +355,12 @@ __device__ inline void scaledSoftmax(
}
__syncthreads();

if (lastValid < blockDim.x)
{
if (threadIdx.x >= lastValid)
{
threadData = 0;
}
}
threadData = 0;

for (int i = threadIdx.x; i < lastValid; i += TPB)
{
threadData += exp((threadData - fMax) * w);
const int idx = offset + i;
threadData += exp((static_cast<float>(input[idx]) - fMax) * w);
}

const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
Expand All @@ -379,7 +374,7 @@ __device__ inline void scaledSoftmax(
for (int i = threadIdx.x; i < ld; i += TPB)
{
const int idx = offset + i;
const float val = (i < lastValid) ? exp(float(input[idx]) * w) * rZ : 0.f;
const float val = (i < lastValid) ? exp((static_cast<float>(input[idx]) - fMax) * w) * rZ : 0.f;
output[idx] = T(val);
}
}
Expand Down