Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86] add gru_unit #5739

Merged
merged 1 commit into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lite/kernels/x86/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ endif()
# lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} )
add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute)
#add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps})
add_kernel(gru_unit_compute_x86 X86 basic SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_conv_compute_x86 X86 basic SRCS sequence_conv_compute.cc DEPS ${lite_kernel_deps} math_function blas context_project)

Expand Down
144 changes: 144 additions & 0 deletions lite/kernels/x86/gru_unit_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/x86/gru_unit_compute.h"
#include "lite/backends/x86/math/blas.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = lite::fluid::EigenMatrix<T, MajorType, IndexType>;

template <class T>
void GRUUnitCompute<T>::Run() {
auto& param = this->Param<param_t>();
auto& context = ctx_->As<X86Context>();
auto* input = param.input;
auto* hidden_prev = param.hidden_prev;
auto* weight = param.weight;
auto* bias = param.bias;
auto* gate = param.gate;
gate->template mutable_data<T>();
auto* reset_hidden_prev = param.reset_hidden_prev;
reset_hidden_prev->template mutable_data<T>();
auto* hidden = param.hidden;
hidden->template mutable_data<T>();

int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];

auto x = EigenMatrix<T>::From(*input);
auto h_p = EigenMatrix<T>::From(*hidden_prev);
auto g = EigenMatrix<T>::From(*gate);
auto r_h_p = EigenMatrix<T>::From(*reset_hidden_prev);
auto h = EigenMatrix<T>::From(*hidden);
const auto& place = lite::fluid::EigenDeviceType<lite::TargetType::kX86>();

if (bias) {
auto b = EigenMatrix<T>::From(*bias);
g.device(place) = x +
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
} else {
g.device(place) = x;
}

// calculate unactivated gate outputs
const T* hidden_prev_data = hidden_prev->template data<T>();
const T* weight_data = weight->template data<T>();
T* gate_data = gate->template mutable_data<T>();
T* reset_hidden_prev_data = reset_hidden_prev->template mutable_data<T>();
auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
blas.GEMM(false,
false,
batch_size,
2 * frame_size,
frame_size,
1,
hidden_prev_data,
frame_size,
weight_data,
frame_size * 2,
1,
gate_data,
frame_size * 3);

// calculate activited gate
Eigen::array<int, 2> extents{{batch_size, frame_size}};
Eigen::array<int, 2> u_offsets{{0, 0}};
ActCompute(param.gate_activation,
place,
g.slice(u_offsets, extents),
g.slice(u_offsets, extents));
auto u = g.slice(u_offsets, extents); // update gate
Eigen::array<int, 2> r_offsets{{0, frame_size}};
ActCompute(param.gate_activation,
place,
g.slice(r_offsets, extents),
g.slice(r_offsets, extents));
auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state
blas.GEMM(false,
false,
batch_size,
frame_size,
frame_size,
1,
reset_hidden_prev_data,
frame_size,
weight_data + frame_size * frame_size * 2,
frame_size,
1,
gate_data + frame_size * 2,
frame_size * 3);

Eigen::array<int, 2> c_offsets{{0, frame_size * 2}};
ActCompute(param.activation,
place,
g.slice(c_offsets, extents),
g.slice(c_offsets, extents));
auto c = g.slice(c_offsets, extents); // output candidate

// calculate final output
if (param.origin_mode) {
h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p
} else {
h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p
}
}

} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(gru_unit,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::GRUUnitCompute<float>,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("HiddenPrev", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Gate", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("ResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
62 changes: 62 additions & 0 deletions lite/kernels/x86/gru_unit_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
#include "lite/kernels/x86/activation_compute.h"
#include "lite/utils/macros.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 };

template <class T>
class GRUUnitCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::GRUUnitParam;

void Run() override;

virtual ~GRUUnitCompute() = default;

template <typename Device, typename X, typename Y>
void ActCompute(const int act_type, const Device& d, X x, Y y) const {
switch (GRUActivationType(act_type)) {
case identity:
y.device(d) = x;
break;
case sigmoid:
SigmoidFunctor<T>()(d, x, y);
break;
case tanh:
TanhFunctor<T>()(d, x, y);
break;
case relu:
ReluFunctor<T>()(d, x, y);
break;
default:
LOG(FATAL) << "Unsupported activation type, only supports identity, "
"sigmoid, tanh and relu.";
}
}
};

} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
22 changes: 13 additions & 9 deletions lite/tests/kernels/gru_unit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,22 +318,22 @@ class GRUUnitTester : public arena::TestCase {

// set input data
std::vector<float> data(dims_.production());
fill_data_rand(data.data(), 0.f, 1.f, dims_.production());
fill_data_rand(data.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(input_, dims_, data.data());

// set hidden_prev data
data.resize(hpdim.production());
fill_data_rand(data.data(), 0.f, 1.f, hpdim.production());
fill_data_rand(data.data(), -1.f, 1.f, hpdim.production());
SetCommonTensor(hidden_prev_, hpdim, data.data());

// set weight data
data.resize(wdim.production());
fill_data_rand(data.data(), 0.f, 1.f, wdim.production());
fill_data_rand(data.data(), -1.f, 1.f, wdim.production());
SetCommonTensor(weight_, wdim, data.data());

// set bias data
data.resize(bdim.production());
fill_data_rand(data.data(), 0.f, 1.f, bdim.production());
fill_data_rand(data.data(), -1.f, 1.f, bdim.production());
SetCommonTensor(bias_, bdim, data.data());
}
};
Expand All @@ -346,17 +346,21 @@ void test_gru_unit(Place place) {
auto& ctx = tester->context()->template As<ARMContext>();
ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 1);
#endif
arena::Arena arena(std::move(tester), place, 2e-5);
arena::Arena arena(std::move(tester), place, 1e-4);
arena.TestPrecision();
}

TEST(GRUUnit, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_gru_unit(place);
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_X86)
place = TARGET(kX86);
#else
Place place(TARGET(kHost));
return;
#endif

test_gru_unit(place);
}

} // namespace lite
Expand Down