From 4787170d0a3d1c07bea7fc2a5ffc176090fbab28 Mon Sep 17 00:00:00 2001 From: guo ran <360112263@qq.com> Date: Sun, 15 May 2022 23:03:05 +0800 Subject: [PATCH] [OneEmbedding] optimize unsorted_segment_sum when col is odd (#8204) * optimize unsorted_segment_sum when col is odd * address review Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- oneflow/user/kernels/data_shuffle_kernel.cu | 220 +++++++++++++----- .../oneflow/test/expensive/test_id_shuffle.py | 24 +- .../test/modules/test_id_shuffle_global.py | 22 +- 3 files changed, 207 insertions(+), 59 deletions(-) diff --git a/oneflow/user/kernels/data_shuffle_kernel.cu b/oneflow/user/kernels/data_shuffle_kernel.cu index df72470c090..348d69ba669 100644 --- a/oneflow/user/kernels/data_shuffle_kernel.cu +++ b/oneflow/user/kernels/data_shuffle_kernel.cu @@ -23,6 +23,8 @@ limitations under the License. #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/embedding/hash_functions.cuh" #include "oneflow/core/cuda/elementwise.cuh" +#include "oneflow/core/ep/include/primitive/copy_nd.h" +#include "oneflow/core/cuda/atomic.cuh" namespace oneflow { @@ -1099,6 +1101,131 @@ void ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t pa recv_offsets, recv_elem_cnt, received_cur_rank_quantize_factor); } +template +__global__ void UnsortedSegmentHalfGpu(const IDX in_h2_elem_cnt, const IDX h2_inner_dim_size, + const IDX inner_dim_size, const half* data, + const K* segment_ids, const IDX num_segments, + half2* out_h2) { + CUDA_1D_KERNEL_LOOP_T(IDX, i, in_h2_elem_cnt) { + const IDX segment_id_idx = i / h2_inner_dim_size; + const IDX h2_inner_idx = i - segment_id_idx * h2_inner_dim_size; + const IDX inner_idx_0 = 2 * h2_inner_idx; + const IDX inner_idx_1 = inner_idx_0 + 1; + const half* data_row = data + segment_id_idx * inner_dim_size; + half2 val; + val.x = data_row[inner_idx_0]; + val.y = (inner_idx_1 >= inner_dim_size) ? static_cast(0) : data_row[inner_idx_1]; + const IDX idx = segment_ids[segment_id_idx]; + const IDX out_h2_offset = idx * h2_inner_dim_size + h2_inner_idx; + cuda::atomic::Add(out_h2 + out_h2_offset, val); + } +} + +template +struct UnsortedSegmentSumPad { + void operator()(ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids, + int64_t num_segments, int64_t inner_dim_size, int64_t padded_inner_dim_size, + T* out) const { + UNIMPLEMENTED(); + } +}; + +template +struct UnsortedSegmentSumPad { + void operator()(ep::Stream* stream, const K* segment_ids, const half* data, + int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size, + int64_t padded_inner_dim_size, half* out) const { + const int64_t data_elem_cnt = num_segment_ids * inner_dim_size; + const int64_t out_elem_cnt = num_segments * padded_inner_dim_size; + CHECK_EQ(padded_inner_dim_size % 2, 0); + CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size); + const int64_t h2_inner_dim_size = padded_inner_dim_size / 2; + const int64_t in_h2_elem_cnt = num_segment_ids * h2_inner_dim_size; + if (std::max(data_elem_cnt, out_elem_cnt) < GetMaxVal() / 2) { + UnsortedSegmentHalfGpu + <<As()->cuda_stream()>>>( + in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments, + reinterpret_cast(out)); + } else { + UnsortedSegmentHalfGpu + <<As()->cuda_stream()>>>( + in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments, + reinterpret_cast(out)); + } + } +}; + +template +void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const T* data, + int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size, + int64_t padded_inner_dim_size, T* out) { + if (inner_dim_size == padded_inner_dim_size) { + UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( + stream, segment_ids, data, num_segment_ids, num_segments, 1, inner_dim_size, 0, out); + } else { + CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size); + UnsortedSegmentSumPad()(stream, segment_ids, data, num_segment_ids, num_segments, + inner_dim_size, padded_inner_dim_size, out); + } +} + +template +void UniquePartitionEmbeddingGrad(ep::Stream* stream, int64_t parallel_id, int64_t parallel_num, + int64_t num_ids, int64_t embedding_size, + int64_t padded_embedding_size, const IDX* host_num_unique_matrix, + const T* embedding_grad, + const IDX* inverse_unique_partition_indices, + T* unique_partition_embedding_grad) { + for (int64_t i = 0; i < parallel_num; ++i) { + const int64_t offset = i * num_ids * padded_embedding_size; + const int64_t valid_value_size = + host_num_unique_matrix[parallel_id * parallel_num + i] * padded_embedding_size * sizeof(T); + OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size, + stream->As()->cuda_stream())); + } + UnsortedSegmentSum(stream, inverse_unique_partition_indices, embedding_grad, num_ids, + parallel_num * num_ids, embedding_size, padded_embedding_size, + unique_partition_embedding_grad); +} + +template +void UniqueCurRankEmbeddingGrad(ep::Stream* stream, DataType data_type, int64_t cur_rank_num_ids, + int64_t embedding_size, int64_t padded_embedding_size, + const T* cur_rank_embedding_grad, + const IDX* cur_rank_inverse_indices, + T* cur_rank_unique_embedding_grad, T* tmp_buffer) { + T* unsorted_segment_sum_out = + (embedding_size == padded_embedding_size) ? cur_rank_unique_embedding_grad : tmp_buffer; + OF_CUDA_CHECK(cudaMemsetAsync(unsorted_segment_sum_out, 0, + cur_rank_num_ids * padded_embedding_size * sizeof(T), + stream->As()->cuda_stream())); + UnsortedSegmentSum(stream, cur_rank_inverse_indices, cur_rank_embedding_grad, + cur_rank_num_ids, cur_rank_num_ids, padded_embedding_size, + padded_embedding_size, unsorted_segment_sum_out); + if (embedding_size != padded_embedding_size) { + std::unique_ptr primitive = + ep::primitive::NewPrimitive(DeviceType::kCUDA, 2); + DimVector dst_shape = {cur_rank_num_ids, embedding_size}; + DimVector dst_pos_vec = {0, 0}; + DimVector src_shape = {cur_rank_num_ids, padded_embedding_size}; + DimVector src_pos_vec = {0, 0}; + DimVector extent_vec = {cur_rank_num_ids, embedding_size}; + primitive->Launch(stream, data_type, 2, cur_rank_unique_embedding_grad, dst_shape.data(), + dst_pos_vec.data(), unsorted_segment_sum_out, src_shape.data(), + src_pos_vec.data(), extent_vec.data()); + } +} + +int64_t GetPaddedEmbeddingSize(DataType data_type, int64_t embedding_size) { + if (data_type == DataType::kFloat16 && embedding_size % 2 != 0) { + return embedding_size + 1; + } else { + return embedding_size; + } +} + template class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { public: @@ -1131,12 +1258,14 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { const int64_t num_ids = inverse_unique_partition_indices->shape().elem_cnt(); const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); + const int64_t padded_embedding_size = GetPaddedEmbeddingSize(data_type, embedding_size); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); ncclComm_t comm = kernel_state->comm(); using ComputeType = typename DefaultComputeType::type; bool enable_quantized_comm_env_var = ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false); - bool enable_quantized_comm = enable_quantized_comm_env_var && (embedding_size < kMaxColSize); + bool enable_quantized_comm = + enable_quantized_comm_env_var && (padded_embedding_size < kMaxColSize); if (enable_quantized_comm_env_var && !enable_quantized_comm) { LOG(WARNING) << "Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and " "embedding_size less equal than 1024 can use quantized communication. "; @@ -1152,7 +1281,7 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id]; } size_t full_num_ids = parallel_num * num_ids; - size_t full_elem_cnt = full_num_ids * embedding_size; + size_t full_elem_cnt = full_num_ids * padded_embedding_size; size_t unique_partition_embedding_grad_size = GetCudaAlignedSize(full_elem_cnt * sizeof(T)); if (!enable_quantized_comm) { @@ -1163,30 +1292,22 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { CHECK_GE(tmp_buffer->shape().elem_cnt(), unique_partition_embedding_grad_size + received_embedding_grad_size); - // unique and partition embedding grad - for (int64_t i = 0; i < parallel_num; ++i) { - const int64_t offset = i * num_ids * embedding_size; - const int64_t valid_value_size = - host_num_unique_matrix[parallel_id * parallel_num + i] * embedding_size * sizeof(T); - OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size, - cuda_stream)); - } - UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( - ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), - embedding_grad->dptr(), num_ids, parallel_num * num_ids, 1, embedding_size, 0, + UniquePartitionEmbeddingGrad( + ctx->stream(), parallel_id, parallel_num, num_ids, embedding_size, padded_embedding_size, + host_num_unique_matrix, embedding_grad->dptr(), + reinterpret_cast(inverse_unique_partition_indices->dptr()), unique_partition_embedding_grad); - ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, - data_type, host_num_unique_matrix, unique_partition_embedding_grad, - received_embedding_grad); + ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, + padded_embedding_size, data_type, host_num_unique_matrix, + unique_partition_embedding_grad, received_embedding_grad); - // unique cur_rank embedding grad - OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad->mut_dptr(), 0, - cur_rank_num_ids * embedding_size * sizeof(T), cuda_stream)); - UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( - ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), - received_embedding_grad, cur_rank_num_ids, cur_rank_num_ids, 1, embedding_size, 0, - cur_rank_unique_embedding_grad->mut_dptr()); + // use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer. + T* buffer_ptr = unique_partition_embedding_grad; + UniqueCurRankEmbeddingGrad(ctx->stream(), data_type, cur_rank_num_ids, embedding_size, + padded_embedding_size, received_embedding_grad, + reinterpret_cast(cur_rank_inverse_indices->dptr()), + cur_rank_unique_embedding_grad->mut_dptr(), buffer_ptr); } else { size_t received_embedding_grad_size = GetCudaAlignedSize(full_elem_cnt * sizeof(int8_t)); size_t quantize_cur_rank_embedding_grad_size = received_embedding_grad_size; @@ -1218,35 +1339,28 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { + received_embedding_grad_size + quantize_cur_rank_embedding_grad_size + cur_rank_quantize_factor_size + received_cur_rank_quantize_factor_size); - // unique and partition embedding grad - for (int64_t i = 0; i < parallel_num; ++i) { - const int64_t offset = i * num_ids * embedding_size; - const int64_t valid_value_size = - host_num_unique_matrix[parallel_id * parallel_num + i] * embedding_size * sizeof(T); - OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad + offset, 0, valid_value_size, - cuda_stream)); - } - - UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( - ctx->stream(), reinterpret_cast(inverse_unique_partition_indices->dptr()), - embedding_grad->dptr(), num_ids, parallel_num * num_ids, 1, embedding_size, 0, + UniquePartitionEmbeddingGrad( + ctx->stream(), parallel_id, parallel_num, num_ids, embedding_size, padded_embedding_size, + host_num_unique_matrix, embedding_grad->dptr(), + reinterpret_cast(inverse_unique_partition_indices->dptr()), unique_partition_embedding_grad); // Quantize. for (int64_t i = 0; i < parallel_num; ++i) { - const int64_t embedding_grad_offset = i * num_ids * embedding_size; + const int64_t embedding_grad_offset = i * num_ids * padded_embedding_size; const int64_t quantize_factor_offset = i * num_ids; const int64_t valid_row_size = host_num_unique_matrix[parallel_id * parallel_num + i]; DispatchQuantizeWarpImplPackSize()( cuda_stream, unique_partition_embedding_grad + embedding_grad_offset, quantize_cur_rank_embedding_grad + embedding_grad_offset, - cur_rank_quantize_factor + quantize_factor_offset, valid_row_size, embedding_size); + cur_rank_quantize_factor + quantize_factor_offset, valid_row_size, + padded_embedding_size); } - ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, embedding_size, - data_type, host_num_unique_matrix, quantize_cur_rank_embedding_grad, - received_embedding_grad, cur_rank_quantize_factor, - received_cur_rank_quantize_factor); + ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids, + padded_embedding_size, data_type, host_num_unique_matrix, + quantize_cur_rank_embedding_grad, received_embedding_grad, + cur_rank_quantize_factor, received_cur_rank_quantize_factor); int64_t dequantize_cur_rank_num = 0; for (int64_t i = 0; i < parallel_num; ++i) { @@ -1262,17 +1376,16 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { */ dequantize_cur_rank_num += host_num_unique_matrix[i * parallel_num + parallel_id]; } - IDX dequantize_elem_cnt = dequantize_cur_rank_num * embedding_size; + IDX dequantize_elem_cnt = dequantize_cur_rank_num * padded_embedding_size; OF_CUDA_CHECK((LaunchDequantizeKernel( cuda_stream, received_embedding_grad, received_cur_rank_quantize_factor, - dequantize_cur_rank_embedding_grad, embedding_size, dequantize_elem_cnt))); - // unique cur_rank embedding grad - OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad->mut_dptr(), 0, - cur_rank_num_ids * embedding_size * sizeof(T), cuda_stream)); - UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( - ctx->stream(), reinterpret_cast(cur_rank_inverse_indices->dptr()), - dequantize_cur_rank_embedding_grad, cur_rank_num_ids, cur_rank_num_ids, 1, embedding_size, - 0, cur_rank_unique_embedding_grad->mut_dptr()); + dequantize_cur_rank_embedding_grad, padded_embedding_size, dequantize_elem_cnt))); + // use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer. + T* buffer_ptr = unique_partition_embedding_grad; + UniqueCurRankEmbeddingGrad(ctx->stream(), data_type, cur_rank_num_ids, embedding_size, + padded_embedding_size, dequantize_cur_rank_embedding_grad, + reinterpret_cast(cur_rank_inverse_indices->dptr()), + cur_rank_unique_embedding_grad->mut_dptr(), buffer_ptr); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -1291,10 +1404,13 @@ class EmbeddingGradientShuffleKernel final : public user_op::OpKernel { ctx->InputTensorDesc("cur_rank_unique_embedding_grad", 0); \ size_t cur_rank_embedding_grad_num = cur_rank_unique_embedding_grad.shape().At(0); \ size_t embedding_size = cur_rank_unique_embedding_grad.shape().At(1); \ - size_t cur_rank_embedding_grad_elem_cnt = cur_rank_embedding_grad_num * embedding_size; \ + size_t padded_embedding_size = \ + GetPaddedEmbeddingSize(cur_rank_unique_embedding_grad.data_type(), embedding_size); \ + size_t cur_rank_embedding_grad_elem_cnt = \ + cur_rank_embedding_grad_num * padded_embedding_size; \ bool enable_quantized_comm = \ ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false) \ - && (embedding_size < kMaxColSize); \ + && (padded_embedding_size < kMaxColSize); \ size_t tmp_size = 0; \ if (!enable_quantized_comm) { \ size_t cur_rank_embedding_grad_size = GetCudaAlignedSize( \ diff --git a/python/oneflow/test/expensive/test_id_shuffle.py b/python/oneflow/test/expensive/test_id_shuffle.py index 6eef77c5ae2..301f186ee1d 100644 --- a/python/oneflow/test/expensive/test_id_shuffle.py +++ b/python/oneflow/test/expensive/test_id_shuffle.py @@ -178,17 +178,21 @@ def build(self, ids, table_ids, data): ) -def _test_embedding_gradient_shuffle(test_case, enable_quantize): +def _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size): batch_size = 512 num_tables = 26 - embedding_size = 128 ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64) enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: + np_tolerance = 0.5 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" ids = np.arange(batch_size * num_tables, dtype=np.int64) np.random.shuffle(ids) else: + if fp16: + np_tolerance = 1e-2 + else: + np_tolerance = 1e-4 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" table_ids = ( @@ -216,12 +220,18 @@ def build(self, ids, table_ids, embedding_grad): _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables) + if fp16: + embedding_grad = flow.cast(embedding_grad, flow.float16) cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle( embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, ) + if fp16: + cur_rank_unique_embedding_grad = flow.cast( + cur_rank_unique_embedding_grad, flow.float32 + ) return ( cur_rank_unique_embedding_grad, flow.cast(cur_rank_unique_ids, flow.int32), @@ -243,6 +253,8 @@ def build(self, ids, table_ids, embedding_grad): ).reshape(-1, embedding_size) embedding_grad = embedding_grad.reshape(-1, embedding_size) + if fp16: + embedding_grad = embedding_grad.astype(np.float16) for k in range(np_num_unique): np_data = sum(embedding_grad[np.where(ids.flatten() == np_unique_ids[k])[0]]) # Quantize Embedding Gradient. @@ -270,13 +282,15 @@ def build(self, ids, table_ids, embedding_grad): of_cur_rank_embedding_grad, (-1, embedding_size) ) np_cur_rank_embedding_grad = np_cur_rank_unique_embedding_grad[np_inverse] + if fp16: + np_cur_rank_embedding_grad = np_cur_rank_embedding_grad.astype(np.float32) test_case.assertTrue( np.allclose( of_cur_rank_embedding_grad.numpy().flatten(), np_cur_rank_embedding_grad.flatten(), - atol=1e-4, - rtol=1e-4, + atol=np_tolerance, + rtol=np_tolerance, ) ) @@ -348,6 +362,8 @@ def test_embedding_shuffle(test_case): def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() arg_dict["enable_quantize"] = [True, False] + arg_dict["fp16"] = [True, False] + arg_dict["embedding_size"] = [128, 17] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs) diff --git a/python/oneflow/test/modules/test_id_shuffle_global.py b/python/oneflow/test/modules/test_id_shuffle_global.py index 6ed40f6dfbf..872eb7f0e04 100644 --- a/python/oneflow/test/modules/test_id_shuffle_global.py +++ b/python/oneflow/test/modules/test_id_shuffle_global.py @@ -217,18 +217,20 @@ def build(self, ids, table_ids, data): test_case.assertTrue(np.array_equal(embeddings.numpy(), np_embeddings)) -def _test_embedding_gradient_shuffle(test_case, enable_quantize): +def _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size): np_tolerance = 0 batch_size = int(1024 / parallel_num) placement = flow.placement(type="cuda", ranks=list(range(parallel_num))) num_tables = 26 - embedding_size = 128 enable_quantized_comm = enable_quantize and embedding_size < 1025 if enable_quantized_comm: np_tolerance = 0.5 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "1" else: - np_tolerance = 1e-4 + if fp16: + np_tolerance = 1e-2 + else: + np_tolerance = 1e-4 os.environ["ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"] = "0" embedding_grad = np.random.rand(batch_size, num_tables, embedding_size).astype( np.float32 @@ -250,12 +252,18 @@ def build(self, ids, table_ids, embedding_grad): _, cur_rank_inverse_indices, ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables) + if fp16: + embedding_grad = flow.cast(embedding_grad, flow.float16) cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle( embedding_grad, num_unique_matrix, cur_rank_inverse_indices, inverse_unique_partition_indices, ) + if fp16: + cur_rank_unique_embedding_grad = flow.cast( + cur_rank_unique_embedding_grad, flow.float32 + ) return ( cur_rank_unique_embedding_grad, flow.cast(cur_rank_num_unique, flow.int32), @@ -280,6 +288,8 @@ def build(self, ids, table_ids, embedding_grad): np_unique_ids = np.unique(global_ids) np_num_unique = np_unique_ids.size np_cur_rank_unique_embedding_grad = np.zeros((max_id, embedding_size)) + if fp16: + global_embedding_grad = global_embedding_grad.astype(np.float16) for k in range(np_num_unique): unique_id = np_unique_ids[k] np_data = sum( @@ -300,6 +310,10 @@ def build(self, ids, table_ids, embedding_grad): np_data = np_data * dequantize_factor np_cur_rank_unique_embedding_grad[unique_id, :] = np_data + if fp16: + np_cur_rank_unique_embedding_grad = np_cur_rank_unique_embedding_grad.astype( + np.float32 + ) cur_rank_num_ids = batch_size * num_tables * parallel_num of_unique_embedding_grad = np.zeros((max_id, embedding_size)) @@ -346,6 +360,8 @@ def test_embedding_shuffle(test_case): def test_embedding_gradient_shuffle(test_case): arg_dict = OrderedDict() arg_dict["enable_quantize"] = [True, False] + arg_dict["fp16"] = [True, False] + arg_dict["embedding_size"] = [128, 17] for kwargs in GenArgDict(arg_dict): _test_embedding_gradient_shuffle(test_case, **kwargs)