Skip to content

Commit

Permalink
[ARM]fix ios bug && add mul quant op && add quant transformer support (
Browse files Browse the repository at this point in the history
  • Loading branch information
xingjing1 authored Dec 15, 2021
1 parent d92bff6 commit 3fc7367
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 25 deletions.
13 changes: 11 additions & 2 deletions lite/backends/arm/math/gemm_s8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,17 @@ void gemm_s8(bool is_transA,
for (int i = 0; i < N; i++) {
scale_ptr[i] = scale[0];
}
gemv_int8(
B, A, C, true, N, K, scale_ptr, is_bias, bias_ptr, act_param, ctx);
gemv_int8(B,
A,
C,
!is_transB,
N,
K,
scale_ptr,
is_bias,
bias_ptr,
act_param,
ctx);
return;
}

Expand Down
15 changes: 13 additions & 2 deletions lite/backends/host/math/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ struct Item {
return os.str();
}
};
LoD ToAbsOffset(const LoD &in) {
if (in.empty() || in.size() == 1) return in;
LoD result = in;
for (auto level = static_cast<int>(in.size() - 2); level >= 0; level--) {
for (size_t i = 0; i < in[level].size(); ++i) {
size_t index = in[level][i];
result[level][i] = result[level + 1][index];
}
}
return result;
}

/*
* Prune the source sentences all branchs finished, and it is optional.
Expand Down Expand Up @@ -149,7 +160,7 @@ std::vector<std::vector<Item>> SelectTopBeamSizeItems(const Tensor *pre_ids,

// find the current candidates
// auto abs_lod = framework::ToAbsOffset(scores->lod());
auto abs_lod = scores->lod();
auto abs_lod = ToAbsOffset(scores->lod());
auto *pre_ids_data = pre_ids->data<int64_t>();
auto *pre_scores_data = pre_scores->data<float>();

Expand Down Expand Up @@ -206,7 +217,7 @@ void beam_search(const Tensor *pre_ids,
int beam_size,
int end_id,
bool is_accumulated) {
auto abs_lod = scores->lod();
auto abs_lod = ToAbsOffset(scores->lod());
auto &high_level = abs_lod[level];
auto items = SelectTopBeamSizeItems(pre_ids,
pre_scores,
Expand Down
2 changes: 1 addition & 1 deletion lite/core/optimizer/mir/fusion/quant_dequant_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {

// fuse quantized node and dequant node
std::vector<std::string> quantized_op_types = {
"conv2d", "depthwise_conv2d", "conv2d_transpose", "mul"};
"conv2d", "depthwise_conv2d", "conv2d_transpose", "mul", "matmul"};
for (auto& op_type : quantized_op_types) {
fusion::DequantOpFuser fuser(op_type);
fuser(graph.get());
Expand Down
3 changes: 3 additions & 0 deletions lite/core/optimizer/mir/fusion/quant_dequant_op_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
auto op_info = *quantized_node->stmt()->op_info();
op_info.UpdateAllInputs(output_var_name, input_var_name);
op_info.SetAttr<int>("bit_length", bit_length);
#ifndef LITE_WITH_FPGA
op_info.SetAttr("enable_int8", true);
#endif

if (input_var_is_activation) {
op_info.SetInputScale(input_var_name, scales);
Expand Down
3 changes: 2 additions & 1 deletion lite/kernels/arm/fc_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
if (param.bias) {
bias_.Resize(param.bias->dims());
auto* ptr = bias_.mutable_data<float>();
auto* ptr_in = bias_.data<float>();
auto* ptr_in = param.bias->data<float>();
float out_scale = param.output_scale;
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i] / out_scale;
Expand Down Expand Up @@ -300,6 +300,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
// b_data = param.bias->data<float>();
bool flag_relu = false;
operators::ActivationParam act_param;
act_param.has_active = false;
Expand Down
38 changes: 28 additions & 10 deletions lite/kernels/arm/matmul_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,21 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::ReInitWhenNeeded() {
}

scale_.resize(n_);
scale_one.resize(n_);
for (int i = 0; i < n_; i++) {
param.output_scale = param.input_scale * param.weight_scale[i];
scale_[i] = param.output_scale;
scale_one.resize(m_);

if (param.weight_scale.size() == 1) {
param.output_scale =
param.input_scale * param.weight_scale[0] * param.alpha;
for (int i = 0; i < n_; i++) {
scale_[i] = param.output_scale;
}
} else {
for (int i = 0; i < n_; i++) {
param.output_scale = param.input_scale * param.weight_scale[i];
scale_[i] = param.output_scale;
}
}
for (int i = 0; i < m_; i++) {
scale_one[i] = 1;
}
}
Expand Down Expand Up @@ -362,6 +373,7 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
operators::ActivationParam act_param;
act_param.has_active = false;

if ((x_dims.size() >= 2 && y_dims.size() >= 2) &&
(x_dims.size() != 2 || y_dims.size() != 2)) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
Expand All @@ -385,9 +397,10 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data + i * out_inner,
nullptr,
false,
scale_.data(),
scale_one.data(),
act_param,
&ctx);
matmul_add_n_scale_bias(o_data + i * out_inner, scale_.data(), m_, n_);
}
} else if (x_dims.size() > 2 && y_dims.size() == 2) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
Expand All @@ -404,6 +417,7 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
scale_one.data(),
act_param,
&ctx);
matmul_add_n_scale_bias(o_data + i * out_inner, scale_.data(), m_, n_);
}
} else if (x_dims.size() == 2 && y_dims.size() > 2) {
for (size_t i = 0; i < y_dims.count(0, y_dims.size() - 2); ++i) {
Expand All @@ -417,9 +431,10 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data + i * out_inner,
nullptr,
false,
scale_.data(),
scale_one.data(),
act_param,
&ctx);
matmul_add_n_scale_bias(o_data + i * out_inner, scale_.data(), m_, n_);
}
}
} else if ((x_dims.size() == 2 && y_dims.size() == 2) ||
Expand All @@ -435,9 +450,10 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data,
nullptr,
false,
scale_.data(),
scale_one.data(),
act_param,
&ctx);
matmul_add_n_scale_bias(o_data, scale_.data(), m_, n_);
} else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
Expand All @@ -449,6 +465,7 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data[i] += x_data[i * y_dims[0] + j] * y_data[j] * alpha;
}
}
matmul_add_n_scale_bias(o_data, scale_.data(), m_, n_);
} else if (x_dims.size() == 1 && y_dims.size() == 1) {
// x: [K], y: [K], out: [1]
if (x_dims[0] == y_dims[0] && x_transpose == false &&
Expand All @@ -467,7 +484,7 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
false,
m_,
k_,
scale_.data(),
scale_one.data(),
false,
nullptr,
act_param,
Expand All @@ -488,16 +505,17 @@ void MatMulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data,
nullptr,
false,
scale_.data(),
scale_one.data(),
act_param,
&ctx);
}
}

matmul_add_n_scale_bias(o_data, scale_.data(), m_, n_);
} else {
LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
}
matmul_add_n_scale_bias(o_data, scale_.data(), m_, n_);
}

} // namespace arm
Expand Down
114 changes: 110 additions & 4 deletions lite/kernels/arm/mul_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ namespace lite {
namespace kernels {
namespace arm {

void MulCompute::PrepareForRun() {
template <>
void MulCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
}

void MulCompute::Run() {
template <>
void MulCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = Param<param_t>();

const auto* x_data = param.x->data<float>();
Expand Down Expand Up @@ -95,14 +97,118 @@ void MulCompute::Run() {
}
}

void mul_add_n_scale_bias(float* o_data, float* scale_, int m_, int n_) {
float32x4_t bias_v, scale_v, out_v, tmp_v;
int n_tail = n_ % 4;
int n_inner = n_ - n_tail;
for (int i = 0; i < m_; i++) {
for (int j = 0; j < n_inner; j += 4) {
tmp_v = vld1q_f32(&o_data[i * n_ + j]);
scale_v = vld1q_f32(&scale_[j]);
out_v = vmulq_f32(scale_v, tmp_v);
vst1q_f32(&o_data[i * n_ + j], out_v);
}
for (int j = n_inner; j < n_; j++) {
o_data[i * n_ + j] *= scale_[j];
}
}
}

template <>
void MulCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
}

template <>
void MulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = Param<param_t>();

const auto* x_data = param.x->data<int8_t>();
const auto* y_data = param.y->data<int8_t>();
auto* o_data = param.output->mutable_data<float>();

m_ = static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production());
int x_w =
static_cast<int>(param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production());
int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production());
n_ = static_cast<int>(param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());

scale_.resize(n_);
scale_one.resize(m_);
for (int i = 0; i < n_; i++) {
param.output_scale = param.input_scale * param.weight_scale[i];
scale_[i] = param.output_scale;
}
for (int i = 0; i < m_; i++) {
scale_one[i] = 1;
}

CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
k_ = x_w;
auto& ctx = this->ctx_->template As<ARMContext>();
operators::ActivationParam act_param;
act_param.has_active = false;
if (n_ == 1) {
lite::arm::math::gemv_int8(x_data,
y_data,
o_data,
false,
m_,
k_,
scale_one.data(),
false,
nullptr,
act_param,
&ctx);
} else {
constexpr bool is_tranposed_y = false;
int ldb = n_;
if (is_tranposed_y) {
ldb = k_;
}
lite::arm::math::gemm_s8(is_tranposed_y,
false,
m_,
n_,
k_,
x_data,
y_data,
o_data,
nullptr,
false,
scale_one.data(),
act_param,
&ctx);
}
mul_add_n_scale_bias(o_data, scale_.data(), m_, n_);
}

} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
typedef paddle::lite::kernels::arm::MulCompute<PRECISION(kFloat),
PRECISION(kFloat)>
Mul_f32_f32;

REGISTER_LITE_KERNEL(
mul, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::MulCompute, def)
REGISTER_LITE_KERNEL(mul, kARM, kFloat, kNCHW, Mul_f32_f32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();

typedef paddle::lite::kernels::arm::MulCompute<PRECISION(kInt8),
PRECISION(kFloat)>
Mul_int8_f32;

REGISTER_LITE_KERNEL(mul, kARM, kInt8, kNCHW, Mul_int8_f32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
5 changes: 4 additions & 1 deletion lite/kernels/arm/mul_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
Expand All @@ -22,7 +23,8 @@ namespace lite {
namespace kernels {
namespace arm {

class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <PrecisionType PType, PrecisionType OutType>
class MulCompute : public KernelLite<TARGET(kARM), PType> {
public:
using param_t = operators::MulParam;

Expand All @@ -34,6 +36,7 @@ class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {

private:
int m_, n_, k_;
std::vector<float> scale_, scale_one;
};

} // namespace arm
Expand Down
6 changes: 3 additions & 3 deletions lite/kernels/arm/mul_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ TEST(mul_arm, retrive_op) {
}

TEST(mul_arm, init) {
MulCompute mul;
MulCompute<PRECISION(kFloat), PRECISION(kFloat)> mul;
ASSERT_EQ(mul.precision(), PRECISION(kFloat));
ASSERT_EQ(mul.target(), TARGET(kARM));
}
Expand Down Expand Up @@ -105,7 +105,7 @@ TEST(mul_arm, compare_test) {
FillData<T>(out_data, out.dims().production(), 0, 0);
FillData<T>(ref_data, ref.dims().production(), 0, 0);

MulCompute mul;
MulCompute<PRECISION(kFloat), PRECISION(kFloat)> mul;
operators::MulParam param;

param.x = &x;
Expand Down Expand Up @@ -150,7 +150,7 @@ TEST(mul_arm, num_col_dims) {
FillData<T>(out_data, out.dims().production());
FillData<T>(ref_data, out.dims().production());

MulCompute mul;
MulCompute<PRECISION(kFloat), PRECISION(kFloat)> mul;
operators::MulParam param;

param.x = &x;
Expand Down
Loading

0 comments on commit 3fc7367

Please sign in to comment.