Skip to content

Commit

Permalink
fix concat when axis < 0; test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Apr 19, 2021
1 parent e3a4de0 commit a2bdf19
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lite/kernels/arm/concat_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void ConcatCompute::Run() {
axis = axis_tensor_data[0];
}
if (axis < 0) {
axis += inputs[0]->dims().size();
axis += static_cast<int>(inputs[0]->dims().size());
}

lite_api::PrecisionType type = PRECISION(kUnk);
Expand Down
9 changes: 6 additions & 3 deletions lite/kernels/x86/concat_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
return;
}

int64_t axis = static_cast<int64_t>(param.axis);
int axis = param.axis;
auto* axis_tensor = param.axis_tensor;
if (axis_tensor != nullptr) {
auto* axis_tensor_data = axis_tensor->template data<int>();
axis = static_cast<int64_t>(axis_tensor_data[0]);
axis = axis_tensor_data[0];
}

const auto& x_dims = param.x[0]->dims();
if (axis < 0) {
axis += static_cast<int>(x_dims.size());
}

auto* out = param.output;
T* output_data = param.output->template mutable_data<T>();

Expand Down
4 changes: 3 additions & 1 deletion lite/kernels/xpu/concat_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ void ConcatCompute<InType>::Run() {

auto ins = param.x;
auto out = param.output;
int64_t axis = param.axis;
int64_t axis = param.axis < 0
? param.axis + static_cast<int>(ins[0]->dims().size())
: param.axis;

std::vector<const float*> x_list;
std::vector<std::vector<int>> xdims_list;
Expand Down
2 changes: 1 addition & 1 deletion lite/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ bool ConcatOpLite::InferShapeImpl() const {
axis = axis_tensor_val[0];
}
if (axis < 0) {
axis += inputs[0]->dims().size();
axis += static_cast<int>(inputs[0]->dims().size());
}

auto out_dims = inputs[0]->dims();
Expand Down
15 changes: 9 additions & 6 deletions lite/tests/kernels/concat_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,24 @@ class ConcateComputeTester : public arena::TestCase {
x_vct.push_back(scope->FindTensor(name));
}

int axis = axis_ < 0 ? axis_ + static_cast<int>(x_dims_.size()) : axis_;
auto* out = scope->NewTensor(out_);
DDim output_dims = infer_shape(x_vct, axis_);
DDim output_dims = infer_shape(x_vct, axis);
out->Resize(output_dims);
auto* output_data = out->mutable_data<float>();

int num = x_vct.size();
int rows = 1;
auto dim_0 = x_vct[0]->dims();
for (int i = 0; i < axis_; ++i) {
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;

std::vector<int> input_cols(x_vct.size());
for (int i = 0; i < num; ++i) {
int input_i_numel = x_vct[i]->dims().size() == 0 ? 0 : 1;
for (int didx = 0; didx < x_vct[i]->dims().size(); ++didx) {
for (size_t didx = 0; didx < x_vct[i]->dims().size(); ++didx) {
input_i_numel *= x_vct[i]->dims()[didx];
}
int t_cols = input_i_numel / rows;
Expand Down Expand Up @@ -142,7 +143,9 @@ class ConcateComputeTester : public arena::TestCase {
TEST(Concat, precision) {
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
#elif defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
Expand All @@ -156,9 +159,9 @@ TEST(Concat, precision) {
return;
#endif

for (int axis : {1, 2}) {
for (int axis : {-1, 1, 2}) {
for (bool is_use_axis_tensor : {false, true}) {
#ifdef LITE_WITH_NPU
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU)
if (is_use_axis_tensor) continue;
#endif
std::unique_ptr<arena::TestCase> tester(
Expand Down

0 comments on commit a2bdf19

Please sign in to comment.