Skip to content

Commit

Permalink
[NPU] cast op bridge and ut
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
zhupengyang committed Jan 8, 2020
1 parent 08afd3a commit 5f80dca
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 71 deletions.
2 changes: 2 additions & 0 deletions lite/kernels/xpu/bridges/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_sub
lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_cast_op_xpu SRCS cast_op.cc DEPS ${xpu_subgraph_bridge_deps})

set(xpu_subgraph_bridges
subgraph_bridge_registry
Expand All @@ -46,6 +47,7 @@ set(xpu_subgraph_bridges
subgraph_bridge_layer_norm_op_xpu
subgraph_bridge_dropout_op_xpu
subgraph_bridge_matmul_op_xpu
subgraph_bridge_cast_op_xpu
CACHE INTERNAL "xpu_subgraph_bridges")

message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}")
99 changes: 99 additions & 0 deletions lite/kernels/xpu/bridges/cast_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"

namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {

int CvtDtype(int dtype, PrecisionType* ptype) {
switch (dtype) {
case 21:
*ptype = PRECISION(kInt8);
break;
case 1:
*ptype = PRECISION(kInt16);
break;
case 2:
*ptype = PRECISION(kInt32);
break;
case 3:
*ptype = PRECISION(kInt64);
break;
case 5:
*ptype = PRECISION(kFloat);
break;
default:
LOG(WARNING) << "[XPU] unsupported date type: " << dtype;
return FAILED;
}
return SUCCESS;
}

int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";

// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto out_name = op_info->Output("Out").front();

// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
int in_dtype = op_info->GetAttr<int>("in_dtype");
PrecisionType in_ptype;
if (CvtDtype(in_dtype, &in_ptype) == FAILED) {
return FAILED;
}

int out_dtype = op_info->GetAttr<int>("out_dtype");
PrecisionType out_ptype;
if (CvtDtype(out_dtype, &out_ptype) == FAILED) {
return FAILED;
}

// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x, in_ptype);
}

// Cast node
graph->Add(
out_name,
graph->builder_.CreateCast(*x_node->data(), CvtPrecisionType(out_ptype)));

return SUCCESS;
}

} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle

REGISTER_SUBGRAPH_BRIDGE(cast,
kXPU,
paddle::lite::subgraph::xpu::CastConverter);
1 change: 1 addition & 0 deletions lite/kernels/xpu/bridges/paddle_use_bridges.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ USE_SUBGRAPH_BRIDGE(layer_norm, kXPU);
USE_SUBGRAPH_BRIDGE(gelu, kXPU);
USE_SUBGRAPH_BRIDGE(dropout, kXPU);
USE_SUBGRAPH_BRIDGE(matmul, kXPU);
USE_SUBGRAPH_BRIDGE(cast, kXPU);
2 changes: 1 addition & 1 deletion lite/tests/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
Expand Down
196 changes: 126 additions & 70 deletions lite/tests/kernels/cast_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ namespace lite {

class CastComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string output_ = "out";
std::string x_ = "x";
std::string out_ = "out";
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
int in_dtype_;
int out_dtype_;
DDim x_dims_{{2, 2}};
DDim dims_{{2, 2}};

public:
CastComputeTester(const Place& place,
Expand All @@ -36,92 +37,147 @@ class CastComputeTester : public arena::TestCase {
int out_dtype)
: TestCase(place, alias), in_dtype_(in_dtype), out_dtype_(out_dtype) {}

void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
template <typename T1, typename T2>
void RunBaselineHelper(Scope* scope) {
auto* x = scope->FindTensor(x_);
auto* x_data = x->data<T1>();
auto* out = scope->NewTensor(out_);
CHECK(out);
out->Resize(x_dims_);
out->Resize(dims_);
auto* out_data = out->mutable_data<T2>();
for (int i = 0; i < dims_.production(); i++) {
*out_data = static_cast<T2>(*x_data);
out_data++;
x_data++;
}
}

if (out_dtype_ == 5 && in_dtype_ == 20) {
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<unsigned char>();
auto* output_data = out->mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
*output_data = static_cast<float>(*x_data);
output_data++;
x_data++;
}
} else if (out_dtype_ == 5 && in_dtype_ == 21) {
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<char>();
for (int i = 0; i < x_dims_.production(); i++) {
*output_data = static_cast<float>(*x_data);
output_data++;
x_data++;
}
} else if (out_dtype_ == 5 && in_dtype_ == 2) {
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<int32_t>();
for (int i = 0; i < x_dims_.production(); i++) {
*output_data = static_cast<float>(*x_data);
output_data++;
x_data++;
}
void RunBaseline(Scope* scope) override {
if (in_dtype_ == 20 && out_dtype_ == 5) {
RunBaselineHelper<uint8_t, float>(scope);
} else if (in_dtype_ == 2 && out_dtype_ == 5) {
RunBaselineHelper<int32_t, float>(scope);
} else if (in_dtype_ == 3 && out_dtype_ == 5) {
RunBaselineHelper<int64_t, float>(scope);
} else if (in_dtype_ == 5 && out_dtype_ == 3) {
RunBaselineHelper<float, int64_t>(scope);
} else if (in_dtype_ == 21 && out_dtype_ == 5) {
RunBaselineHelper<int8_t, float>(scope);
} else if (in_dtype_ == 5 && out_dtype_ == 21) {
RunBaselineHelper<float, int8_t>(scope);
} else {
LOG(FATAL) << "unsupported";
}
}

void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("cast");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("in_dtype", in_dtype_);
op_desc->SetAttr("out_dtype", out_dtype_);
}

template <typename T1>
void PrepareDataHelper() {
std::vector<T1> x_data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
x_data[i] = static_cast<T1>(i % 128);
}
SetCommonTensor(x_, dims_, x_data.data());
}

void PrepareData() override {
SetPrecisionType(output_, PRECISION(kFloat));
if (in_dtype_ == 20) {
std::vector<unsigned char> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = static_cast<unsigned char>(i % 128);
}
SetCommonTensor(input_, x_dims_, x_data.data());
} else if (in_dtype_ == 21) {
std::vector<char> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_data[i] = sign * static_cast<char>(i % 128);
}
SetCommonTensor(input_, x_dims_, x_data.data());
} else if (in_dtype_ == 2) {
std::vector<int32_t> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
int sign = i % 3 == 0 ? -1 : 1;
x_data[i] = sign * static_cast<int32_t>(i % 128);
}
SetCommonTensor(input_, x_dims_, x_data.data());
} else {
LOG(FATAL) << "not implemented!";
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
switch (in_dtype_) {
case 20:
PrepareDataHelper<uint8_t>();
break;
case 21:
PrepareDataHelper<int8_t>();
break;
case 1:
PrepareDataHelper<int16_t>();
break;
case 2:
PrepareDataHelper<int32_t>();
break;
case 3:
PrepareDataHelper<int64_t>();
break;
case 5:
PrepareDataHelper<float>();
break;
case 6:
PrepareDataHelper<double>();
break;
case 19:
PrepareDataHelper<size_t>();
break;
default:
LOG(FATAL) << "unsupported data type: " << in_dtype_;
break;
}

PrecisionType out_ptype;
switch (out_dtype_) {
case 0:
out_ptype = PRECISION(kBool);
break;
case 21:
out_ptype = PRECISION(kInt8);
break;
case 1:
out_ptype = PRECISION(kInt16);
break;
case 2:
out_ptype = PRECISION(kInt32);
break;
case 3:
out_ptype = PRECISION(kInt64);
break;
case 4:
out_ptype = PRECISION(kFP16);
break;
case 5:
out_ptype = PRECISION(kFloat);
break;
default:
LOG(FATAL) << "unsupported data type: " << out_dtype_;
break;
}
SetPrecisionType(out_, out_ptype);
}
};

TEST(Cast, precision) {
LOG(INFO) << "test cast op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));

void TestCast(Place place, float abs_error, int in_dtype, int out_dtype) {
std::unique_ptr<arena::TestCase> tester(
new CastComputeTester(place, "def", 20, 5));
arena::Arena arena(std::move(tester), place, 2e-5);
new CastComputeTester(place, "def", in_dtype, out_dtype));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}

TEST(Cast, precision) {
LOG(INFO) << "test cast op";
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif

std::unique_ptr<arena::TestCase> tester1(
new CastComputeTester(place, "def", 2, 5));
arena::Arena arena1(std::move(tester1), place, 2e-5);
arena1.TestPrecision();
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
#ifndef LITE_WITH_XPU
TestCast(place, abs_error, 20, 5);
#endif
TestCast(place, abs_error, 2, 5);
TestCast(place, abs_error, 3, 5);
TestCast(place, abs_error, 5, 3);
}

} // namespace lite
Expand Down

0 comments on commit 5f80dca

Please sign in to comment.