Skip to content

Commit

Permalink
fix softmax, math unary/binary kernel int overflow (#8472)
Browse files Browse the repository at this point in the history
* fix softmax, math unary/binary kernel int overflow

* using template IndexType for handle int32_t index in cuda kernel

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
chengtbf and mergify[bot] authored Jun 24, 2022
1 parent 26fe902 commit 3ea445a
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 66 deletions.
5 changes: 3 additions & 2 deletions oneflow/core/device/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ const int32_t kCudaWarpSize = 32;
// TODO: limit of shared memory should be different for different arch
const int32_t kCudaMaxSharedMemoryByteSize = 48 << 10;

inline int32_t BlocksNum4ThreadsNum(const int32_t n) {
inline int64_t BlocksNum4ThreadsNum(const int64_t n) {
CHECK_GT(n, 0);
return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, kCudaMaxBlocksNum);
return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock,
static_cast<int64_t>(kCudaMaxBlocksNum));
}

#define RUN_CUDA_KERNEL(func, stream, elem_cnt, ...) \
Expand Down
22 changes: 10 additions & 12 deletions oneflow/user/kernels/math_binary_elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@ namespace oneflow {
namespace {

template<template<typename> class BinaryFunctor, typename T>
__global__ void MathBinaryElementwiseForwardGpu(const int n, const T* x, const T* y, T* z) {
CUDA_1D_KERNEL_LOOP(i, n) { z[i] = BinaryFunctor<T>::Forward(x[i], y[i]); }
__global__ void MathBinaryElementwiseForwardGpu(const int64_t n, const T* x, const T* y, T* z) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { z[i] = BinaryFunctor<T>::Forward(x[i], y[i]); }
}

template<template<typename> class BinaryFunctor, typename T>
__global__ void MathBinaryElementwiseBackwardXGradGpu(const int n, const T* x, const T* y,
__global__ void MathBinaryElementwiseBackwardXGradGpu(const int64_t n, const T* x, const T* y,
const T* dz, T* dx) {
CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = BinaryFunctor<T>::BackwardXGrad(x[i], y[i], dz[i]); }
CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {
dx[i] = BinaryFunctor<T>::BackwardXGrad(x[i], y[i], dz[i]);
}
}

template<template<typename> class BinaryFunctor, typename T>
__global__ void MathBinaryElementwiseBackwardYGradGpu(const int n, const T* x, const T* y,
__global__ void MathBinaryElementwiseBackwardYGradGpu(const int64_t n, const T* x, const T* y,
const T* dz, T* dy) {
CUDA_1D_KERNEL_LOOP(i, n) { dy[i] = BinaryFunctor<T>::BackwardYGrad(x[i], y[i], dz[i]); }
CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {
dy[i] = BinaryFunctor<T>::BackwardYGrad(x[i], y[i], dz[i]);
}
}

} // namespace
Expand All @@ -53,7 +57,6 @@ class MathBinaryElementwiseGpuKernel final : public user_op::OpKernel {
const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0);
user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0);
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseForwardGpu<BinaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand All @@ -77,7 +80,6 @@ class MathBinaryElementwiseXGradGpuKernel final : public user_op::OpKernel {
const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0);
user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand All @@ -102,7 +104,6 @@ class MathBinaryElementwiseYGradGpuKernel final : public user_op::OpKernel {
const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex("dz", 0);
user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand Down Expand Up @@ -156,7 +157,6 @@ class MathBinaryElementwiseGpuHalfKernel final : public user_op::OpKernel {
const half* y = reinterpret_cast<const half*>(tensor_y->dptr<float16>());
half* z = reinterpret_cast<half*>(tensor_z->mut_dptr<float16>());
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseForwardGpu<BinaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand Down Expand Up @@ -184,7 +184,6 @@ class MathBinaryElementwiseXGradGpuHalfKernel final : public user_op::OpKernel {
const half* dz = reinterpret_cast<const half*>(tensor_dz->dptr<float16>());
half* dx = reinterpret_cast<half*>(tensor_dx->mut_dptr<float16>());
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand Down Expand Up @@ -212,7 +211,6 @@ class MathBinaryElementwiseYGradGpuHalfKernel final : public user_op::OpKernel {
const half* dz = reinterpret_cast<const half*>(tensor_dz->dptr<float16>());
half* dy = reinterpret_cast<half*>(tensor_dy->mut_dptr<float16>());
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand Down
12 changes: 4 additions & 8 deletions oneflow/user/kernels/math_unary_elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ namespace oneflow {
namespace {

template<template<typename> class UnaryFunctor, typename T>
__global__ void MathUnaryElementwiseForwardGpu(const int n, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) { y[i] = UnaryFunctor<T>::Forward(x[i]); }
__global__ void MathUnaryElementwiseForwardGpu(const int64_t n, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { y[i] = UnaryFunctor<T>::Forward(x[i]); }
}

template<template<typename> class UnaryFunctor, typename T>
__global__ void MathUnaryElementwiseBackwardGpu(const int n, const T* x, const T* dy, T* dx) {
CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = UnaryFunctor<T>::Backward(x[i], dy[i]); }
__global__ void MathUnaryElementwiseBackwardGpu(const int64_t n, const T* x, const T* dy, T* dx) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { dx[i] = UnaryFunctor<T>::Backward(x[i], dy[i]); }
}

} // namespace
Expand All @@ -50,7 +50,6 @@ class MathUnaryElementwiseGpuKernel final : public user_op::OpKernel,
const T* x = tensor_x->dptr<T>();
T* y = tensor_y->mut_dptr<T>();
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseForwardGpu<UnaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand All @@ -77,7 +76,6 @@ class MathUnaryElementwiseGradGpuKernel final : public user_op::OpKernel,
const T* dy = tensor_dy->dptr<T>();
T* dx = tensor_dx->mut_dptr<T>();
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseBackwardGpu<UnaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand Down Expand Up @@ -126,7 +124,6 @@ class MathUnaryElementwiseGpuHalfKernel final : public user_op::OpKernel,
const half* x = reinterpret_cast<const half*>(tensor_x->dptr<float16>());
half* y = reinterpret_cast<half*>(tensor_y->mut_dptr<float16>());
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseForwardGpu<UnaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand All @@ -153,7 +150,6 @@ class MathUnaryElementwiseGradGpuHalfKernel final : public user_op::OpKernel,
const half* dy = reinterpret_cast<const half*>(tensor_dy->dptr<float16>());
half* dx = reinterpret_cast<half*>(tensor_dx->mut_dptr<float16>());
int64_t n = tensor_x->shape_view().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseBackwardGpu<UnaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
Expand Down
82 changes: 54 additions & 28 deletions oneflow/user/kernels/sparse_cross_entropy_kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ template<typename T, typename K>
__global__ void ComputeEntropyGpu(const int64_t num_instances, const int64_t num_classes,
const int64_t depth, const int64_t lower_bound, const T* x,
const K* labels, T* y) {
CUDA_1D_KERNEL_LOOP(i, num_instances) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {
assert(labels[i] >= 0);
assert(labels[i] < depth);
K label = labels[i] - lower_bound;
Expand All @@ -40,7 +40,7 @@ __global__ void ComputeEntropyGpuHalf(const int64_t num_instances, const int64_t
const int64_t depth, const int64_t lower_bound, const half* x,
const K* labels, half* y) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
CUDA_1D_KERNEL_LOOP(i, num_instances) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {
assert(labels[i] >= 0);
assert(labels[i] < depth);
K label = labels[i] - lower_bound;
Expand All @@ -58,7 +58,7 @@ template<typename T, typename K>
__global__ void ComputeDiffGpu(const int64_t num_instances, const int64_t num_classes,
const int64_t depth, const int64_t lower_bound, const T* x,
const K* labels, const T* dy, T* dx) {
CUDA_1D_KERNEL_LOOP(i, num_instances) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {
assert(labels[i] >= 0);
assert(labels[i] < depth);
K label = labels[i] - lower_bound;
Expand All @@ -73,7 +73,7 @@ __global__ void ComputeDiffGpuHalf(const int64_t num_instances, const int64_t nu
const int64_t depth, const int64_t lower_bound, const half* x,
const K* labels, const half* dy, half* dx) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
CUDA_1D_KERNEL_LOOP(i, num_instances) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {
assert(labels[i] >= 0);
assert(labels[i] < depth);
K label = labels[i] - lower_bound;
Expand All @@ -88,13 +88,13 @@ __global__ void ComputeDiffGpuHalf(const int64_t num_instances, const int64_t nu
#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/
}

template<typename T, typename K>
template<typename T, typename K, typename IndexType>
__global__ void ComputeDiffWithSoftmaxGpu(const int64_t elem_cnt, const int64_t num_classes,
const int64_t depth, const int64_t lower_bound,
const T* prob, const K* labels, const T* dy, T* dx) {
CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
const int32_t row_id = i / num_classes;
const int32_t col_id = i - row_id * num_classes;
CUDA_1D_KERNEL_LOOP_T(IndexType, i, elem_cnt) {
const IndexType row_id = i / num_classes;
const IndexType col_id = i - row_id * num_classes;
assert(labels[row_id] >= 0);
assert(labels[row_id] < depth);
K label = labels[row_id] - lower_bound;
Expand All @@ -106,15 +106,16 @@ __global__ void ComputeDiffWithSoftmaxGpu(const int64_t elem_cnt, const int64_t
}
}

template<typename K>
template<typename K, typename IndexType>
__global__ void ComputeDiffWithSoftmaxGpuHalf(const int64_t elem_cnt, const int64_t num_classes,
const int64_t depth, const int64_t lower_bound,
const half* prob, const K* labels, const half* dy,
half* dx) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
const int32_t row_id = i / num_classes;
const int32_t col_id = i - row_id * num_classes;
CUDA_1D_KERNEL_LOOP_T(IndexType, i, elem_cnt) {
// NOTE(chengcheng): int division ('/') of i will reduce performance of int64_t.
const IndexType row_id = i / num_classes;
const IndexType col_id = i - row_id * num_classes;
assert(labels[row_id] >= 0);
assert(labels[row_id] < depth);
K label = labels[row_id] - lower_bound;
Expand All @@ -130,7 +131,7 @@ __global__ void ComputeDiffWithSoftmaxGpuHalf(const int64_t elem_cnt, const int6
#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/
}

template<typename K>
template<typename K, typename IndexType>
__global__ void ComputeDiffWithSoftmaxGpuHalf2(const int64_t elem_cnt, const int64_t num_classes,
const int64_t depth, const int64_t lower_bound,
const half* prob, const K* labels, const half* dy,
Expand All @@ -140,9 +141,9 @@ __global__ void ComputeDiffWithSoftmaxGpuHalf2(const int64_t elem_cnt, const int
const int64_t h2_elem_cnt = elem_cnt / 2;
const auto* prob_h2 = reinterpret_cast<const half2*>(prob);
auto* dx_h2 = reinterpret_cast<half2*>(dx);
CUDA_1D_KERNEL_LOOP(i, h2_elem_cnt) {
const int32_t row_id = i / h2_num_classes;
const int32_t h2_col_id = i - row_id * h2_num_classes;
CUDA_1D_KERNEL_LOOP_T(IndexType, i, h2_elem_cnt) {
const IndexType row_id = i / h2_num_classes;
const IndexType h2_col_id = i - row_id * h2_num_classes;
assert(labels[row_id] >= 0);
assert(labels[row_id] < depth);
K label = labels[row_id] - lower_bound;
Expand Down Expand Up @@ -183,9 +184,17 @@ struct SparseCrossEntropyKernelUtil<DeviceType::kCUDA, T, K> {
const int64_t num_classes, const int64_t depth,
const int64_t lower_bound, const T* prob, const K* labels,
const T* dy, T* dx) {
ComputeDiffWithSoftmaxGpu<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, num_classes, depth, lower_bound, prob, labels, dy, dx);
if (elem_cnt < GetMaxVal<int32_t>() / 2) {
ComputeDiffWithSoftmaxGpu<T, K, int32_t>
<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(elem_cnt, num_classes, depth,
lower_bound, prob, labels, dy, dx);
} else {
ComputeDiffWithSoftmaxGpu<T, K, int64_t>
<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(elem_cnt, num_classes, depth,
lower_bound, prob, labels, dy, dx);
}
}
};
Expand Down Expand Up @@ -215,16 +224,33 @@ struct SparseCrossEntropyKernelUtil<DeviceType::kCUDA, float16, K> {
const int64_t lower_bound, const float16* prob,
const K* labels, const float16* dy, float16* dx) {
if (num_classes % 2 == 0) {
ComputeDiffWithSoftmaxGpuHalf2<K>
<<<BlocksNum4ThreadsNum(elem_cnt / 2), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),
labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));
if (elem_cnt < GetMaxVal<int32_t>() / 2) {
ComputeDiffWithSoftmaxGpuHalf2<K, int32_t>
<<<BlocksNum4ThreadsNum(elem_cnt / 2), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),
labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));
} else {
ComputeDiffWithSoftmaxGpuHalf2<K, int64_t>
<<<BlocksNum4ThreadsNum(elem_cnt / 2), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),
labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));
}
} else {
ComputeDiffWithSoftmaxGpuHalf<K><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob), labels,
reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));
if (elem_cnt < GetMaxVal<int32_t>() / 2) {
ComputeDiffWithSoftmaxGpuHalf<K, int32_t>
<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),
labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));
} else {
ComputeDiffWithSoftmaxGpuHalf<K, int64_t>
<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),
labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ __global__ void ComputeSparseSoftmaxCrossEntropyResultGpu(const int64_t num_inst
const int64_t depth,
const int64_t lower_bound,
const K* labels, const T* prob, T* out) {
CUDA_1D_KERNEL_LOOP(i, num_instances) {
CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {
assert(labels[i] >= 0);
assert(labels[i] < depth);
K label = labels[i] - lower_bound;
Expand Down
Loading

0 comments on commit 3ea445a

Please sign in to comment.