diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu index d4894d667d..0f6e7e1a45 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu @@ -21,8 +21,8 @@ namespace { template __global__ inline void _float_to_FP8rowwise_cuda_kernel( const input_t* __restrict__ input, - const int nrows, - const int ncols, + const int64_t nrows, + const int64_t ncols, std::uint8_t* __restrict__ output, const bool forward) { constexpr float kEpsilon = 1e-20f; @@ -30,10 +30,10 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel( const int bias = forward ? 15 : 31; const float max_pos = forward ? 0.9375 : 0.875; - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned + 2 * sizeof(float); + const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int64_t output_columns = ncols_aligned + 2 * sizeof(float); - const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int64_t row = blockIdx.x * blockDim.x + threadIdx.x; if (row < nrows) { const input_t* input_row = input + row * ncols; @@ -47,7 +47,7 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel( const auto scale = max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element)); output_row_scale_bias[0] = scale; - for (std::size_t col = 0; col < ncols; ++col) { + for (int64_t col = 0; col < ncols; ++col) { output_row[col] = float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos); } @@ -57,15 +57,15 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel( template __global__ inline void _get_FP8_qparam_cuda_kernel( const input_t* __restrict__ input, - const int nrows, - const int ncols, + const int64_t nrows, + const int64_t ncols, uint8_t* __restrict__ output, float* __restrict__ range_list, const bool forward) { - const int row = (int)blockIdx.x * blockDim.y + threadIdx.y; + const int64_t row = blockIdx.x * blockDim.y + threadIdx.y; - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned + 2 * sizeof(float); + const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int64_t output_columns = ncols_aligned + 2 * sizeof(float); float max_pos; if (forward) { max_pos = 0.9375; @@ -84,7 +84,7 @@ __global__ inline void _get_FP8_qparam_cuda_kernel( if (row < nrows) { const input_t* const input_row = input + row * ncols; - for (int col = threadIdx.x; col < ncols; col += lane_width) { + for (int64_t col = threadIdx.x; col < ncols; col += lane_width) { // Get thread-local minmax. These are the smallest min and max ever seen // by this thread. maximum_element = fmaxf(maximum_element, fabs(input_row[col])); @@ -116,8 +116,8 @@ template __global__ inline void _compute_FP8_quantize_cuda_kernel( const input_t* const __restrict__ input, const float* const __restrict__ range_list, - const int nrows, - const int ncols, + const int64_t nrows, + const int64_t ncols, std::uint8_t* const __restrict__ output, const bool forward) { int ebit; @@ -133,18 +133,18 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel( max_pos = 0.875; } - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned + 2 * sizeof(float); + const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int64_t output_columns = ncols_aligned + 2 * sizeof(float); - int row = (int)blockIdx.y * blockDim.y + threadIdx.y; - const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; - const int row_incre = blockDim.y * gridDim.y; + int64_t row = blockIdx.y * blockDim.y + threadIdx.y; + const int64_t col = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t row_incre = blockDim.y * gridDim.y; for (/*row*/; row < nrows; row += row_incre) { if (col < ncols) { float* row_qparams = reinterpret_cast( output + row * output_columns + ncols_aligned); const float scale = row_qparams[0]; - const int input_idx = row * ncols + col; + const auto input_idx = row * ncols + col; uint8_t* output_addr = output + row * output_columns + col; // TODO: lift range_list into shared memory. However, when nrows is large, // it might exceed the size of shared memory.