From 0af594a71073594a159e81c6eec8246c9a459868 Mon Sep 17 00:00:00 2001 From: csy0225 <78470701+csy0225@users.noreply.github.com> Date: Thu, 24 Aug 2023 10:15:43 +0800 Subject: [PATCH] [XPU] Add embedding plugin (#56488) --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../cast_embedding_trans_ids_to_int32_pass.cc | 137 ++++++++++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/phi/kernels/xpu/embedding_kernel.cc | 72 ++++-- .../kernels/xpu/plugin/include/xpu/plugin.h | 11 + .../kunlun2cpp/embedding_fwd_tiny_dict.xpu | 240 ++++++++++++++++++ .../xpu/plugin/src/wrapper/fast_embedding.cpp | 189 ++++++++++++++ ..._cast_embedding_trans_ids_to_int32_pass.py | 94 +++++++ 8 files changed, 725 insertions(+), 21 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/cast_embedding_trans_ids_to_int32_pass.cc create mode 100644 paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/embedding_fwd_tiny_dict.xpu create mode 100644 paddle/phi/kernels/xpu/plugin/src/wrapper/fast_embedding.cpp create mode 100644 test/ir/inference/test_xpu_cast_embedding_trans_ids_to_int32_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index bcd016a43e9b41..b1dafb0d3934db 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -239,6 +239,8 @@ if(WITH_XPU) pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(cast_embedding_trans_ids_to_int32_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu diff --git a/paddle/fluid/framework/ir/xpu/cast_embedding_trans_ids_to_int32_pass.cc b/paddle/fluid/framework/ir/xpu/cast_embedding_trans_ids_to_int32_pass.cc new file mode 100644 index 00000000000000..f17f71e7b84a54 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/cast_embedding_trans_ids_to_int32_pass.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct CastEmbeddingTransIdsToInt32Pattern : public PatternBase { + CastEmbeddingTransIdsToInt32Pattern(PDPattern* pattern, + const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(cast); + PATTERN_DECL_NODE(embedding); + // declare variable node's name + PATTERN_DECL_NODE(cast_x); + PATTERN_DECL_NODE(embedding_ids); + PATTERN_DECL_NODE(embedding_w); + PATTERN_DECL_NODE(embedding_out); +}; + +CastEmbeddingTransIdsToInt32Pattern::CastEmbeddingTransIdsToInt32Pattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto cast = pattern->NewNode(cast_repr())->assert_is_op("cast"); + auto cast_x = pattern->NewNode(cast_x_repr()) + ->assert_is_op_input("cast", "X") + ->assert_var_not_persistable() + ->AsInput(); + auto embedding_ids = pattern->NewNode(embedding_ids_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("lookup_table_v2", "Ids") + ->assert_has_n_outputs(1); + cast->LinksFrom({cast_x}).LinksTo({embedding_ids}); + auto embedding_w = pattern->NewNode(embedding_w_repr()) + ->assert_is_op_input("lookup_table_v2", "W"); + auto embedding = + pattern->NewNode(embedding_repr())->assert_is_op("lookup_table_v2"); + auto embedding_out = pattern->NewNode(embedding_out_repr()) + ->assert_is_op_output("lookup_table_v2", "Out") + ->AsOutput(); + embedding->LinksFrom({embedding_ids, embedding_w}).LinksTo({embedding_out}); +} + +} // namespace patterns + +class CastEmbeddingTransIdsToInt32Pass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"cast_embedding_trans_ids_to_int32_pass"}; +}; +void CastEmbeddingTransIdsToInt32Pass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + GraphPatternDetector gpd; + patterns::CastEmbeddingTransIdsToInt32Pattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle CastEmbeddingTransIdsToInt32Pass"; + GET_IR_NODE(cast); + GET_IR_NODE(embedding); + GET_IR_NODE(embedding_ids); + auto cast_node_attr_out_dtype = + cast->Op()->GetAttrIfExists("out_dtype"); + if (cast_node_attr_out_dtype != + static_cast(paddle::framework::proto::VarType::INT64)) { + return; + } + cast->Op()->SetAttr( + "out_dtype", + static_cast(paddle::framework::proto::VarType::INT32)); + embedding_ids->Var()->SetDataType(paddle::framework::proto::VarType::INT32); + embedding->Op()->Flush(); + found_subgraph_count++; + }; + gpd(graph, handler); + AddStatis(found_subgraph_count); + if (found_subgraph_count) { + VLOG(4) << "There is a risk of overflow when converting the data type of " + "embedded ids from int64 to int32." + "Please ensure that the numerical range of ids is within the " + "maximum value of int32." + "If it exceeds this range, it may result in incorrect results. " + "You can try removing this pass."; + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(cast_embedding_trans_ids_to_int32_pass, + paddle::framework::ir::CastEmbeddingTransIdsToInt32Pass); + +REGISTER_PASS_CAPABILITY(cast_embedding_trans_ids_to_int32_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "lookup_table_v2", 1)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 052db0a9e1af6b..09d1197d35b556 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "reshape_unstack_concat_fuse_pass", "delete_op_device_pass", "constant_folding_pass", + "cast_embedding_trans_ids_to_int32_pass", "delete_elementwise_mul_op_pass", "generate_sequence_xpu_fuse_pass", "embedding_with_eltwise_add_xpu_fuse_pass", diff --git a/paddle/phi/kernels/xpu/embedding_kernel.cc b/paddle/phi/kernels/xpu/embedding_kernel.cc index 99faf8b5819661..7137357aa48b23 100644 --- a/paddle/phi/kernels/xpu/embedding_kernel.cc +++ b/paddle/phi/kernels/xpu/embedding_kernel.cc @@ -44,18 +44,6 @@ void EmbeddingKernel(const Context &ctx, auto *table = table_t->data(); auto *output = dev_ctx.template Alloc(output_t); - xpu::ctx_guard RAII_GUARD(ctx.x_context()); - const int64_t *ids; - if (ids_t->dtype() == phi::DataType::INT64) { - ids = ids_t->data(); - } else { - int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm(ids_t->numel()); - int r = xpu::cast( - ctx.x_context(), ids_t->data(), ids_tt, ids_t->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); - ids = reinterpret_cast(ids_tt); - } - PADDLE_ENFORCE_EQ( ids_numel <= std::numeric_limits::max(), true, @@ -68,15 +56,57 @@ void EmbeddingKernel(const Context &ctx, size_t xm = table_t->dims()[0]; size_t n = table_t->dims()[1]; - int r = xpu::embedding(dev_ctx.x_context(), - reinterpret_cast(table), - ids, - reinterpret_cast(output), - xm, - n, - ym, - padding_idx); - + int r; + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + if (ids_t->dtype() == phi::DataType::INT64) { +#ifndef PADDLE_WITH_XPU_PLUGIN + r = xpu::embedding( + dev_ctx.x_context(), + reinterpret_cast(table), + ids_t->data(), + reinterpret_cast(output), + xm, + n, + ym, + padding_idx); +#else + r = xpu::plugin::fast_embedding( + dev_ctx.x_context(), + reinterpret_cast(table), + ids_t->data(), + reinterpret_cast(output), + xm, + n, + ym, + padding_idx); +#endif + } else { +#ifndef PADDLE_WITH_XPU_PLUGIN + int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm(ids_t->numel()); + r = xpu::cast( + ctx.x_context(), ids_t->data(), ids_tt, ids_t->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + const int64_t *ids = reinterpret_cast(ids_tt); + r = xpu::embedding(dev_ctx.x_context(), + reinterpret_cast(table), + ids, + reinterpret_cast(output), + xm, + n, + ym, + padding_idx); +#else + r = xpu::plugin::fast_embedding( + dev_ctx.x_context(), + reinterpret_cast(table), + ids_t->data(), + reinterpret_cast(output), + xm, + n, + ym, + padding_idx); +#endif + } PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); } diff --git a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h index c063ebf949b917..1357ca43001c8b 100644 --- a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h +++ b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h @@ -104,6 +104,17 @@ DLL_EXPORT int fast_reduce_min(Context* ctx, const std::vector& xshape, const std::vector& rdims); +template +DLL_EXPORT int fast_embedding(Context* ctx, + const T* x, + const TID* indices, + T* y, + int64_t xm, + int64_t n, + int64_t ym, + int64_t padding_idx, + TID start_index = 0); + } // namespace plugin } // namespace api } // namespace xpu diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/embedding_fwd_tiny_dict.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/embedding_fwd_tiny_dict.xpu new file mode 100644 index 00000000000000..984b0111c52a50 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/embedding_fwd_tiny_dict.xpu @@ -0,0 +1,240 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/xtdk_io.h" + +namespace xpu2 { +namespace plugin { + +/* +Kernel usage conditions: Dict is tiny, Local memory can be loaded in at once. +Optimizer ideas: + - Reduce frequent memory handling, allocate fixed size buffers, accumulate +data to buffer size and move it out together. + + ********** Local Memory Addr ********** + Part 1: dict(size = dict_idx_len * emb_dim) + ----------------------------------- + Part 2: index(size = idx_len * sizeof(emb_idx_type)) + ----------------------------------- + Part 3: result + ----------------------------------- +*/ + +template +static inline __device__ void embedding_fwd_kl2_tiny_dict_align64( + _global_ptr_ const emb_idx_type* idx, + _global_ptr_ const char* dict, + _global_ptr_ char* featvec, + int64_t emb_dim, + int64_t dict_idx_len, + int64_t idx_len, + int64_t padding_idx, + emb_idx_type start_index) { + int cid = core_id(); + int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = ncores * cluster_num(); + int64_t row_start = -1; + int64_t row_end = -1; + partition(tid, nthreads, idx_len, 1, &row_start, &row_end); + + // 1. Pre allocation total Local Memory size = 6 KB + const int TOTAL_LM_SIZE = 6144; // 6 KB + __simd__ char lm[TOTAL_LM_SIZE]; + + // 2. Load dict from Global Memory to Local memory only once. + int total_emb_dict_size = dict_idx_len * emb_dim; + GM2LM(dict, lm, total_emb_dict_size); + + // residual_lm_space = index + result + int residual_lm_space = TOTAL_LM_SIZE - total_emb_dict_size - + 64; // 64 to preventing memory overflow, because the + // total index memory need to align to 64. + + // The maximum count that can be processed in one iteration. + int idx_cnt = residual_lm_space / (sizeof(emb_idx_type) + emb_dim); + int index_lm_offset = total_emb_dict_size; + int result_lm_offset = + total_emb_dict_size + + (idx_cnt * sizeof(emb_idx_type) + 64) / 64 * 64; // Align to 64 bytes + + // 3. Loop Calc + for (int64_t i = row_start; i < row_end; i += idx_cnt) { + int curr_idx_len = idx_cnt; + if (i + idx_cnt >= row_end) { + curr_idx_len = row_end - i; + } + // 3.1 Load idx to Local Memory + GM2LM(idx + i, lm + index_lm_offset, curr_idx_len * sizeof(emb_idx_type)); + + // 3.2 Save result into result memory buffer. + for (int j = 0; j < curr_idx_len; j++) { + emb_idx_type real_index = + *((emb_idx_type*)(lm + index_lm_offset + j * sizeof(emb_idx_type))) - + start_index; + if (real_index == padding_idx) { + for (int koffset = 0; koffset < emb_dim; koffset += 64) { + float32x16_t v_src = vload_lm_float32x16_mz((void*)lm, 0); + vstore_lm_float32x16( + (void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src); + } + } else { + if (real_index >= 0 && real_index < dict_idx_len) { + for (int koffset = 0; koffset < emb_dim; koffset += 64) { + float32x16_t v_src = vload_lm_float32x16( + (void*)(lm + real_index * emb_dim + koffset)); + vstore_lm_float32x16( + (void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src); + } + } else { + for (int koffset = 0; koffset < emb_dim; koffset += 64) { + float32x16_t v_src = vload_lm_float32x16_mz((void*)lm, 0); + vstore_lm_float32x16( + (void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src); + } + } + } + mfence_lm(); + } + // 3.3 Save result into global memory buffer. + LM2GM(lm + result_lm_offset, + (_global_ptr_ char*)(featvec + i * emb_dim), + curr_idx_len * emb_dim); + } +} + +template +static inline __device__ void embedding_fwd_kl2_tiny_dict_not_align64( + _global_ptr_ const emb_idx_type* idx, + _global_ptr_ const char* dict, + _global_ptr_ char* featvec, + int64_t emb_dim, + int64_t dict_idx_len, + int64_t idx_len, + int64_t padding_idx, + emb_idx_type start_index) { + int cid = core_id(); + int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = ncores * cluster_num(); + int64_t row_start = -1; + int64_t row_end = -1; + partition(tid, nthreads, idx_len, 1, &row_start, &row_end); + + // 1. Pre allocation total Local Memory size = 6 KB + const int TOTAL_LM_SIZE = 6144; // 6 KB + __local__ char lm[TOTAL_LM_SIZE]; + + // 2. Load dict from Global Memory to Local memory only once. + GM2LM(dict, lm, dict_idx_len * emb_dim); + + // residual_lm_space = index + result + int residual_lm_space = TOTAL_LM_SIZE - dict_idx_len * emb_dim; + + // The maximum count that can be processed in one iteration. + int idx_cnt = residual_lm_space / (sizeof(emb_idx_type) + emb_dim); + int index_lm_offset = dict_idx_len * emb_dim; + int result_lm_offset = index_lm_offset + idx_cnt * sizeof(emb_idx_type); + + // 3. Loop Calc + for (int64_t i = row_start; i < row_end; i += idx_cnt) { + int curr_idx_len = idx_cnt; + if (i + idx_cnt >= row_end) { + curr_idx_len = row_end - i; + } + // 3.1 Load idx to Local Memory + GM2LM(idx + i, lm + index_lm_offset, curr_idx_len * sizeof(emb_idx_type)); + + // 3.2 Save result into result memory buffer. + for (int j = 0; j < curr_idx_len; j++) { + emb_idx_type real_index = + *((emb_idx_type*)(lm + index_lm_offset + j * sizeof(emb_idx_type))) - + start_index; + if (real_index == padding_idx) { + for (int k = 0; k < emb_dim; k++) { + lm[result_lm_offset + j * emb_dim + k] = 0; + } + } else { + if (real_index >= 0 && real_index < dict_idx_len) { + for (int k = 0; k < emb_dim; k++) { + lm[result_lm_offset + j * emb_dim + k] = + lm[real_index * emb_dim + k]; + } + } else { + for (int k = 0; k < emb_dim; k++) { + lm[result_lm_offset + j * emb_dim + k] = 0; + } + } + } + mfence_lm(); + } + // 3.3 Save result into global memory buffer. + LM2GM(lm + result_lm_offset, + (_global_ptr_ char*)(featvec + i * emb_dim), + curr_idx_len * emb_dim); + } +} + +template +__global__ void embedding_fwd_kl2_tiny_dict(const emb_idx_type* idx, + const char* dict, + char* featvec, + int64_t emb_dim, + int64_t dict_idx_len, + int64_t idx_len, + int64_t padding_idx, + emb_idx_type start_index) { + if (emb_dim % 64 == 0) { + embedding_fwd_kl2_tiny_dict_align64(idx, + dict, + featvec, + emb_dim, + dict_idx_len, + idx_len, + padding_idx, + start_index); + } else { + embedding_fwd_kl2_tiny_dict_not_align64(idx, + dict, + featvec, + emb_dim, + dict_idx_len, + idx_len, + padding_idx, + start_index); + } +} + +#define _XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(EMB_IDX_TYPE) \ + template __global__ void embedding_fwd_kl2_tiny_dict( \ + const EMB_IDX_TYPE* idx, \ + const char* dict, \ + char* featvec, \ + int64_t emb_dim, \ + int64_t dict_idx_len, \ + int64_t idx_len, \ + int64_t padding_idx, \ + EMB_IDX_TYPE start_index); +_XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(int); +_XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(int64_t); + +} // namespace plugin +} // namespace xpu2 diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_embedding.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_embedding.cpp new file mode 100644 index 00000000000000..3bf4a04a7cd8d4 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_embedding.cpp @@ -0,0 +1,189 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include "xpu/refactor/util/vector_util.h" + +namespace xpu2 { +namespace plugin { +template +__attribute__((global)) void embedding_fwd_kl2_tiny_dict( + const emb_idx_type* idx, + const char* dict, + char* featvec, + int64_t emb_dim, + int64_t dict_idx_len, + int64_t idx_len, + int64_t padding_idx, + emb_idx_type start_index); +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +// CPU implementation +template +static int cpu_wrapper(Context* ctx, + const T* x, + const TID* indices, + T* y, + int64_t xm, + int64_t n, + int64_t ym, + int64_t padding_idx, + TID start_index) { + for (int64_t i = 0; i < ym; i++) { + TID real_index = indices[i] - start_index; // -start_index BEFORE compare + if (real_index == padding_idx) { + ::memset(y + i * n, 0, sizeof(T) * n); + } else { + if (real_index >= 0 && real_index < xm) { + std::memcpy(y + i * n, x + real_index * n, sizeof(T) * n); + } else { + // set zeros + for (int64_t k = 0; k < n; ++k) { + y[i * n + k] = 0; + } + } + } + } + return api::SUCCESS; +} + +template +static int xpu2_wrapper(Context* ctx, + const T* x, + const TID* indices, + T* y, + int64_t xm, + int64_t n, + int64_t ym, + int64_t padding_idx, + TID start_index) { + const int TOTAL_LM_SIZE = 6144; // 6 KB + int total_emb_dict_size = xm * n * sizeof(T); + // residual_lm_space = index + result + int residual_lm_space = TOTAL_LM_SIZE - total_emb_dict_size - 64; + // The maximum count that can be processed in one iteration. + int idx_cnt = residual_lm_space / (sizeof(TID) + n * sizeof(T)); + bool plugin_entry_condition = idx_cnt >= 16; + // This plugin is suitable for scenarios with relatively small dictionary + // sizes, requiring process greater than 16 index count one iter, in order to + // load the dictionary into local memory at once, and to leave enough space + // for the local memory to store the results. + if (plugin_entry_condition) { + using XPU_TID = typename XPUIndexType::type; + const XPU_TID* casted_indices = + static_cast(static_cast(indices)); + XPU_TID casted_start_index = static_cast(start_index); + if (ctx->dev().type() == api::kXPU2) { + xpu2::plugin::embedding_fwd_kl2_tiny_dict + <<ncluster(), 64, ctx->xpu_stream>>>( + casted_indices, + reinterpret_cast(x), + reinterpret_cast(y), + n * sizeof(T), + xm, + ym, + padding_idx, + casted_start_index); + } + } else { + embedding(ctx, x, indices, y, xm, n, ym, padding_idx, start_index); + } + + return api::SUCCESS; +} + +template +int fast_embedding(Context* ctx, + const T* x, + const TID* indices, + T* y, + int64_t xm, + int64_t n, + int64_t ym, + int64_t padding_idx, + TID start_index) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T2(ctx, "fast_embedding", T, TID); + WRAPPER_DUMP_PARAM6(ctx, x, indices, y, xm, n, ym); + WRAPPER_DUMP_PARAM3(ctx, padding_idx, start_index, ctx->_l3_mgr.get_size()); + WRAPPER_DUMP(ctx); + int64_t xlen = -1; + int64_t ylen = -1; + WRAPPER_CHECK_SHAPE(ctx, &xlen, {xm, n}); + WRAPPER_CHECK_SHAPE(ctx, &ylen, {ym, n}); + WRAPPER_CHECK_PTR(ctx, T, xlen, x); + WRAPPER_CHECK_PTR(ctx, T, ylen, y); + WRAPPER_CHECK_PTR(ctx, TID, ym, indices); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper( + ctx, x, indices, y, xm, n, ym, padding_idx, start_index); + } + if (ctx->dev().type() == api::kXPU2) { + return xpu2_wrapper( + ctx, x, indices, y, xm, n, ym, padding_idx, start_index); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int fast_embedding(Context*, + const float*, + const int*, + float*, + int64_t, + int64_t, + int64_t, + int64_t, + int); +template int fast_embedding(Context*, + const float*, + const int64_t*, + float*, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t); +template int fast_embedding(Context*, + const float16*, + const int*, + float16*, + int64_t, + int64_t, + int64_t, + int64_t, + int); +template int fast_embedding(Context*, + const float16*, + const int64_t*, + float16*, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/test/ir/inference/test_xpu_cast_embedding_trans_ids_to_int32_pass.py b/test/ir/inference/test_xpu_cast_embedding_trans_ids_to_int32_pass.py new file mode 100644 index 00000000000000..627af42d5fa861 --- /dev/null +++ b/test/ir/inference/test_xpu_cast_embedding_trans_ids_to_int32_pass.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestXpuCastEmbeddingTransIdsToInt32Pass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["cast", "lookup_table_v2"], (1e-5, 1e-5) + + def sample_program_config(self, draw): + ids_shape = draw(st.integers(min_value=1, max_value=128)) + w_shape = draw( + st.sampled_from([[20, 64], [32, 32], [23, 15], [24, 33]]) + ) + padding_idx = draw(st.sampled_from([-1])) + + cast_op = OpConfig( + "cast", + inputs={ + "X": ["cast_input"], + }, + outputs={"Out": ["cast_out"]}, + in_dtype=5, + out_dtype=3, + ) + lookup_table_op = OpConfig( + "lookup_table_v2", + inputs={ + "Ids": ["cast_out"], + "W": ["lookup_table_w"], + }, + outputs={"Out": ["lookup_table_out"]}, + padding_idx=padding_idx, + ) + + def gen_lookup_table_weights_data(): + weights = {} + w_name = "lookup_table_w" + weights[w_name] = TensorConfig(shape=w_shape) + return weights + + def generate_cast_input(*args, **kwargs): + return np.random.randint(0, w_shape[0], ids_shape).astype( + np.float32 + ) + + def gen_input_data(*args, **kwargs): + inputs = {} + input_name = "cast_input" + inputs[input_name] = TensorConfig( + data_gen=partial(generate_cast_input) + ) + return inputs + + inputs = gen_input_data() + weights = gen_lookup_table_weights_data() + + program_config = ProgramConfig( + ops=[cast_op, lookup_table_op], + weights=weights, + inputs=inputs, + outputs=["lookup_table_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["cast_embedding_trans_ids_to_int32_pass"], + ) + + +if __name__ == "__main__": + unittest.main()