diff --git a/oneflow/user/kernels/data_shuffle_kernel.cu b/oneflow/user/kernels/data_shuffle_kernel.cu index 2b50235dda1..df72470c090 100644 --- a/oneflow/user/kernels/data_shuffle_kernel.cu +++ b/oneflow/user/kernels/data_shuffle_kernel.cu @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h" #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/hash_functions.cuh" +#include "oneflow/core/cuda/elementwise.cuh" namespace oneflow { @@ -458,6 +459,391 @@ void ShuffleEmbeddings(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parall reverse_unique_cur_rank_embeddings, recv_offsets, recv_elem_cnt, received_embeddings); } +// Quantized Version. +template +void ShuffleEmbeddings(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, + int64_t parallel_num, int64_t num_ids, int64_t embedding_size, + DataType data_type, IDX* host_num_unique_matrix, + int8_t* reverse_unique_cur_rank_embeddings, int8_t* received_embeddings, + T* reverse_cur_rank_quantize_factor, T* recv_quantize_factor) { + std::vector send_offsets; + std::vector send_elem_cnt; + std::vector recv_offsets; + std::vector recv_elem_cnt; + // shuffle quantized_embedding + MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num, + &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt); + ShuffleData(cuda_stream, comm, DataType::kInt8, send_offsets, send_elem_cnt, + reverse_unique_cur_rank_embeddings, recv_offsets, recv_elem_cnt, received_embeddings); + // shuffle quantize_factor + MakeShuffleParams(host_num_unique_matrix, num_ids, /*embedding_size=*/1, parallel_id, + parallel_num, &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt); + ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt, + reverse_cur_rank_quantize_factor, recv_offsets, recv_elem_cnt, recv_quantize_factor); +} + +__device__ float RoundHalfAwayFromZero(const float x) { + float abs_val = abs(x); + float floor_val = floor(abs_val + static_cast(0.5)); + return copysignf(floor_val, x); +} + +// warp reduce version. +constexpr int32_t kWarpSize = 32; +constexpr int32_t kMaxColSize = 1024; + +template +__inline__ __device__ T WarpMaxAllReduce(T val) { + for (int32_t lane_mask = thread_group_width / 2; lane_mask > 0; lane_mask /= 2) { + val = max(val, __shfl_xor_sync(0xffffffff, val, lane_mask, thread_group_width)); + } + return val; +} + +inline cudaError_t GetWarpImplNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves, + int* num_blocks) { + int dev; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { return err; } + } + int sm_count; + { + cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { return err; } + } + int tpm; + { + cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); + if (err != cudaSuccess) { return err; } + } + *num_blocks = + std::max(1, std::min(max_blocks, sm_count * tpm / block_size * waves)); + return cudaSuccess; +} + +template +__global__ void QuantizeWarpImplKernel(const T* src, int8_t* dst, T* quantize_factor, + const int64_t rows, const int64_t cols) { + static_assert(cols_per_thread % pack_size == 0, ""); + static_assert(thread_group_width <= kWarpSize, ""); + static_assert(kWarpSize % thread_group_width == 0, ""); + constexpr int num_packs = cols_per_thread / pack_size; + assert(cols <= cols_per_thread * thread_group_width); + ComputeType buf[rows_per_access][cols_per_thread]; + const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; + const int num_global_thread_group = gridDim.x * blockDim.y; + const int lane_id = threadIdx.x; + const int64_t step = num_global_thread_group * rows_per_access; + using LoadType = cuda::elementwise::PackType; + using LoadPack = cuda::elementwise::Pack; + using StoreType = cuda::elementwise::PackType; + using StorePack = cuda::elementwise::Pack; + + for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { + ComputeType thread_abs_max[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; row_id++) { + ComputeType* row_buf = buf[row_id]; + thread_abs_max[row_id] = 0.0; +#pragma unroll + for (int pack_id = 0; pack_id < num_packs; pack_id++) { + const int pack_offset = pack_id * pack_size; + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + LoadPack load_pack; + if (!padding || col < cols) { + const int64_t load_offset = ((row + row_id) * cols + col) / pack_size; + load_pack.storage = *(reinterpret_cast(src) + load_offset); +#pragma unroll + for (int i = 0; i < pack_size; i++) { + row_buf[pack_offset + i] = static_cast(load_pack.elem[i]); + thread_abs_max[row_id] = max(thread_abs_max[row_id], abs(row_buf[pack_offset + i])); + } + } else { +#pragma unroll + for (int i = 0; i < pack_size; i++) { row_buf[pack_offset + i] = 0.0; } + } + } + } + ComputeType warp_max[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; row_id++) { + warp_max[row_id] = WarpMaxAllReduce(thread_abs_max[row_id]); + if (threadIdx.x == 0) { quantize_factor[row + row_id] = static_cast(warp_max[row_id]); } + ComputeType* row_buf = buf[row_id]; + ComputeType quantize_factor_val = static_cast(127.0) / warp_max[row_id]; +#pragma unroll + for (int col = 0; col < cols_per_thread; col++) { + row_buf[col] = RoundHalfAwayFromZero(row_buf[col] * quantize_factor_val); + } +#pragma unroll + for (int pack_id = 0; pack_id < num_packs; pack_id++) { + const int pack_offset = pack_id * pack_size; + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + StorePack store_pack; + if (!padding || col < cols) { + const int64_t store_offset = ((row + row_id) * cols + col) / pack_size; + for (int i = 0; i < pack_size; i++) { + store_pack.elem[i] = static_cast(row_buf[pack_id * pack_size + i]); + } + *(reinterpret_cast(dst) + store_offset) = store_pack.storage; + } + } + } + } +} + +template +inline cudaError_t LaunchQuantizeWarpImpl(cudaStream_t stream, const T* src, int8_t* dst, + T* quantize_factor, const int64_t rows, + const int64_t cols) { + constexpr int block_size = 128; + constexpr int waves = 32; + static_assert(block_size % thread_group_width == 0, ""); + constexpr int thread_groups_per_block = block_size / thread_group_width; + dim3 block_dim(thread_group_width, thread_groups_per_block); + const int64_t num_blocks = + (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; + int grid_dim_x = 0; + + cudaError_t err = GetWarpImplNumBlocks(block_size, num_blocks, waves, &grid_dim_x); + if (err != cudaSuccess) { return err; } + + QuantizeWarpImplKernel + <<>>(src, dst, quantize_factor, rows, cols); + return cudaPeekAtLastError(); +} + +template +inline cudaError_t DispatchQuantizeWarpImplPadding(cudaStream_t stream, const T* src, int8_t* dst, + T* quantize_factor, const int64_t rows, + const int64_t cols) { + if (cols == cols_per_thread * thread_group_width) { + return LaunchQuantizeWarpImpl(stream, src, dst, quantize_factor, rows, + cols); + } else { + return LaunchQuantizeWarpImpl(stream, src, dst, quantize_factor, rows, + cols); + } +} + +template +typename std::enable_if::type DispatchQuantizeWarpImplCols( + cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, + const int64_t cols) { + if (cols <= 0) { return cudaErrorInvalidValue; } +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + return DispatchQuantizeWarpImplPadding(stream, src, dst, \ + quantize_factor, rows, cols); \ + } else { \ + return DispatchQuantizeWarpImplPadding(stream, src, dst, \ + quantize_factor, rows, cols); \ + } \ + } + DEFINE_ONE_ELIF(1) + DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + return DispatchQuantizeWarpImplPadding( \ + stream, src, dst, quantize_factor, rows, cols); \ + } + DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(3) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(5) + DEFINE_ONE_ELIF(6) + DEFINE_ONE_ELIF(7) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(9) + DEFINE_ONE_ELIF(10) + DEFINE_ONE_ELIF(11) + DEFINE_ONE_ELIF(12) + DEFINE_ONE_ELIF(13) + DEFINE_ONE_ELIF(14) + DEFINE_ONE_ELIF(15) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(17) + DEFINE_ONE_ELIF(18) + DEFINE_ONE_ELIF(19) + DEFINE_ONE_ELIF(20) + DEFINE_ONE_ELIF(21) + DEFINE_ONE_ELIF(22) + DEFINE_ONE_ELIF(23) + DEFINE_ONE_ELIF(24) + DEFINE_ONE_ELIF(25) + DEFINE_ONE_ELIF(26) + DEFINE_ONE_ELIF(27) + DEFINE_ONE_ELIF(28) + DEFINE_ONE_ELIF(29) + DEFINE_ONE_ELIF(30) + DEFINE_ONE_ELIF(31) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF + else { + return cudaErrorInvalidValue; + } +} + +template +typename std::enable_if::type DispatchQuantizeWarpImplCols( + cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows, + const int64_t cols) { + if (cols <= 0) { return cudaErrorInvalidValue; } +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + return DispatchQuantizeWarpImplPadding(stream, src, dst, \ + quantize_factor, rows, cols); \ + } else { \ + return DispatchQuantizeWarpImplPadding(stream, src, dst, \ + quantize_factor, rows, cols); \ + } \ + } + DEFINE_ONE_ELIF(1) + DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + return DispatchQuantizeWarpImplPadding( \ + stream, src, dst, quantize_factor, rows, cols); \ + } + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(6) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(10) + DEFINE_ONE_ELIF(12) + DEFINE_ONE_ELIF(14) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(18) + DEFINE_ONE_ELIF(20) + DEFINE_ONE_ELIF(22) + DEFINE_ONE_ELIF(24) + DEFINE_ONE_ELIF(26) + DEFINE_ONE_ELIF(28) + DEFINE_ONE_ELIF(30) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF + else { + return cudaErrorInvalidValue; + } +} + +template +struct DispatchQuantizeWarpImplPackSize { + cudaError_t operator()(cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, + const int64_t rows, const int64_t cols) { + if (cols % 2 == 0) { + return DispatchQuantizeWarpImplCols(stream, src, dst, quantize_factor, + rows, cols); + } else { + return DispatchQuantizeWarpImplCols(stream, src, dst, quantize_factor, + rows, cols); + } + } +}; + +template +__global__ void DequantizeKernel(const int8_t* x, T* quantize_factor, T* out, IDX col_size, + IDX elem_cnt); + +template +__global__ void DequantizeKernel(const int8_t* x, T* quantize_factor, T* out, IDX col_size, + IDX elem_cnt) { + IDX global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = global_thread_id * pack_size; index < elem_cnt; + index += gridDim.x * blockDim.x * pack_size) { + IDX quantize_factor_idx = index / col_size; + ComputeType quantize_factor_val = static_cast(quantize_factor[quantize_factor_idx]) + / static_cast(127.0); + using LoadPackType = cuda::elementwise::PackType; + using LoadPack = cuda::elementwise::Pack; + using StorePackType = cuda::elementwise::PackType; + using StorePack = cuda::elementwise::Pack; + LoadPack load_pack{}; + StorePack store_pack{}; + load_pack.storage = *(reinterpret_cast(x) + index / pack_size); +#pragma unroll + for (int i = 0; i < pack_size; i++) { + store_pack.elem[i] = + static_cast(static_cast(load_pack.elem[i]) * quantize_factor_val); + } + *(reinterpret_cast(out) + index / pack_size) = store_pack.storage; + } +} + +template +cudaError_t DispatchDequantizeKernelPackSize(cudaStream_t stream, const int8_t* src, + T* quantize_factor, T* dst, const int64_t col_size, + const int64_t elem_cnt) { + const int64_t pack_num = elem_cnt / pack_size; + int grid_size = 0; + cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size); + if (err != cudaSuccess) { return err; } + DequantizeKernel + <<>>(src, quantize_factor, dst, col_size, + elem_cnt); + return cudaSuccess; +} + +template +inline cudaError_t LaunchDequantizeKernel(cudaStream_t stream, const int8_t* src, + T* quantize_factor, T* dst, const int64_t col_size, + const int64_t elem_cnt) { + constexpr int quantized_src_pack_size = cuda::elementwise::PackSize(); + constexpr int dst_pack_size = cuda::elementwise::PackSize(); + int launch_pack_size = std::min(quantized_src_pack_size, dst_pack_size); + if (launch_pack_size == 8 && col_size % 8 == 0) { + cudaError_t err = DispatchDequantizeKernelPackSize( + stream, src, quantize_factor, dst, col_size, elem_cnt); + if (err != cudaSuccess) { return err; } + } else if (launch_pack_size == 4 && col_size % 4 == 0) { + cudaError_t err = DispatchDequantizeKernelPackSize( + stream, src, quantize_factor, dst, col_size, elem_cnt); + if (err != cudaSuccess) { return err; } + } else if (launch_pack_size == 2 && col_size % 2 == 0) { + cudaError_t err = DispatchDequantizeKernelPackSize( + stream, src, quantize_factor, dst, col_size, elem_cnt); + if (err != cudaSuccess) { return err; } + } else { + cudaError_t err = DispatchDequantizeKernelPackSize( + stream, src, quantize_factor, dst, col_size, elem_cnt); + if (err != cudaSuccess) { return err; } + } + return cudaPeekAtLastError(); +} + +template +struct DefaultComputeType { + using type = T; +}; + +template<> +struct DefaultComputeType { + using type = float; +}; + template class EmbeddingShuffleKernel final : public user_op::OpKernel { public: @@ -484,13 +870,21 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { ctx->Tensor4ArgNameAndIndex("inverse_unique_partition_indices", 0); user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - + ncclComm_t comm = kernel_state->comm(); + using ComputeType = typename DefaultComputeType::type; const int64_t embedding_size = cur_rank_embeddings->shape().At(1); IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix(); DataType data_type = cur_rank_embeddings->data_type(); const int64_t num_ids = inverse_unique_partition_indices->shape().elem_cnt(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); + bool enable_quantized_comm_env_var = + ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); + bool enable_quantized_comm = enable_quantized_comm_env_var && (embedding_size < kMaxColSize); + if (enable_quantized_comm_env_var && !enable_quantized_comm) { + LOG(WARNING) << "Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and " + "embedding_size less equal than 1024 can use quantized communication. "; + } cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); OF_CUDA_CHECK(cudaMemcpyAsync( host_num_unique_matrix, reinterpret_cast(num_unique_matrix->dptr()), @@ -500,34 +894,121 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { for (int64_t i = 0; i < parallel_num; ++i) { cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id]; } + size_t full_elem_cnt = parallel_num * num_ids * embedding_size; + CHECK_EQ(full_elem_cnt, cur_rank_embeddings->shape().elem_cnt()); + if (!enable_quantized_comm) { + size_t reverse_unique_cur_rank_embeddings_size = + GetCudaAlignedSize(full_elem_cnt * sizeof(T)); + size_t received_embeddings_size = reverse_unique_cur_rank_embeddings_size; - CHECK_EQ(parallel_num * num_ids * embedding_size, cur_rank_embeddings->shape().elem_cnt()); - size_t reverse_unique_cur_rank_embeddings_size = - GetCudaAlignedSize(parallel_num * num_ids * embedding_size * sizeof(T)); - size_t received_embeddings_size = reverse_unique_cur_rank_embeddings_size; - T* reverse_unique_cur_rank_embeddings = reinterpret_cast(tmp_buffer->mut_dptr()); - T* received_embeddings = reinterpret_cast(tmp_buffer->mut_dptr() - + reverse_unique_cur_rank_embeddings_size); - CHECK_GE(tmp_buffer->shape().elem_cnt(), - reverse_unique_cur_rank_embeddings_size + received_embeddings_size); - - // reverse cur_rank unique - GatherKernelUtilImpl::Forward( - ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), - cur_rank_num_ids, cur_rank_embeddings->dptr(), - Shape({1, cur_rank_embeddings->shape().elem_cnt() / embedding_size, embedding_size}), - reverse_unique_cur_rank_embeddings, 0); + CHECK_GE(tmp_buffer->shape().elem_cnt(), + reverse_unique_cur_rank_embeddings_size + received_embeddings_size); - ncclComm_t comm = kernel_state->comm(); - ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, - data_type, host_num_unique_matrix, reverse_unique_cur_rank_embeddings, - received_embeddings); + T* reverse_unique_cur_rank_embeddings = reinterpret_cast(tmp_buffer->mut_dptr()); + T* received_embeddings = reinterpret_cast(tmp_buffer->mut_dptr() + + reverse_unique_cur_rank_embeddings_size); + // reverse cur_rank unique + GatherKernelUtilImpl::Forward( + ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), + cur_rank_num_ids, cur_rank_embeddings->dptr(), + Shape({1, cur_rank_embeddings->shape().elem_cnt() / embedding_size, embedding_size}), + reverse_unique_cur_rank_embeddings, 0); + + ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, + data_type, host_num_unique_matrix, reverse_unique_cur_rank_embeddings, + received_embeddings); + + // reverse unique_partition + GatherKernelUtilImpl::Forward( + ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), + inverse_unique_partition_indices->shape().elem_cnt(), received_embeddings, + Shape({1, parallel_num * num_ids, embedding_size}), embeddings->mut_dptr(), 0); + } else { + size_t reverse_unique_cur_rank_embeddings_size = + GetCudaAlignedSize(full_elem_cnt * sizeof(int8_t)); + size_t received_embeddings_size = reverse_unique_cur_rank_embeddings_size; + size_t quantize_cur_rank_embeddings_size = reverse_unique_cur_rank_embeddings_size; + size_t reverse_recv_quantize_cur_rank_embeddings_size = + reverse_unique_cur_rank_embeddings_size; + size_t cur_rank_quantize_factor_size = + GetCudaAlignedSize(cur_rank_embeddings->shape().At(0) * sizeof(T)); + size_t reverse_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size; + size_t recv_quantize_factor_size = cur_rank_quantize_factor_size; + size_t reverse_recv_quantize_factor_size = cur_rank_quantize_factor_size; + CHECK_GE(tmp_buffer->shape().elem_cnt(), + reverse_unique_cur_rank_embeddings_size + received_embeddings_size + + quantize_cur_rank_embeddings_size + + reverse_recv_quantize_cur_rank_embeddings_size + cur_rank_quantize_factor_size + + reverse_cur_rank_quantize_factor_size + recv_quantize_factor_size + + reverse_recv_quantize_factor_size); + int8_t* reverse_unique_cur_rank_embeddings = + reinterpret_cast(tmp_buffer->mut_dptr()); + int8_t* received_embeddings = reinterpret_cast( + tmp_buffer->mut_dptr() + reverse_unique_cur_rank_embeddings_size); + int8_t* quantize_cur_rank_embeddings = reinterpret_cast( + tmp_buffer->mut_dptr() + reverse_unique_cur_rank_embeddings_size + + received_embeddings_size); + int8_t* reverse_recv_quantize_cur_rank_embeddings = reinterpret_cast( + tmp_buffer->mut_dptr() + reverse_unique_cur_rank_embeddings_size + + received_embeddings_size + quantize_cur_rank_embeddings_size); + T* cur_rank_quantize_factor = reinterpret_cast( + tmp_buffer->mut_dptr() + reverse_unique_cur_rank_embeddings_size + + received_embeddings_size + quantize_cur_rank_embeddings_size + + reverse_recv_quantize_cur_rank_embeddings_size); + T* reverse_cur_rank_quantize_factor = reinterpret_cast( + tmp_buffer->mut_dptr() + reverse_unique_cur_rank_embeddings_size + + received_embeddings_size + quantize_cur_rank_embeddings_size + + reverse_recv_quantize_cur_rank_embeddings_size + cur_rank_quantize_factor_size); + T* recv_quantize_factor = reinterpret_cast( + tmp_buffer->mut_dptr() + reverse_unique_cur_rank_embeddings_size + + received_embeddings_size + quantize_cur_rank_embeddings_size + + reverse_recv_quantize_cur_rank_embeddings_size + cur_rank_quantize_factor_size + + reverse_cur_rank_quantize_factor_size); + T* reverse_recv_quantize_factor = reinterpret_cast( + tmp_buffer->mut_dptr() + reverse_unique_cur_rank_embeddings_size + + received_embeddings_size + quantize_cur_rank_embeddings_size + + reverse_recv_quantize_cur_rank_embeddings_size + cur_rank_quantize_factor_size + + reverse_cur_rank_quantize_factor_size + recv_quantize_factor_size); + DispatchQuantizeWarpImplPackSize()( + cuda_stream, cur_rank_embeddings->dptr(), quantize_cur_rank_embeddings, + cur_rank_quantize_factor, cur_rank_num_ids, embedding_size); + // reverse cur_rank embedding unique + GatherKernelUtilImpl::Forward( + ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), + cur_rank_num_ids, quantize_cur_rank_embeddings, + Shape({1, cur_rank_embeddings->shape().elem_cnt() / embedding_size, embedding_size}), + reverse_unique_cur_rank_embeddings, 0); + + // reverse cur_rank quantize factor unique + GatherKernelUtilImpl::Forward( + ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), + cur_rank_num_ids, cur_rank_quantize_factor, + Shape({1, cur_rank_embeddings->shape().elem_cnt() / embedding_size, 1}), + reverse_cur_rank_quantize_factor, 0); + + ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, + data_type, host_num_unique_matrix, reverse_unique_cur_rank_embeddings, + received_embeddings, reverse_cur_rank_quantize_factor, + recv_quantize_factor); + + // reverse unique_partition + GatherKernelUtilImpl::Forward( + ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), + inverse_unique_partition_indices->shape().elem_cnt(), received_embeddings, + Shape({1, parallel_num * num_ids, embedding_size}), + reverse_recv_quantize_cur_rank_embeddings, 0); + + GatherKernelUtilImpl::Forward( + ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), + inverse_unique_partition_indices->shape().elem_cnt(), recv_quantize_factor, + Shape({1, parallel_num * num_ids, 1}), reverse_recv_quantize_factor, 0); - // reverse unique_partition - GatherKernelUtilImpl::Forward( - ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), - inverse_unique_partition_indices->shape().elem_cnt(), received_embeddings, - Shape({1, parallel_num * num_ids, embedding_size}), embeddings->mut_dptr(), 0); + int32_t dequantize_row_size = inverse_unique_partition_indices->shape().elem_cnt(); + IDX dequantize_elem_cnt = dequantize_row_size * embedding_size; + OF_CUDA_CHECK((LaunchDequantizeKernel( + cuda_stream, reverse_recv_quantize_cur_rank_embeddings, reverse_recv_quantize_factor, + embeddings->mut_dptr(), embedding_size, dequantize_elem_cnt))); + } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -543,15 +1024,39 @@ class EmbeddingShuffleKernel final : public user_op::OpKernel { .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& cur_rank_embeddings = \ ctx->InputTensorDesc("cur_rank_embeddings", 0); \ - const user_op::TensorDesc& embeddings = ctx->InputTensorDesc("embeddings", 0); \ - size_t reverse_cur_rank_embeddings_size = GetCudaAlignedSize( \ - cur_rank_embeddings.shape().elem_cnt() * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ - size_t recv_unique_embeddings = reverse_cur_rank_embeddings_size; \ - return reverse_cur_rank_embeddings_size + recv_unique_embeddings; \ + bool enable_quantized_comm = \ + ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false) \ + && (cur_rank_embeddings.shape().At(1) < kMaxColSize); \ + size_t tmp_size = 0; \ + if (!enable_quantized_comm) { \ + size_t reverse_cur_rank_embeddings_size = GetCudaAlignedSize( \ + cur_rank_embeddings.shape().elem_cnt() * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ + size_t recv_unique_embeddings_size = reverse_cur_rank_embeddings_size; \ + tmp_size = reverse_cur_rank_embeddings_size + recv_unique_embeddings_size; \ + } else { \ + size_t total_elem_cnt = cur_rank_embeddings.shape().elem_cnt(); \ + size_t reverse_cur_rank_embeddings_size = \ + GetCudaAlignedSize(total_elem_cnt * sizeof(int8_t)); \ + size_t recv_unique_embeddings = reverse_cur_rank_embeddings_size; \ + size_t quantize_cur_rank_embeddings_size = reverse_cur_rank_embeddings_size; \ + size_t reverse_recv_quantize_cur_rank_embeddings_size = \ + reverse_cur_rank_embeddings_size; \ + size_t cur_rank_quantize_factor_size = GetCudaAlignedSize( \ + cur_rank_embeddings.shape().At(0) * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ + size_t reverse_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size; \ + size_t recv_quantize_factor_size = cur_rank_quantize_factor_size; \ + size_t reverse_recv_quantize_factor_size = cur_rank_quantize_factor_size; \ + tmp_size = reverse_cur_rank_embeddings_size + recv_unique_embeddings \ + + quantize_cur_rank_embeddings_size \ + + reverse_recv_quantize_cur_rank_embeddings_size \ + + cur_rank_quantize_factor_size + reverse_cur_rank_quantize_factor_size \ + + recv_quantize_factor_size + reverse_recv_quantize_factor_size; \ + } \ + return tmp_size; \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_SHUFFLE_KERNEL, - FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) + FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) template void ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, @@ -569,6 +1074,31 @@ void ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t pa received_embeddings_grad); } +// Quantize Version. +template +void ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id, + int64_t parallel_num, int64_t num_ids, int64_t embedding_size, + DataType data_type, IDX* host_num_unique_matrix, + int8_t* unique_partition_embedding_grad, + int8_t* received_embeddings_grad, T* cur_rank_quantize_factor, + T* received_cur_rank_quantize_factor) { + std::vector send_offsets; + std::vector send_elem_cnt; + std::vector recv_offsets; + std::vector recv_elem_cnt; + // Shuffle Embedding Grad. + MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num, + &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt); + ShuffleData(cuda_stream, comm, DataType::kInt8, send_offsets, send_elem_cnt, + unique_partition_embedding_grad, recv_offsets, recv_elem_cnt, + received_embeddings_grad); + // Shuffle Quantize factor. + MakeShuffleParams(host_num_unique_matrix, num_ids, /*embedding_size=*/1, parallel_id, + parallel_num, &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt); + ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt, cur_rank_quantize_factor, + recv_offsets, recv_elem_cnt, received_cur_rank_quantize_factor); +} + template class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { public: @@ -587,6 +1117,7 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { auto* kernel_state = dynamic_cast*>(state); CHECK(kernel_state != nullptr); const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); + const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex("num_unique_matrix", 0); const user_op::Tensor* cur_rank_inverse_indices = ctx->Tensor4ArgNameAndIndex("cur_rank_inverse_indices", 0); @@ -601,51 +1132,148 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - + ncclComm_t comm = kernel_state->comm(); + using ComputeType = typename DefaultComputeType::type; + bool enable_quantized_comm_env_var = + ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); + bool enable_quantized_comm = enable_quantized_comm_env_var && (embedding_size < kMaxColSize); + if (enable_quantized_comm_env_var && !enable_quantized_comm) { + LOG(WARNING) << "Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and " + "embedding_size less equal than 1024 can use quantized communication. "; + } cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); OF_CUDA_CHECK(cudaMemcpyAsync(host_num_unique_matrix, num_unique_matrix->dptr(), parallel_num * parallel_num * sizeof(IDX), cudaMemcpyDefault, cuda_stream)); CHECK_JUST(ctx->stream()->Sync()); + int64_t cur_rank_num_ids = 0; for (int64_t i = 0; i < parallel_num; ++i) { cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id]; } + size_t full_num_ids = parallel_num * num_ids; + size_t full_elem_cnt = full_num_ids * embedding_size; + size_t unique_partition_embedding_grad_size = GetCudaAlignedSize(full_elem_cnt * sizeof(T)); - size_t unique_partition_embedding_grad_size = - GetCudaAlignedSize(parallel_num * num_ids * embedding_size * sizeof(T)); - size_t received_embedding_grad_size = unique_partition_embedding_grad_size; - T* unique_partition_embedding_grad = reinterpret_cast(tmp_buffer->mut_dptr()); - T* received_embedding_grad = - reinterpret_cast(tmp_buffer->mut_dptr() + unique_partition_embedding_grad_size); - CHECK_GE(tmp_buffer->shape().elem_cnt(), - unique_partition_embedding_grad_size + received_embedding_grad_size); + if (!enable_quantized_comm) { + size_t received_embedding_grad_size = unique_partition_embedding_grad_size; + T* unique_partition_embedding_grad = reinterpret_cast(tmp_buffer->mut_dptr()); + T* received_embedding_grad = + reinterpret_cast(tmp_buffer->mut_dptr() + unique_partition_embedding_grad_size); + CHECK_GE(tmp_buffer->shape().elem_cnt(), + unique_partition_embedding_grad_size + received_embedding_grad_size); - // unique and partition embedding grad - for (int64_t i = 0; i < parallel_num; ++i) { - const int64_t offset = i * num_ids * embedding_size; - const int64_t valid_value_size = - host_num_unique_matrix[parallel_id * parallel_num + i] * embedding_size * sizeof(T); - OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size, - cuda_stream)); - } - UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( - ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), - embedding_grad->dptr(), num_ids, parallel_num * num_ids, 1, embedding_size, 0, - unique_partition_embedding_grad); + // unique and partition embedding grad + for (int64_t i = 0; i < parallel_num; ++i) { + const int64_t offset = i * num_ids * embedding_size; + const int64_t valid_value_size = + host_num_unique_matrix[parallel_id * parallel_num + i] * embedding_size * sizeof(T); + OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size, + cuda_stream)); + } + UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( + ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), + embedding_grad->dptr(), num_ids, parallel_num * num_ids, 1, embedding_size, 0, + unique_partition_embedding_grad); - ncclComm_t comm = kernel_state->comm(); - ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, - data_type, host_num_unique_matrix, unique_partition_embedding_grad, - received_embedding_grad); - - // unique cur_rank embedding grad - OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad->mut_dptr(), 0, - cur_rank_num_ids * embedding_size * sizeof(T), cuda_stream)); - UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( - ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), - received_embedding_grad, cur_rank_num_ids, cur_rank_num_ids, 1, embedding_size, 0, - cur_rank_unique_embedding_grad->mut_dptr()); + ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, + data_type, host_num_unique_matrix, unique_partition_embedding_grad, + received_embedding_grad); + + // unique cur_rank embedding grad + OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad->mut_dptr(), 0, + cur_rank_num_ids * embedding_size * sizeof(T), cuda_stream)); + UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( + ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), + received_embedding_grad, cur_rank_num_ids, cur_rank_num_ids, 1, embedding_size, 0, + cur_rank_unique_embedding_grad->mut_dptr()); + } else { + size_t received_embedding_grad_size = GetCudaAlignedSize(full_elem_cnt * sizeof(int8_t)); + size_t quantize_cur_rank_embedding_grad_size = received_embedding_grad_size; + size_t cur_rank_quantize_factor_size = GetCudaAlignedSize(full_num_ids * sizeof(T)); + size_t received_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size; + size_t dequantize_cur_rank_embedding_grad_size = + GetCudaAlignedSize(full_elem_cnt * sizeof(T)); + CHECK_GE(tmp_buffer->shape().elem_cnt(), + unique_partition_embedding_grad_size + received_embedding_grad_size + + quantize_cur_rank_embedding_grad_size + cur_rank_quantize_factor_size + + received_cur_rank_quantize_factor_size + + dequantize_cur_rank_embedding_grad_size); + T* unique_partition_embedding_grad = reinterpret_cast(tmp_buffer->mut_dptr()); + int8_t* received_embedding_grad = reinterpret_cast( + tmp_buffer->mut_dptr() + unique_partition_embedding_grad_size); + + int8_t* quantize_cur_rank_embedding_grad = reinterpret_cast( + tmp_buffer->mut_dptr() + unique_partition_embedding_grad_size + + received_embedding_grad_size); + T* cur_rank_quantize_factor = reinterpret_cast( + tmp_buffer->mut_dptr() + unique_partition_embedding_grad_size + + received_embedding_grad_size + quantize_cur_rank_embedding_grad_size); + T* received_cur_rank_quantize_factor = reinterpret_cast( + tmp_buffer->mut_dptr() + unique_partition_embedding_grad_size + + received_embedding_grad_size + quantize_cur_rank_embedding_grad_size + + cur_rank_quantize_factor_size); + T* dequantize_cur_rank_embedding_grad = reinterpret_cast( + tmp_buffer->mut_dptr() + unique_partition_embedding_grad_size + + received_embedding_grad_size + quantize_cur_rank_embedding_grad_size + + cur_rank_quantize_factor_size + received_cur_rank_quantize_factor_size); + + // unique and partition embedding grad + for (int64_t i = 0; i < parallel_num; ++i) { + const int64_t offset = i * num_ids * embedding_size; + const int64_t valid_value_size = + host_num_unique_matrix[parallel_id * parallel_num + i] * embedding_size * sizeof(T); + OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size, + cuda_stream)); + } + + UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( + ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), + embedding_grad->dptr(), num_ids, parallel_num * num_ids, 1, embedding_size, 0, + unique_partition_embedding_grad); + + // Quantize. + for (int64_t i = 0; i < parallel_num; ++i) { + const int64_t embedding_grad_offset = i * num_ids * embedding_size; + const int64_t quantize_factor_offset = i * num_ids; + const int64_t valid_row_size = host_num_unique_matrix[parallel_id * parallel_num + i]; + DispatchQuantizeWarpImplPackSize()( + cuda_stream, unique_partition_embedding_grad + embedding_grad_offset, + quantize_cur_rank_embedding_grad + embedding_grad_offset, + cur_rank_quantize_factor + quantize_factor_offset, valid_row_size, embedding_size); + } + + ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, + data_type, host_num_unique_matrix, quantize_cur_rank_embedding_grad, + received_embedding_grad, cur_rank_quantize_factor, + received_cur_rank_quantize_factor); + + int64_t dequantize_cur_rank_num = 0; + for (int64_t i = 0; i < parallel_num; ++i) { + /* + Host num unique matrix: + | Partition0 | Partition1 | + | Rank0 | 2 | 4 | + | Rank1 | 3 | 3 | + After ShuffleEmbeddingGrads, each rank will exchange partition. + For example: + Rank0 will have (matrix[rank0][part0] + matrix[rank1][part0]) grad tensor. + Rank1 will have (matrix[rank0][part1] + matrix[rank1][part1]) grad tensor. + */ + dequantize_cur_rank_num += host_num_unique_matrix[i * parallel_num + parallel_id]; + } + IDX dequantize_elem_cnt = dequantize_cur_rank_num * embedding_size; + OF_CUDA_CHECK((LaunchDequantizeKernel( + cuda_stream, received_embedding_grad, received_cur_rank_quantize_factor, + dequantize_cur_rank_embedding_grad, embedding_size, dequantize_elem_cnt))); + // unique cur_rank embedding grad + OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad->mut_dptr(), 0, + cur_rank_num_ids * embedding_size * sizeof(T), cuda_stream)); + UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( + ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), + dequantize_cur_rank_embedding_grad, cur_rank_num_ids, cur_rank_num_ids, 1, embedding_size, + 0, cur_rank_unique_embedding_grad->mut_dptr()); + } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -661,10 +1289,33 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const user_op::TensorDesc& cur_rank_unique_embedding_grad = \ ctx->InputTensorDesc("cur_rank_unique_embedding_grad", 0); \ - size_t cur_rank_embedding_grad_size = \ - GetCudaAlignedSize(cur_rank_unique_embedding_grad.shape().elem_cnt() \ - * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ - return 2 * cur_rank_embedding_grad_size; \ + size_t cur_rank_embedding_grad_num = cur_rank_unique_embedding_grad.shape().At(0); \ + size_t embedding_size = cur_rank_unique_embedding_grad.shape().At(1); \ + size_t cur_rank_embedding_grad_elem_cnt = cur_rank_embedding_grad_num * embedding_size; \ + bool enable_quantized_comm = \ + ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false) \ + && (embedding_size < kMaxColSize); \ + size_t tmp_size = 0; \ + if (!enable_quantized_comm) { \ + size_t cur_rank_embedding_grad_size = GetCudaAlignedSize( \ + cur_rank_embedding_grad_elem_cnt * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ + tmp_size = 2 * cur_rank_embedding_grad_size; \ + } else { \ + size_t unique_partition_embedding_grad_size = GetCudaAlignedSize( \ + cur_rank_embedding_grad_elem_cnt * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ + size_t received_embedding_grad_size = \ + GetCudaAlignedSize(cur_rank_embedding_grad_elem_cnt * sizeof(int8_t)); \ + size_t quantize_cur_rank_embedding_grad_size = received_embedding_grad_size; \ + size_t cur_rank_quantize_factor_size = GetCudaAlignedSize( \ + cur_rank_embedding_grad_num * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair))); \ + size_t received_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size; \ + size_t dequantize_cur_rank_embedding_grad_size = unique_partition_embedding_grad_size; \ + tmp_size = unique_partition_embedding_grad_size + received_embedding_grad_size \ + + quantize_cur_rank_embedding_grad_size + cur_rank_quantize_factor_size \ + + received_cur_rank_quantize_factor_size \ + + dequantize_cur_rank_embedding_grad_size; \ + } \ + return tmp_size; \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_GRADIENT_SHUFFLE_KERNEL, diff --git a/oneflow/user/kernels/gather_kernel.cpp b/oneflow/user/kernels/gather_kernel.cpp index e9f02511ae1..42a0a6dc976 100644 --- a/oneflow/user/kernels/gather_kernel.cpp +++ b/oneflow/user/kernels/gather_kernel.cpp @@ -118,6 +118,12 @@ class GatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSu OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, DEVICE_TYPE_SEQ, GATHER_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) +#ifdef WITH_CUDA +// For Half +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), + HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) +#endif + } // namespace user_op } // namespace oneflow diff --git a/oneflow/user/kernels/gather_kernel_util.cu b/oneflow/user/kernels/gather_kernel_util.cu index d2d83a4e7bd..492eca7b825 100644 --- a/oneflow/user/kernels/gather_kernel_util.cu +++ b/oneflow/user/kernels/gather_kernel_util.cu @@ -115,8 +115,8 @@ struct GatherKernelUtilImpl final { #define INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL(in_type_pair, index_type_pair) \ template struct GatherKernelUtilImpl; -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL, GATHER_DATA_TYPE_SEQ, - GATHER_INDEX_TYPE_SEQ); +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL, + GATHER_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, GATHER_INDEX_TYPE_SEQ); #undef INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL } // namespace oneflow diff --git a/oneflow/user/kernels/gather_kernel_util.h b/oneflow/user/kernels/gather_kernel_util.h index e11e3397abe..2eb1a774c84 100644 --- a/oneflow/user/kernels/gather_kernel_util.h +++ b/oneflow/user/kernels/gather_kernel_util.h @@ -34,7 +34,7 @@ struct GatherKernelUtilImpl final { const Shape& flat_in_shape, T* out, int64_t offset); }; -#define GATHER_DATA_TYPE_SEQ ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ +#define GATHER_DATA_TYPE_SEQ ARITHMETIC_DATA_TYPE_SEQ #define GATHER_INDEX_TYPE_SEQ INDEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) } // namespace oneflow diff --git a/python/oneflow/test/expensive/test_id_shuffle.py b/python/oneflow/test/expensive/test_id_shuffle.py index 197933239e2..6eef77c5ae2 100644 --- a/python/oneflow/test/expensive/test_id_shuffle.py +++ b/python/oneflow/test/expensive/test_id_shuffle.py @@ -19,7 +19,6 @@ from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow -import oneflow.unittest from oneflow.test_utils.automated_test_util import * @@ -85,10 +84,51 @@ def build(self, ids, table_ids): # when has_table_id=False, we can not test table ids because in this case same ids not lead to same table id -def _test_embedding_shuffle(test_case, dtype): +def round_half_away_from_zero(x): + sign = np.sign(x) + abs_val = np.abs(x) + abs_val += 0.5 + floor_val = np.floor(abs_val) + out = floor_val * sign + return out + + +def embedding_shuffle_quantize(np_data, np_dtype): + # When use float16, ComputeType is set to as Float. + np_reduce_data = np_data.astype(np.float32) + abs_max_factor = np.max(np.abs(np_reduce_data), axis=2) + abs_max_factor = np.expand_dims(abs_max_factor, axis=2) + transport_quantize_factor = abs_max_factor.astype(np_dtype) + int8_factor = np.ones(abs_max_factor.shape, dtype=np.float32) * 127.0 + int8_factor = int8_factor.astype(np.float32) + quantize_factor = int8_factor / abs_max_factor + + # Covert to Compute Type. + np_data.astype(np.float32) + np_data = np_data * quantize_factor + np_data = round_half_away_from_zero(np_data) + np_data = np_data.astype(np.int8) + + # Covert to Compute Type. + np_data = np_data.astype(np.float32) + dequantize_factor = transport_quantize_factor.astype(np.float32) / int8_factor + np_data = np_data * dequantize_factor + np_data = np_data.astype(np_dtype) + return np_data + + +def _test_embedding_shuffle(test_case, dtype, enable_quantize): batch_size = 512 num_tables = 26 + embedding_size = 128 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) + + enable_quantized_comm = enable_quantize and embedding_size < 1025 + if enable_quantized_comm: + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" + else: + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" + table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids @@ -96,7 +136,8 @@ def _test_embedding_shuffle(test_case, dtype): np_dtype = np.float16 else: np_dtype = np.float32 - data = np.random.rand(1000, 128).astype(np_dtype) + data = np.random.rand(1000, embedding_size).astype(np_dtype) + ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to( "cuda" @@ -129,20 +170,33 @@ def build(self, ids, table_ids, data): embeddings = graph(ids_tensor, table_ids_tensor, data_tensor) np_embeddings = data[ids] - test_case.assertTrue(np.array_equal(embeddings.numpy(), np_embeddings)) + # Quantized numpy embedding. + if enable_quantized_comm: + np_embeddings = embedding_shuffle_quantize(np_embeddings, np_dtype) + test_case.assertTrue( + np.allclose(embeddings.numpy(), np_embeddings, atol=1e-4, rtol=1e-4) + ) -def _test_embedding_gradient_shuffle(test_case): +def _test_embedding_gradient_shuffle(test_case, enable_quantize): batch_size = 512 num_tables = 26 embedding_size = 128 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) + enable_quantized_comm = enable_quantize and embedding_size < 1025 + if enable_quantized_comm: + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" + ids = np.arange(batch_size * num_tables, dtype=np.int64) + np.random.shuffle(ids) + else: + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" + table_ids = ( ids % num_tables ) # same id must have same table id, so in this case get table_ids from ids - embedding_grad = np.random.rand(batch_size, num_tables, embedding_size).astype( - np.float32 - ) + embedding_grad = np.random.uniform( + low=-1, high=1, size=(batch_size, num_tables, embedding_size) + ).astype(np.float32) ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda") table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to( "cuda" @@ -185,26 +239,42 @@ def build(self, ids, table_ids, embedding_grad): np_unique_ids, np_inverse = np.unique(ids, return_inverse=True) np_num_unique = np_unique_ids.size np_cur_rank_unique_embedding_grad = np.zeros( - cur_rank_unique_embedding_grad.shape + cur_rank_unique_embedding_grad.shape, dtype=np.float32 ).reshape(-1, embedding_size) + + embedding_grad = embedding_grad.reshape(-1, embedding_size) for k in range(np_num_unique): - np_cur_rank_unique_embedding_grad[k, :] = sum( - embedding_grad.reshape(-1, embedding_size)[ - np.where(ids.flatten() == np_unique_ids[k])[0] - ] - ) + np_data = sum(embedding_grad[np.where(ids.flatten() == np_unique_ids[k])[0]]) + # Quantize Embedding Gradient. + if enable_quantized_comm: + abs_max_factor = np.max(np.abs(np_data)) + int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32) + quantize_factor = int8_factor / abs_max_factor + np_data = np_data * quantize_factor + np_data = round_half_away_from_zero(np_data) + np_data = np_data.astype(np.int8) + np_data = np_data.astype(np.float32) + dequantize_factor = abs_max_factor / int8_factor + np_data = np_data * dequantize_factor + + np_cur_rank_unique_embedding_grad[k, :] = np_data + reversed_ids = cur_rank_unique_ids[cur_rank_inverse_indices][ inverse_unique_partition_indices ] test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids)) + of_cur_rank_embedding_grad = cur_rank_unique_embedding_grad[ + cur_rank_inverse_indices + ][inverse_unique_partition_indices] + of_cur_rank_embedding_grad = flow.reshape( + of_cur_rank_embedding_grad, (-1, embedding_size) + ) + np_cur_rank_embedding_grad = np_cur_rank_unique_embedding_grad[np_inverse] + test_case.assertTrue( np.allclose( - cur_rank_unique_embedding_grad[cur_rank_inverse_indices][ - inverse_unique_partition_indices - ] - .numpy() - .flatten(), - np_cur_rank_unique_embedding_grad[np_inverse].flatten(), + of_cur_rank_embedding_grad.numpy().flatten(), + np_cur_rank_embedding_grad.flatten(), atol=1e-4, rtol=1e-4, ) @@ -270,11 +340,14 @@ def test_id_shuffle(test_case): def test_embedding_shuffle(test_case): arg_dict = OrderedDict() arg_dict["dtype"] = [flow.float32, flow.float16] + arg_dict["enable_quantize"] = [True, False] + for kwargs in GenArgDict(arg_dict): _test_embedding_shuffle(test_case, **kwargs) def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() + arg_dict["enable_quantize"] = [True, False] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs) diff --git a/python/oneflow/test/modules/test_id_shuffle_global.py b/python/oneflow/test/modules/test_id_shuffle_global.py index 37316ce238e..6ed40f6dfbf 100644 --- a/python/oneflow/test/modules/test_id_shuffle_global.py +++ b/python/oneflow/test/modules/test_id_shuffle_global.py @@ -19,7 +19,6 @@ from oneflow.test_utils.test_util import GenArgDict import numpy as np import oneflow as flow -import oneflow.unittest from oneflow.test_utils.automated_test_util import * @@ -127,15 +126,55 @@ def build(self, ids, table_ids): test_case.assertTrue(np.array_equal(unique_table_ids, np_unique_table_ids)) -def _test_embedding_shuffle(test_case, dtype): +def round_half_away_from_zero(x): + sign = np.sign(x) + abs_val = np.abs(x) + abs_val += 0.5 + floor_val = np.floor(abs_val) + out = floor_val * sign + return out + + +def embedding_shuffle_quantize(np_data, np_dtype): + # When use float16, ComputeType is set to as Float. + np_reduce_data = np_data.astype(np.float32) + abs_max_factor = np.max(np.abs(np_reduce_data), axis=2) + abs_max_factor = np.expand_dims(abs_max_factor, axis=2) + transport_quantize_factor = abs_max_factor.astype(np_dtype) + int8_factor = np.ones(abs_max_factor.shape, dtype=np.float32) * 127.0 + int8_factor = int8_factor.astype(np.float32) + quantize_factor = int8_factor / abs_max_factor + + # Covert to Compute Type. + np_data.astype(np.float32) + np_data = np_data * quantize_factor + np_data = round_half_away_from_zero(np_data) + np_data = np_data.astype(np.int8) + + # Covert to Compute Type. + np_data = np_data.astype(np.float32) + dequantize_factor = transport_quantize_factor.astype(np.float32) / int8_factor + np_data = np_data * dequantize_factor + np_data = np_data.astype(np_dtype) + return np_data + + +def _test_embedding_shuffle(test_case, dtype, enable_quantize): batch_size = int(1024 / parallel_num) placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) num_tables = 26 + embedding_size = 128 + enable_quantized_comm = enable_quantize and embedding_size < 1025 + if enable_quantized_comm: + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" + else: + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" + if dtype == flow.float16: np_dtype = np.float16 else: np_dtype = np.float32 - data = np.random.rand(max_id, 128).astype(np_dtype) + data = np.random.rand(max_id, embedding_size).astype(np_dtype) data_tensor = flow.tensor(data, requires_grad=False).to_global( placement=placement, sbp=flow.sbp.broadcast() ) @@ -170,14 +209,27 @@ def build(self, ids, table_ids, data): global_ids = ids_tensor.numpy() global_data = data_tensor.numpy() np_embeddings = global_data[global_ids] + + # Quantized numpy embedding. + if enable_quantized_comm: + np_embeddings = embedding_shuffle_quantize(np_embeddings, np_dtype) + test_case.assertTrue(np.array_equal(embeddings.numpy(), np_embeddings)) -def _test_embedding_gradient_shuffle(test_case): +def _test_embedding_gradient_shuffle(test_case, enable_quantize): + np_tolerance = 0 batch_size = int(1024 / parallel_num) placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) num_tables = 26 embedding_size = 128 + enable_quantized_comm = enable_quantize and embedding_size < 1025 + if enable_quantized_comm: + np_tolerance = 0.5 + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" + else: + np_tolerance = 1e-4 + os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" embedding_grad = np.random.rand(batch_size, num_tables, embedding_size).astype( np.float32 ) @@ -230,11 +282,24 @@ def build(self, ids, table_ids, embedding_grad): np_cur_rank_unique_embedding_grad = np.zeros((max_id, embedding_size)) for k in range(np_num_unique): unique_id = np_unique_ids[k] - np_cur_rank_unique_embedding_grad[unique_id, :] = sum( + np_data = sum( global_embedding_grad.reshape(-1, embedding_size)[ np.where(global_ids.flatten() == unique_id)[0] ] ) + # Quantize Embedding Gradient. + if enable_quantized_comm: + abs_max_factor = np.max(np.abs(np_data)) + int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32) + quantize_factor = int8_factor / abs_max_factor + np_data = np_data * quantize_factor + np_data = round_half_away_from_zero(np_data) + np_data = np_data.astype(np.int8) + np_data = np_data.astype(np.float32) + dequantize_factor = abs_max_factor / int8_factor + np_data = np_data * dequantize_factor + + np_cur_rank_unique_embedding_grad[unique_id, :] = np_data cur_rank_num_ids = batch_size * num_tables * parallel_num of_unique_embedding_grad = np.zeros((max_id, embedding_size)) @@ -254,8 +319,8 @@ def build(self, ids, table_ids, embedding_grad): np.allclose( of_unique_embedding_grad, np_cur_rank_unique_embedding_grad, - atol=1e-4, - rtol=1e-4, + atol=np_tolerance, + rtol=np_tolerance, ), ) @@ -273,11 +338,14 @@ def test_id_shuffle(test_case): def test_embedding_shuffle(test_case): arg_dict = OrderedDict() arg_dict["dtype"] = [flow.float32, flow.float16] + arg_dict["enable_quantize"] = [True, False] + for kwargs in GenArgDict(arg_dict): _test_embedding_shuffle(test_case, **kwargs) def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() + arg_dict["enable_quantize"] = [True, False] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs)