Skip to content

Commit

Permalink
[XPU] Add embedding plugin (PaddlePaddle#56488)
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 authored and BeingGod committed Sep 9, 2023
1 parent 60f8b19 commit 0af594a
Show file tree
Hide file tree
Showing 8 changed files with 725 additions and 21 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ if(WITH_XPU)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(cast_embedding_trans_ids_to_int32_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>

#include "glog/logging.h"

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct CastEmbeddingTransIdsToInt32Pattern : public PatternBase {
CastEmbeddingTransIdsToInt32Pattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast);
PATTERN_DECL_NODE(embedding);
// declare variable node's name
PATTERN_DECL_NODE(cast_x);
PATTERN_DECL_NODE(embedding_ids);
PATTERN_DECL_NODE(embedding_w);
PATTERN_DECL_NODE(embedding_out);
};

CastEmbeddingTransIdsToInt32Pattern::CastEmbeddingTransIdsToInt32Pattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto cast = pattern->NewNode(cast_repr())->assert_is_op("cast");
auto cast_x = pattern->NewNode(cast_x_repr())
->assert_is_op_input("cast", "X")
->assert_var_not_persistable()
->AsInput();
auto embedding_ids = pattern->NewNode(embedding_ids_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("lookup_table_v2", "Ids")
->assert_has_n_outputs(1);
cast->LinksFrom({cast_x}).LinksTo({embedding_ids});
auto embedding_w = pattern->NewNode(embedding_w_repr())
->assert_is_op_input("lookup_table_v2", "W");
auto embedding =
pattern->NewNode(embedding_repr())->assert_is_op("lookup_table_v2");
auto embedding_out = pattern->NewNode(embedding_out_repr())
->assert_is_op_output("lookup_table_v2", "Out")
->AsOutput();
embedding->LinksFrom({embedding_ids, embedding_w}).LinksTo({embedding_out});
}

} // namespace patterns

class CastEmbeddingTransIdsToInt32Pass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
const std::string name_scope_{"cast_embedding_trans_ids_to_int32_pass"};
};
void CastEmbeddingTransIdsToInt32Pass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);

GraphPatternDetector gpd;
patterns::CastEmbeddingTransIdsToInt32Pattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle CastEmbeddingTransIdsToInt32Pass";
GET_IR_NODE(cast);
GET_IR_NODE(embedding);
GET_IR_NODE(embedding_ids);
auto cast_node_attr_out_dtype =
cast->Op()->GetAttrIfExists<int>("out_dtype");
if (cast_node_attr_out_dtype !=
static_cast<int>(paddle::framework::proto::VarType::INT64)) {
return;
}
cast->Op()->SetAttr(
"out_dtype",
static_cast<int>(paddle::framework::proto::VarType::INT32));
embedding_ids->Var()->SetDataType(paddle::framework::proto::VarType::INT32);
embedding->Op()->Flush();
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
if (found_subgraph_count) {
VLOG(4) << "There is a risk of overflow when converting the data type of "
"embedded ids from int64 to int32."
"Please ensure that the numerical range of ids is within the "
"maximum value of int32."
"If it exceeds this range, it may result in incorrect results. "
"You can try removing this pass.";
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(cast_embedding_trans_ids_to_int32_pass,
paddle::framework::ir::CastEmbeddingTransIdsToInt32Pass);

REGISTER_PASS_CAPABILITY(cast_embedding_trans_ids_to_int32_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"lookup_table_v2", 1));
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"reshape_unstack_concat_fuse_pass",
"delete_op_device_pass",
"constant_folding_pass",
"cast_embedding_trans_ids_to_int32_pass",
"delete_elementwise_mul_op_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
Expand Down
72 changes: 51 additions & 21 deletions paddle/phi/kernels/xpu/embedding_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,6 @@ void EmbeddingKernel(const Context &ctx,
auto *table = table_t->data<T>();
auto *output = dev_ctx.template Alloc<T>(output_t);

xpu::ctx_guard RAII_GUARD(ctx.x_context());
const int64_t *ids;
if (ids_t->dtype() == phi::DataType::INT64) {
ids = ids_t->data<int64_t>();
} else {
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
int r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
ids = reinterpret_cast<const int64_t *>(ids_tt);
}

PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(),
true,
Expand All @@ -68,15 +56,57 @@ void EmbeddingKernel(const Context &ctx,
size_t xm = table_t->dims()[0];
size_t n = table_t->dims()[1];

int r = xpu::embedding<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids,
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);

int r;
xpu::ctx_guard RAII_GUARD(ctx.x_context());
if (ids_t->dtype() == phi::DataType::INT64) {
#ifndef PADDLE_WITH_XPU_PLUGIN
r = xpu::embedding<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int64_t>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#else
r = xpu::plugin::fast_embedding<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int64_t>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#endif
} else {
#ifndef PADDLE_WITH_XPU_PLUGIN
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
const int64_t *ids = reinterpret_cast<const int64_t *>(ids_tt);
r = xpu::embedding<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids,
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#else
r = xpu::plugin::fast_embedding<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#endif
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
}

Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ DLL_EXPORT int fast_reduce_min(Context* ctx,
const std::vector<int>& xshape,
const std::vector<int>& rdims);

template <typename T, typename TID>
DLL_EXPORT int fast_embedding(Context* ctx,
const T* x,
const TID* indices,
T* y,
int64_t xm,
int64_t n,
int64_t ym,
int64_t padding_idx,
TID start_index = 0);

} // namespace plugin
} // namespace api
} // namespace xpu
Expand Down
Loading

0 comments on commit 0af594a

Please sign in to comment.