Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xpu] add topk_v2 and fc fuse #9207

Merged
merged 1 commit into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ namespace fusion {

class XPUFcFuser : public FuseBase {
public:
explicit XPUFcFuser(bool with_bias, const std::string& act_type) {
explicit XPUFcFuser(bool with_bias,
const std::string& act_type,
const std::string& mul_type) {
with_bias_ = with_bias;
act_type_ = act_type;
mul_type_ = mul_type;
}

void BuildPattern() override {
auto* x = VarNode("x")->assert_is_op_input("mul", "X")->AsInput();
auto* W = VarNode("W")->assert_is_op_input("mul", "Y")->AsInput();
auto* mul = OpNode("mul", "mul")->AsIntermediate();
auto* mul_out = VarNode("mul_out")->assert_is_op_output("mul", "Out");
auto* x = VarNode("x")->assert_is_op_input(mul_type_, "X")->AsInput();
auto* W = VarNode("W")->assert_is_op_input(mul_type_, "Y")->AsInput();
auto* mul = OpNode("mul", mul_type_)->AsIntermediate();
auto* mul_out = VarNode("mul_out")->assert_is_op_output(mul_type_, "Out");
PMNode* bias = nullptr;
PMNode* add = nullptr;
PMNode* add_out = nullptr;
Expand Down Expand Up @@ -81,7 +84,6 @@ class XPUFcFuser : public FuseBase {
op_desc.SetType("__xpu__fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("Filter", {matched.at("W")->arg()->name});

if (with_bias_) {
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
}
Expand All @@ -102,6 +104,7 @@ class XPUFcFuser : public FuseBase {
{"relu", 1},
{"sigmoid", 2},
{"tanh", 3},
{"gelu", 4},
{"leaky_relu", 5},
{"hard_swish", 14},
{"hard_sigmoid", 15},
Expand All @@ -117,9 +120,29 @@ class XPUFcFuser : public FuseBase {
}
op_desc.SetAttr<int>("act_type", act_map[act_type_]);
op_desc.SetAttr<float>("act_param", act_param_);
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));

op_desc.SetAttr<int>("in_num_col_dims", -1);
if (mul_type_ == "mul") {
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
op_desc.SetAttr("transpose_x", false);
op_desc.SetAttr("transpose_w", true);
} else if (mul_type_ == "matmul") {
op_desc.SetAttr(
"transpose_x",
matched.at("mul")->stmt()->op_info()->GetAttr<bool>("transpose_X"));
op_desc.SetAttr(
"transpose_w",
matched.at("mul")->stmt()->op_info()->GetAttr<bool>("transpose_Y"));
} else {
op_desc.SetAttr(
"transpose_x",
matched.at("mul")->stmt()->op_info()->GetAttr<bool>("trans_x"));
op_desc.SetAttr(
"transpose_w",
matched.at("mul")->stmt()->op_info()->GetAttr<bool>("trans_y"));
}

std::string max_output_name = output_name + "_xpu_max";
auto* max_output_node = graph->NewArgumentNode(max_output_name);
Expand Down Expand Up @@ -147,6 +170,7 @@ class XPUFcFuser : public FuseBase {
private:
bool with_bias_;
std::string act_type_;
std::string mul_type_;
};

} // namespace fusion
Expand All @@ -158,15 +182,18 @@ class XPUFcFusePass : public ProgramPass {
// TODO(weihaoji) support with_no_bias and more activation types
for (auto with_bias : {true, /*false*/}) {
for (auto act_type : {"relu",
"gelu",
/*"sigmoid",
"tanh",
"leaky_relu",
"hard_swish",
"hard_sigmoid",
"relu6",*/
"linear"}) {
fusion::XPUFcFuser fuser(with_bias, act_type);
fuser(graph.get());
for (auto mul_type : {"mul", "matmul_v2"}) {
fusion::XPUFcFuser fuser(with_bias, act_type, mul_type);
fuser(graph.get());
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/xpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc)
add_kernel(sequence_unpad_compute_xpu XPU extra SRCS sequence_unpad_compute.cc)
add_kernel(lrn_compute_xpu XPU extra SRCS lrn_compute.cc)
add_kernel(topk_compute_xpu XPU extra SRCS topk_compute.cc)
add_kernel(topk_v2_compute_xpu XPU extra SRCS topk_v2_compute.cc)
add_kernel(im2sequence_compute_xpu XPU extra SRCS im2sequence_compute.cc)
add_kernel(unstack_compute_xpu XPU extra SRCS unstack_compute.cc)
add_kernel(norm_compute_xpu XPU extra SRCS norm_compute.cc)
Expand Down
25 changes: 18 additions & 7 deletions lite/kernels/xpu/__xpu__fc_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void XPUFcCompute<TGEMM, TW, DX, DY, PType>::PrepareForRun() {
auto w_ptr = param.w->template data<float>();
auto weight_dims = param.w->dims();
bool quant_int8 = false;
bool w_trans = param.transpose_w;
if (param.quant_w_max > 0.f) {
quant_int8 = true;
}
Expand All @@ -45,7 +46,7 @@ void XPUFcCompute<TGEMM, TW, DX, DY, PType>::PrepareForRun() {
if (quant_int8) { // for paddle slim int8 quant
xpu_quant_weight_ =
TargetWrapperXPU::ConvertCPUWeightToXPUQuantWeight<int8_t, int8_t>(
reinterpret_cast<const int8_t*>(w_ptr), weight_dims, true);
reinterpret_cast<const int8_t*>(w_ptr), weight_dims, w_trans);
std::vector<float> cpu_w_max(max_ptr_size, param.quant_w_max);
CHECK(xpu_quant_weight_.max_ptr_ != nullptr)
<< "slim int8 quant xpu_quant_weight_max_ptr should't be null";
Expand All @@ -62,7 +63,7 @@ void XPUFcCompute<TGEMM, TW, DX, DY, PType>::PrepareForRun() {
} else {
xpu_quant_weight_ =
TargetWrapperXPU::ConvertCPUWeightToXPUQuantWeight<float, TW>(
w_ptr, weight_dims, true);
w_ptr, weight_dims, w_trans);
if (std::is_same<TW, float>::value) {
VLOG(6)
<< "If fc compute precision is int31,must check weight max should "
Expand All @@ -72,6 +73,7 @@ void XPUFcCompute<TGEMM, TW, DX, DY, PType>::PrepareForRun() {
}
}
}

template <typename TGEMM,
typename TW,
typename DX,
Expand All @@ -82,6 +84,9 @@ void XPUFcCompute<TGEMM, TW, DX, DY, PType>::Run() {
auto& ctx = this->ctx_->template As<XPUContext>();

auto input_dims = param.input->dims();
if (param.in_num_col_dims == -1) {
param.in_num_col_dims += input_dims.size();
}
auto in_mat_dims = input_dims.Flatten2D(param.in_num_col_dims);
int m = in_mat_dims[0];
int k = in_mat_dims[1];
Expand All @@ -90,6 +95,12 @@ void XPUFcCompute<TGEMM, TW, DX, DY, PType>::Run() {
int max_ptr_size = ctx.GetRawContext()->max_ptr_size();
param.output_max->Resize({max_ptr_size});

bool x_trans = param.transpose_x;
bool w_trans = param.transpose_w;
int ldx = (x_trans ? m : k);
int ldw = (w_trans ? k : n);
int ldy = n;

float* output_max =
quant_int8 ? nullptr
: param.output_max->template mutable_data<float>(TARGET(kXPU));
Expand All @@ -116,14 +127,14 @@ void XPUFcCompute<TGEMM, TW, DX, DY, PType>::Run() {
m, // m
n, // n
k, // k
false, // x_trans
true, // w_trans
x_trans, // x_trans
w_trans, // w_trans
input_max, // x_maxptr
reinterpret_cast<const float*>(xpu_quant_weight_.max_ptr_), // w_maxptr
output_max, // y_maxptr
k, // ldx
k, // ldw
n, // ldy
ldx, // ldx
ldw, // ldw
ldy, // ldy
1.0f, // alpha
0.0f, // beta
bias, // bias
Expand Down
86 changes: 86 additions & 0 deletions lite/kernels/xpu/topk_v2_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/topk_v2_compute.h"
#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

void TopkV2Compute::Run() {
auto& param = this->Param<operators::TopkParam>();
auto& ctx = this->ctx_->As<XPUContext>();
const float* x_data = param.X->data<float>();
float* out_val = param.Out->mutable_data<float>();
auto out_ind = param.Indices->mutable_data<int64_t>();

DDim x_dims = param.X->dims();
int axis = param.axis;
CHECK_EQ(axis, -1);
int dim_size = x_dims.size();
if (axis < 0) {
axis += dim_size;
}

int k = param.K;
if (param.k_is_tensor) {
k = param.KTensor->data<int>()[0];
}

int m = x_dims.count(0, axis);
int n = x_dims[axis];

XPUScratchPadGuard indices_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(m * k * sizeof(int));

int* indices_int32_device = reinterpret_cast<int*>(indices_xpu_guard_->addr_);
int64_t* indices_int64_device =
param.Indices->mutable_data<int64_t>(TARGET(kXPU));

int r = xdnn::sorted_topk(ctx.GetRawContext(),
param.X->data<float>(),
param.Out->mutable_data<float>(TARGET(kXPU)),
indices_int32_device,
m,
n,
k);
CHECK_EQ(r, 0);

r = xdnn::cast_v2<int, int64_t>(
ctx.GetRawContext(), indices_int32_device, indices_int64_device, m * k);

CHECK_EQ(r, 0);
}

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(top_k_v2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::TopkV2Compute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("K", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Indices",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.Finalize();
33 changes: 33 additions & 0 deletions lite/kernels/xpu/topk_v2_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

class TopkV2Compute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
virtual void Run();

virtual ~TopkV2Compute() = default;
};

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
6 changes: 5 additions & 1 deletion lite/operators/__xpu__fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ bool XPUFcOp::CheckShape() const {
CHECK_EQ_OR_FALSE(bias_dims[0], w_dims_1);
}
}
if (param_.in_num_col_dims == -1) {
param_.in_num_col_dims += input_dims.size();
}

CHECK_GT_OR_FALSE(input_dims.size(),
static_cast<size_t>(param_.in_num_col_dims));
Expand Down Expand Up @@ -88,7 +91,8 @@ bool XPUFcOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.act_param = op_desc.GetAttr<float>("act_param");
param_.has_bias = op_desc.GetAttr<bool>("has_bias");
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");

param_.transpose_x = op_desc.GetAttr<bool>("transpose_x");
param_.transpose_w = op_desc.GetAttr<bool>("transpose_w");
// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
Expand Down
2 changes: 2 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,8 @@ struct XPUFcParam : ParamBase {
std::string precision{};
bool has_bias{false};
int in_num_col_dims{1};
bool transpose_x{false};
bool transpose_w{true};
};

struct XPUResNetCbamParam : ParamBase {
Expand Down
4 changes: 4 additions & 0 deletions lite/tests/kernels/topk_v2_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ void test_topk_v2(Place place, float abs_error) {
std::vector<std::vector<int64_t>>{{2, 3, 4, 5}, {3, 4, 5}, {4, 5}}) {
#if defined(NNADAPTER_WITH_HUAWEI_ASCEND_NPU)
for (int axis : {-1}) {
#elif defined(LITE_WITH_XPU)
for (int axis : {-1}) {
#else
for (int axis : {-1, -2, 0}) {
#endif
Expand Down Expand Up @@ -153,6 +155,8 @@ TEST(Topk, precision) {
#endif
#elif defined(LITE_WITH_ARM)
place = TARGET(kHost);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
Expand Down