Skip to content

Commit

Permalink
[Host] Add empty op and ceil op (#10092)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangzhou2000 authored Mar 20, 2023
1 parent 1f8b83e commit 44ae965
Show file tree
Hide file tree
Showing 11 changed files with 340 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lite/kernels/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ add_kernel(round_compute_host Host extra SRCS round_compute.cc)
add_kernel(temporal_shift_compute_host Host extra SRCS temporal_shift_compute.cc)
add_kernel(pad_compute_host Host extra SRCS pad_compute.cc)
add_kernel(bitwise_compute_host Host extra SRCS bitwise_compute.cc)
add_kernel(empty_compute_host Host extra SRCS empty_compute.cc)


if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc)
Expand Down
16 changes: 16 additions & 0 deletions lite/kernels/host/activation_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ void FloorCompute::Run() {
}
}

void CeilCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(param.X);
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
for (int i = 0; i < x_dims.production(); i++) {
output_data[i] = std::ceil(x_data[i]);
}
}

void HardSigmoidCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(param.X);
Expand Down Expand Up @@ -377,6 +388,11 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(
ceil, kHost, kFloat, kNCHW, paddle::lite::kernels::host::CeilCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(hard_sigmoid,
kHost,
kFloat,
Expand Down
9 changes: 9 additions & 0 deletions lite/kernels/host/activation_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ class FloorCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
virtual ~FloorCompute() = default;
};

class CeilCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;

void Run() override;

virtual ~CeilCompute() = default;
};

class HardSigmoidCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
Expand Down
60 changes: 60 additions & 0 deletions lite/kernels/host/empty_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/host/empty_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

void EmptyCompute::Run() {
auto& param = *param_.get_mutable<param_t>();
auto output = param.Out;
auto output_dims = output->dims();
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::BOOL)) {
output->set_precision(PRECISION(kBool));
output->template mutable_data<bool>();
} else if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
output->set_precision(PRECISION(kFloat));
output->template mutable_data<float>();
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
output->set_precision(PRECISION(kInt32));
output->template mutable_data<int32_t>();
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT64)) {
output->set_precision(PRECISION(kInt64));
output->template mutable_data<int64_t>();
} else {
output->set_precision(PRECISION(kInt32));
output->template mutable_data<int32_t>();
}

return;
}

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(
empty, kHost, kAny, kNCHW, paddle::lite::kernels::host::EmptyCompute, def)
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.Finalize();
36 changes: 36 additions & 0 deletions lite/kernels/host/empty_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

class EmptyCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
public:
using param_t = operators::EmptyParam;

void Run() override;

virtual ~EmptyCompute() = default;
};

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
3 changes: 3 additions & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ add_operator(pow_op extra SRCS pow_op.cc)
add_operator(sign_op extra SRCS sign_op.cc)
add_operator(rnn_op extra SRCS rnn_op.cc)

add_operator(empty_op extra SRCS empty_op.cc)

# 2.basic ops not used in basic models
add_operator(negative_op extra SRCS negative_op.cc)
add_operator(crop_op extra SRCS crop_op.cc)
Expand Down Expand Up @@ -212,6 +214,7 @@ add_operator(unique_with_counts_op extra SRCS unique_with_counts_op.cc)
add_operator(unique_op extra SRCS unique_op.cc)
add_operator(viterbi_decode extra SRCS viterbi_decode_op.cc)


# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc)
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc)
Expand Down
1 change: 1 addition & 0 deletions lite/operators/activation_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(ceil, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
Expand Down
89 changes: 89 additions & 0 deletions lite/operators/empty_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/empty_op.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool EmptyOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out);
return true;
}

bool EmptyOp::InferShapeImpl() const {
std::vector<int64_t> OutShape;
auto ShapeTensor = param_.ShapeTensor;
auto ShapeTensorList = param_.ShapeTensorList;
if (ShapeTensor != nullptr) {
auto ShapeTensorData = ShapeTensor->data<int>();
for (int i = 0; i < ShapeTensor->numel(); i++) {
OutShape.push_back(ShapeTensorData[i]);
}
} else if (!ShapeTensorList.empty()) {
for (size_t i = 0; i < ShapeTensorList.size(); i++) {
OutShape.push_back(ShapeTensorList[i]->data<int>()[0]);
}
} else if (!param_.shape.empty()) {
OutShape = param_.shape;
} else {
LOG(FATAL) << "no valid out_shape. Must set one of shape_tensor, or "
"shape_tensor_list, or shape.";
}

param_.Out->Resize(OutShape);
return true;
}

bool EmptyOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
if (opdesc.HasInput("ShapeTensor") && !opdesc.Input("ShapeTensor").empty()) {
param_.ShapeTensor =
scope->FindMutableTensor(opdesc.Input("ShapeTensor").front());
}
param_.ShapeTensorList.clear();
if (opdesc.HasInput("ShapeTensorList") &&
!opdesc.Input("ShapeTensorList").empty()) {
for (auto name : opdesc.Input("ShapeTensorList")) {
param_.ShapeTensorList.push_back(
GetMutableVar<lite::Tensor>(scope, name));
}
}
if (opdesc.HasAttr("shape")) {
auto type = opdesc.GetAttrType("shape");
if (type == OpAttrType::INTS) { // paddle1.0 shape type is ints
auto shape = opdesc.GetAttr<std::vector<int32_t>>("shape");
param_.shape.resize(shape.size());
for (int i = 0; i < shape.size(); i++) {
param_.shape[i] = shape[i];
}
} else {
param_.shape = opdesc.GetAttr<std::vector<int64_t>>("shape");
}
}
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
CHECK(param_.Out) << "Output(Out) of EmptyOp should not be null.";
if (opdesc.HasAttr("dtype")) {
param_.dtype = opdesc.GetAttr<int>("dtype");
}

return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(empty, paddle::lite::operators::EmptyOp);
44 changes: 44 additions & 0 deletions lite/operators/empty_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"

namespace paddle {
namespace lite {
namespace operators {

class EmptyOp : public OpLite {
public:
EmptyOp() {}
explicit EmptyOp(const std::string &op_type) : OpLite(op_type) {}

bool CheckShape() const override;

bool InferShapeImpl() const override;

bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;

void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "empty"; }

protected:
mutable EmptyParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle
8 changes: 8 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,14 @@ struct TemporalShiftParam : ParamBase {
std::string data_format{"NCHW"};
};

struct EmptyParam : ParamBase {
lite::Tensor* ShapeTensor{nullptr};
std::vector<lite::Tensor*> ShapeTensorList{};
std::vector<int64_t> shape{};
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
lite::Tensor* Out{};
};

struct ViterbiDecodeParam : ParamBase {
const lite::Tensor* input{};
const lite::Tensor* length{};
Expand Down
Loading

0 comments on commit 44ae965

Please sign in to comment.