diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 4cf33ca9ab..95fd8cb2ef 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -436,6 +436,24 @@ Variable NetBuilder::Sort(const Variable& operand, const int& axis, const bool& return instr.GetOutput(0); } +Variable NetBuilder::Argmax(const Variable& x, const int& axis, const bool& keep_dim) { + Instruction instr("argmax", {x}); + instr.SetAttr("axis", axis); + instr.SetAttr("keep_dim", keep_dim); + InferShape(instr); + AppendInstruction(instr); + return instr.GetOutput(0); +} + +Variable NetBuilder::Argmin(const Variable& x, const int& axis, const bool& keep_dim) { + Instruction instr("argmin", {x}); + instr.SetAttr("axis", axis); + instr.SetAttr("keep_dim", keep_dim); + InferShape(instr); + AppendInstruction(instr); + return instr.GetOutput(0); +} + Variable NetBuilder::Conv2d(const Variable& a, const Variable& b, const std::vector& strides, diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index fbfd68c574..f4e33e1e53 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -813,6 +813,26 @@ class NetBuilder { const float epsilon = 1e-5f, const std::string& data_layout = "NCHW"); + /** + * @brief Get index of variable x to the maximum value along the given axis. + * @param x An input N-D variable. + * @param axis Specify the axis to operate on the input. Default: 0. + * @param keep_dim Decide whether to keep the dimension. + * Defalut “NCHW”. + * @return `Index of variable x to the maximum value`. + */ + Variable Argmax(const Variable& x, const int& axis = 0, const bool& keep_dim = false); + + /** + * @brief Get index of variable x to the minimum value along the given axis. + * @param x An input N-D variable. + * @param axis Specify the axis to operate on the input. Default: 0. + * @param keep_dim Decide whether to keep the dimension. + * Defalut “NCHW”. + * @return `Index of variable x to the minimum value`. + */ + Variable Argmin(const Variable& x, const int& axis = 0, const bool& keep_dim = false); + /** * @brief Sort Variable x along the given axis and return sorted index. The original Variable x will not be changed. * @param operand The variable that will be sorted. diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index 924e00060f..c85acd6d89 100755 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -745,7 +745,7 @@ TEST(net_build, program_execute_argsort) { for (int b = 0; b < B; ++b) { int index = h + H * b; sorted_data.push_back(input_data[index]); - out_sorted_data[output_data[index]] = input_data[index]; + out_sorted_data[b] = input_data[h + H * output_data[index]]; } std::sort(sorted_data.begin(), sorted_data.begin() + B); @@ -887,5 +887,297 @@ TEST(net_build, program_execute_arange_int) { } } +TEST(net_build, program_argmax_case1) { + const int N = 4; + const int IN_C = 3; + const int OUT_C = 1; + const int H = 7; + const int W = 7; + + NetBuilder builder("net_builder"); + Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); + Variable output = builder.Argmax(input, 1, true); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input.id())); + scope->Var(std::string(output->id)); + + auto input_tensor = scope->GetTensor(std::string(input.id())); + SetRandData(input_tensor, target); + float* input_data = input_tensor->mutable_data(target); + VLOG(6) << "Visualize input_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + line += (std::to_string(input_data[index]) + ", "); + } + VLOG(6) << line; + } + } + } + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_shape.size(), 4UL); + EXPECT_EQ(output_shape[0], N); + EXPECT_EQ(output_shape[1], OUT_C); + EXPECT_EQ(output_shape[2], H); + EXPECT_EQ(output_shape[3], W); + + int* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + int out_index = w + W * (h + H * n); + float in_data = input_data[index]; + int out_data = output_data[out_index]; + EXPECT_LE(0, out_data); + EXPECT_LT(out_data, IN_C); + int max_index = w + W * (h + H * (out_data + IN_C * n)); + float max_value = input_data[max_index]; + line += (std::to_string(out_data) + ", "); + EXPECT_LE(in_data, max_value); + } + VLOG(6) << line; + } + } + } +} + +TEST(net_build, program_argmax_case2) { + const int N = 4; + const int IN_C = 3; + const int H = 7; + const int W = 7; + + NetBuilder builder("net_builder"); + Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); + Variable output = builder.Argmax(input, 1, false); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input.id())); + scope->Var(std::string(output->id)); + + auto input_tensor = scope->GetTensor(std::string(input.id())); + SetRandData(input_tensor, target); + float* input_data = input_tensor->mutable_data(target); + VLOG(6) << "Visualize input_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + line += (std::to_string(input_data[index]) + ", "); + } + VLOG(6) << line; + } + } + } + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_shape.size(), 3UL); + EXPECT_EQ(output_shape[0], N); + EXPECT_EQ(output_shape[1], H); + EXPECT_EQ(output_shape[2], W); + + int* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + int out_index = w + W * (h + H * n); + float in_data = input_data[index]; + int out_data = output_data[out_index]; + EXPECT_LE(0, out_data); + EXPECT_LT(out_data, IN_C); + int max_index = w + W * (h + H * (out_data + IN_C * n)); + float max_value = input_data[max_index]; + line += (std::to_string(out_data) + ", "); + EXPECT_LE(in_data, max_value); + } + VLOG(6) << line; + } + } + } +} + +TEST(net_build, program_argmin_case1) { + const int N = 4; + const int IN_C = 3; + const int OUT_C = 1; + const int H = 7; + const int W = 7; + + NetBuilder builder("net_builder"); + Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); + Variable output = builder.Argmin(input, 1, true); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input.id())); + scope->Var(std::string(output->id)); + + auto input_tensor = scope->GetTensor(std::string(input.id())); + SetRandData(input_tensor, target); + float* input_data = input_tensor->mutable_data(target); + VLOG(6) << "Visualize input_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + line += (std::to_string(input_data[index]) + ", "); + } + VLOG(6) << line; + } + } + } + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_shape.size(), 4UL); + EXPECT_EQ(output_shape[0], N); + EXPECT_EQ(output_shape[1], OUT_C); + EXPECT_EQ(output_shape[2], H); + EXPECT_EQ(output_shape[3], W); + + int* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + int out_index = w + W * (h + H * n); + float in_data = input_data[index]; + int out_data = output_data[out_index]; + EXPECT_LE(0, out_data); + EXPECT_LT(out_data, IN_C); + int max_index = w + W * (h + H * (out_data + IN_C * n)); + float max_value = input_data[max_index]; + line += (std::to_string(out_data) + ", "); + EXPECT_GE(in_data, max_value); + } + VLOG(6) << line; + } + } + } +} + +TEST(net_build, program_argmin_case2) { + const int N = 4; + const int IN_C = 3; + const int H = 7; + const int W = 7; + + NetBuilder builder("net_builder"); + Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); + Variable output = builder.Argmin(input, 1, false); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input.id())); + scope->Var(std::string(output->id)); + + auto input_tensor = scope->GetTensor(std::string(input.id())); + SetRandData(input_tensor, target); + float* input_data = input_tensor->mutable_data(target); + VLOG(6) << "Visualize input_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + line += (std::to_string(input_data[index]) + ", "); + } + VLOG(6) << line; + } + } + } + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_shape.size(), 3UL); + EXPECT_EQ(output_shape[0], N); + EXPECT_EQ(output_shape[1], H); + EXPECT_EQ(output_shape[2], W); + + int* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int n = 0; n < N; ++n) { + for (int c = 0; c < IN_C; ++c) { + VLOG(6) << "n = " << n << ", c = " << c; + for (int h = 0; h < H; ++h) { + std::string line; + for (int w = 0; w < W; ++w) { + int index = w + W * (h + H * (c + IN_C * n)); + int out_index = w + W * (h + H * n); + float in_data = input_data[index]; + int out_data = output_data[out_index]; + EXPECT_LE(0, out_data); + EXPECT_LT(out_data, IN_C); + int max_index = w + W * (h + H * (out_data + IN_C * n)); + float max_value = input_data[max_index]; + line += (std::to_string(out_data) + ", "); + EXPECT_GE(in_data, max_value); + } + VLOG(6) << line; + } + } + } +} + } // namespace frontend } // namespace cinn diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index 5643c4cd69..82aa58b77b 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -8,6 +8,8 @@ gather_srcs(cinnapi_src SRCS clip.cc arange.cc sort.cc + argmin.cc + argmax.cc squeeze.cc ) @@ -17,4 +19,6 @@ cc_test(test_gather SRCS gather_test.cc DEPS cinncore) cc_test(test_scatter SRCS scatter_test.cc DEPS cinncore) cc_test(test_clip SRCS clip_test.cc DEPS cinncore) cc_test(test_sort SRCS sort_test.cc DEPS cinncore) +cc_test(test_argmin SRCS argmin_test.cc DEPS cinncore) +cc_test(test_argmax SRCS argmax_test.cc DEPS cinncore) cc_test(test_arange SRCS arange_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/argmax.cc b/cinn/hlir/op/contrib/argmax.cc new file mode 100644 index 0000000000..16b0319393 --- /dev/null +++ b/cinn/hlir/op/contrib/argmax.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/argmax.h" + +#include +#include + +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/contrib/sort.h" +#include "cinn/hlir/pe/broadcast.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/hlir/pe/transform.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_schedule.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::CINNValue; +using framework::shape_t; +using ir::Tensor; + +Tensor Argmax(const Tensor &in_tensor, + const common::Target &target, + poly::StageMap stages, + const int &axis, + const bool &keep_dims, + const std::string &name) { + auto shape = in_tensor->shape; + auto ndim = shape.size(); + CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; + + int pos_axis = axis; + if (axis < 0) { + pos_axis = static_cast(ndim) + axis; + } + CHECK_LT(pos_axis, ndim) << "Axis must be less than tensor's dim"; + CHECK_GE(pos_axis, 0) << "Axis must be more than 0"; + + std::vector output_shape; + for (int i = 0; i < shape.size(); ++i) { + CHECK(shape[i].is_constant()) << "Input tensor's shape should be constant value."; + if (axis == i) { + if (keep_dims) { + output_shape.push_back(Expr(1)); + } + } else { + output_shape.push_back(shape[i]); + } + } + if (output_shape.empty()) { + output_shape.push_back(Expr(1)); + } + + auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, false, name + "_index"); + auto res = Compute( + output_shape, + [=](const std::vector &indices) { + std::vector eval_indices(indices); + if (!keep_dims) { + eval_indices.insert(eval_indices.begin() + pos_axis, Expr(0)); + } else { + eval_indices[pos_axis] = Expr(0); + } + return sort_index(eval_indices); + }, + name); + stages->InsertLazily(sort_index); + return res; +} + +std::shared_ptr StrategyForArgmax(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + int axis; + bool keep_dims = false; + + if (attrs.attr_store.count("axis")) { + axis = absl::get(attrs.attr_store.at("axis")); + } else { + LOG(FATAL) << "reduce dimension is not set!"; + } + if (attrs.attr_store.count("keep_dim")) { + keep_dims = absl::get(attrs.attr_store.at("keep_dim")); + } + + framework::CINNCompute argmax_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of argmax compute is empty! Please check."; + common::CINNValuePack pack_args = args[0]; + std::string tensor_name = UniqName("Argmax_out"); + CHECK_EQ(pack_args.size(), 1U) << "There should be 1 input args for argmax compute"; + Expr in_expr = pack_args[0]; + CHECK(in_expr.as_tensor()); + Tensor in_tensor = in_expr.as_tensor_ref(); + auto stages = CreateStages({in_tensor}); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + auto out_tensor = Argmax(in_tensor, target, stages, axis, keep_dims, tensor_name); + + stages->InsertLazily(out_tensor); + std::vector cinn_values{CINNValue(out_tensor), CINNValue(stages)}; + *ret = common::CINNValuePack{cinn_values}; + }); + + framework::CINNSchedule argmax_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of argmax schedule is empty! Please check."; + common::CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 2UL); + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + + // When develop FLAGS_cinn_ir_schedule=true case, we should run unit test with + // FLAGS_cinn_ir_schedule=1 + if (FLAGS_cinn_ir_schedule) { + *ret = common::CINNValuePack{{common::CINNValue(out)}}; + } else { + poly::StageMap stages = arg_pack[arg_pack.size() - 1]; + *ret = common::CINNValuePack{{common::CINNValue(out), common::CINNValue(stages)}}; + } + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(argmax_compute, argmax_schedule, "strategy.argmax.x86", 1); + + return strategy; +} + +std::vector InferShapeForArgmax(const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(inputs_shape.size() == 1UL); + auto ndim = inputs_shape[0].size(); + CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; + int axis; + bool keep_dim; + + CHECK(attrs.find("axis") != attrs.end()); + axis = absl::get(attrs.at("axis")); + if (axis < 0) { + axis = static_cast(ndim) + axis; + } + CHECK_LT(axis, ndim) << "Axis must be less than tensor's dim"; + CHECK_GE(axis, 0) << "Axis must be more than 0"; + + CHECK(attrs.find("keep_dim") != attrs.end()); + keep_dim = absl::get(attrs.at("keep_dim")); + + std::vector out_shapes; + for (size_t i = 0; i < ndim; ++i) { + if (axis == i) { + if (keep_dim) { + out_shapes.push_back(1); + } + } else { + out_shapes.push_back(inputs_shape[0][i]); + } + } + + if (keep_dim) { + CHECK_EQ(ndim, out_shapes.size()); + } else { + CHECK_EQ(ndim - 1, out_shapes.size()); + } + if (out_shapes.empty()) { + out_shapes.push_back(1); + } + + return {out_shapes}; +} + +std::vector InferDtypeForArgmax(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + return {Int(32)}; +} + +std::vector> InferLayoutForArgmax(const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_shapes.size(), 1U) << "The input's shape size is not 1! Please check again."; + CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; + return {input_layouts, input_layouts}; +} +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(argmax_ops) { + CINN_REGISTER_OP(argmax) + .describe("This operator implements the op argmax.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArgmax) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArgmax)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArgmax)) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/argmax.h b/cinn/hlir/op/contrib/argmax.h new file mode 100644 index 0000000000..068b8af480 --- /dev/null +++ b/cinn/hlir/op/contrib/argmax.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace hlir { +namespace op { +ir::Tensor Argmax(const ir::Tensor &in_tensor, + const common::Target &target, + poly::StageMap stages, + const int &axis, + const bool &keep_dims = false, + const std::string &name = "T_Argmax_out"); +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/argmax_test.cc b/cinn/hlir/op/contrib/argmax_test.cc new file mode 100644 index 0000000000..461fc19fe8 --- /dev/null +++ b/cinn/hlir/op/contrib/argmax_test.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/argmax.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, Argmax_Keep) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + + int axis = 1; + ir::Expr n(4); + ir::Expr in_c(3); + ir::Expr out_c(1); + ir::Expr h(28); + ir::Expr w(28); + + lang::Placeholder in("in", {n, in_c, h, w}); + poly::StageMap stages = poly::CreateStages({in}); + ir::Tensor res = Argmax(in, target, stages, axis, true, "test_argmax_in"); + stages->InsertLazily(res); + + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Argmax_Keep", stages, {in, res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Argmax_Keep_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_Argmax_Keep(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _in = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _test_argmax_in = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _test_argmax_in_index = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 3, 28, 28 }); + cinn_buffer_t* _test_argmax_in_index_temp = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 3, 28, 28 }); + cinn_buffer_malloc((void*)(0), _test_argmax_in); + cinn_buffer_malloc((void*)(0), _test_argmax_in_index); + cinn_buffer_malloc((void*)(0), _test_argmax_in_index_temp); + const float* in = ((const float*)(_in->memory)); + int32_t* test_argmax_in = ((int32_t*)(_test_argmax_in->memory)); + int32_t* test_argmax_in_index = ((int32_t*)(_test_argmax_in_index->memory)); + int32_t* test_argmax_in_index_temp = ((int32_t*)(_test_argmax_in_index_temp->memory)); + { + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 3; j += 1) { + for (int32_t k = 0; k < 28; k += 1) { + for (int32_t a = 0; a < 28; a += 1) { + test_argmax_in_index_temp[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_gt_num_float(_in, 3, in[((2352 * i) + ((784 * j) + ((28 * k) + a)))], ((2352 * i) + ((28 * k) + a)), 784); + }; + }; + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 3; j += 1) { + for (int32_t k = 0; k < 28; k += 1) { + for (int32_t a = 0; a < 28; a += 1) { + test_argmax_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_find_int_nd(_test_argmax_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784); + }; + }; + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t k = 0; k < 28; k += 1) { + for (int32_t a = 0; a < 28; a += 1) { + test_argmax_in[((784 * i) + ((28 * k) + a))] = test_argmax_in_index[((2352 * i) + ((28 * k) + a))]; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _test_argmax_in_index); + cinn_buffer_free((void*)(0), _test_argmax_in_index_temp); + cinn_buffer_free((void*)(0), _test_argmax_in); +} + )ROC"; + CHECK_EQ(utils::Trim(code), utils::Trim(target_source)); +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/argmin.cc b/cinn/hlir/op/contrib/argmin.cc new file mode 100644 index 0000000000..44842126ad --- /dev/null +++ b/cinn/hlir/op/contrib/argmin.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/argmin.h" + +#include +#include + +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/contrib/sort.h" +#include "cinn/hlir/pe/broadcast.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/hlir/pe/transform.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_schedule.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::CINNValue; +using framework::shape_t; +using ir::Tensor; + +Tensor Argmin(const Tensor &in_tensor, + const common::Target &target, + poly::StageMap stages, + const int &axis, + const bool &keep_dims, + const std::string &name) { + auto shape = in_tensor->shape; + auto ndim = shape.size(); + CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; + + int pos_axis = axis; + if (axis < 0) { + pos_axis = static_cast(ndim) + axis; + } + CHECK_LT(pos_axis, ndim) << "Axis must be less than tensor's dim"; + CHECK_GE(pos_axis, 0) << "Axis must be more than 0"; + + std::vector output_shape; + for (int i = 0; i < shape.size(); ++i) { + CHECK(shape[i].is_constant()) << "Input tensor's shape should be constant value."; + if (axis == i) { + if (keep_dims) { + output_shape.push_back(Expr(1)); + } + } else { + output_shape.push_back(shape[i]); + } + } + if (output_shape.empty()) { + output_shape.push_back(Expr(1)); + } + auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, true, name + "_index"); + auto res = Compute( + output_shape, + [=](const std::vector &indices) { + std::vector eval_indices(indices); + if (!keep_dims) { + eval_indices.insert(eval_indices.begin() + pos_axis, Expr(0)); + } else { + eval_indices[pos_axis] = Expr(0); + } + return sort_index(eval_indices); + }, + name); + stages->InsertLazily(sort_index); + return res; +} + +std::shared_ptr StrategyForArgmin(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + int axis; + bool keep_dims = false; + + if (attrs.attr_store.count("axis")) { + axis = absl::get(attrs.attr_store.at("axis")); + } else { + LOG(FATAL) << "reduce dimension is not set!"; + } + if (attrs.attr_store.count("keep_dim")) { + keep_dims = absl::get(attrs.attr_store.at("keep_dim")); + } + + framework::CINNCompute argmin_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of argmin compute is empty! Please check."; + common::CINNValuePack pack_args = args[0]; + std::string tensor_name = UniqName("Argmin_out"); + CHECK_EQ(pack_args.size(), 1U) << "There should be 1 input args for argmax compute"; + Expr in_expr = pack_args[0]; + CHECK(in_expr.as_tensor()); + Tensor in_tensor = in_expr.as_tensor_ref(); + auto stages = CreateStages({in_tensor}); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name); + + stages->InsertLazily(out_tensor); + std::vector cinn_values{CINNValue(out_tensor), CINNValue(stages)}; + *ret = common::CINNValuePack{cinn_values}; + }); + + framework::CINNSchedule argmin_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of argmin schedule is empty! Please check."; + common::CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 2UL); + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + + // When develop FLAGS_cinn_ir_schedule=true case, we should run unit test with + // FLAGS_cinn_ir_schedule=1 + if (FLAGS_cinn_ir_schedule) { + *ret = common::CINNValuePack{{common::CINNValue(out)}}; + } else { + poly::StageMap stages = arg_pack[arg_pack.size() - 1]; + *ret = common::CINNValuePack{{common::CINNValue(out), common::CINNValue(stages)}}; + } + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(argmin_compute, argmin_schedule, "strategy.argmin.x86", 1); + + return strategy; +} + +std::vector InferShapeForArgmin(const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(inputs_shape.size() == 1UL); + auto ndim = inputs_shape[0].size(); + CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; + int axis; + bool keep_dim; + + CHECK(attrs.find("axis") != attrs.end()); + axis = absl::get(attrs.at("axis")); + if (axis < 0) { + axis = static_cast(ndim) + axis; + } + CHECK_LT(axis, ndim) << "Axis must be less than tensor's dim"; + CHECK_GE(axis, 0) << "Axis must be more than 0"; + + CHECK(attrs.find("keep_dim") != attrs.end()); + keep_dim = absl::get(attrs.at("keep_dim")); + + std::vector out_shapes; + for (size_t i = 0; i < ndim; ++i) { + if (axis == i) { + if (keep_dim) { + out_shapes.push_back(1); + } + } else { + out_shapes.push_back(inputs_shape[0][i]); + } + } + + if (keep_dim) { + CHECK_EQ(ndim, out_shapes.size()); + } else { + CHECK_EQ(ndim - 1, out_shapes.size()); + } + + if (out_shapes.empty()) { + out_shapes.push_back(1); + } + + return {out_shapes}; +} + +std::vector InferDtypeForArgmin(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + return {Int(32)}; +} + +std::vector> InferLayoutForArgmin(const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_shapes.size(), 1U) << "The input's shape size is not 1! Please check again."; + CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; + return {input_layouts, input_layouts}; +} +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(argmin_ops) { + CINN_REGISTER_OP(argmin) + .describe("This operator implements the op argmin.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArgmin) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArgmin)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArgmin)) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/argmin.h b/cinn/hlir/op/contrib/argmin.h new file mode 100644 index 0000000000..1c2005f961 --- /dev/null +++ b/cinn/hlir/op/contrib/argmin.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace hlir { +namespace op { +ir::Tensor Argmin(const ir::Tensor& in_tensor, + const common::Target& target, + poly::StageMap stages, + const int& axis, + const bool& keep_dims = false, + const std::string& name = "T_Argmin_out"); +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/argmin_test.cc b/cinn/hlir/op/contrib/argmin_test.cc new file mode 100644 index 0000000000..2be2dd2c84 --- /dev/null +++ b/cinn/hlir/op/contrib/argmin_test.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/argmin.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, Argmin_Keep) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + + int axis = 1; + ir::Expr n(4); + ir::Expr in_c(3); + ir::Expr out_c(1); + ir::Expr h(28); + ir::Expr w(28); + + lang::Placeholder in("in", {n, in_c, h, w}); + poly::StageMap stages = poly::CreateStages({in}); + ir::Tensor res = Argmin(in, target, stages, axis, true, "test_argmin_in"); + stages->InsertLazily(res); + + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Argmin_Keep", stages, {in, res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Argmin_Keep_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_Argmin_Keep(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _in = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _test_argmin_in = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _test_argmin_in_index = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 3, 28, 28 }); + cinn_buffer_t* _test_argmin_in_index_temp = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 3, 28, 28 }); + cinn_buffer_malloc((void*)(0), _test_argmin_in); + cinn_buffer_malloc((void*)(0), _test_argmin_in_index); + cinn_buffer_malloc((void*)(0), _test_argmin_in_index_temp); + const float* in = ((const float*)(_in->memory)); + int32_t* test_argmin_in = ((int32_t*)(_test_argmin_in->memory)); + int32_t* test_argmin_in_index = ((int32_t*)(_test_argmin_in_index->memory)); + int32_t* test_argmin_in_index_temp = ((int32_t*)(_test_argmin_in_index_temp->memory)); + { + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 3; j += 1) { + for (int32_t k = 0; k < 28; k += 1) { + for (int32_t a = 0; a < 28; a += 1) { + test_argmin_in_index_temp[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_lt_num_float(_in, 3, in[((2352 * i) + ((784 * j) + ((28 * k) + a)))], ((2352 * i) + ((28 * k) + a)), 784); + }; + }; + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 3; j += 1) { + for (int32_t k = 0; k < 28; k += 1) { + for (int32_t a = 0; a < 28; a += 1) { + test_argmin_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_find_int_nd(_test_argmin_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784); + }; + }; + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t k = 0; k < 28; k += 1) { + for (int32_t a = 0; a < 28; a += 1) { + test_argmin_in[((784 * i) + ((28 * k) + a))] = test_argmin_in_index[((2352 * i) + ((28 * k) + a))]; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _test_argmin_in_index); + cinn_buffer_free((void*)(0), _test_argmin_in_index_temp); + cinn_buffer_free((void*)(0), _test_argmin_in); +} + )ROC"; + CHECK_EQ(utils::Trim(code), utils::Trim(target_source)); +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/sort.cc b/cinn/hlir/op/contrib/sort.cc index 75f0c20f51..1bf21ef20d 100644 --- a/cinn/hlir/op/contrib/sort.cc +++ b/cinn/hlir/op/contrib/sort.cc @@ -47,27 +47,31 @@ using common::CINNValuePack; ir::Tensor ArgSort(const ir::Tensor &A, const common::Target &target, + poly::StageMap stages, const int &axis, const bool &is_ascend, const std::string &name) { - std::string extern_fun_name; + std::string find_func_name; + std::string index_func_name; if (target.arch == common::Target::Arch::NVGPU) { - extern_fun_name.assign("cinn_cuda_"); + index_func_name.assign("cinn_cuda_"); + find_func_name.assign("cinn_cuda_find_int_nd"); } else if (target.arch == common::Target::Arch::X86) { - extern_fun_name.assign("cinn_host_"); + index_func_name.assign("cinn_host_"); + find_func_name.assign("cinn_host_find_int_nd"); } else { LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n"; } if (is_ascend) { - extern_fun_name.append("lt_num_float"); + index_func_name.append("lt_num_float"); } else { - extern_fun_name.append("gt_num_float"); + index_func_name.append("gt_num_float"); } int pos_axis = axis; if (pos_axis < 0) { pos_axis += A->shape.size(); } - auto res = Compute( + auto positions = Compute( A->shape, [=](const std::vector &indices) { Expr offset(0); @@ -85,31 +89,10 @@ ir::Tensor ArgSort(const ir::Tensor &A, offset = common::AutoSimplify(offset); stride = common::AutoSimplify(stride); auto A_shape_axis = A->shape[pos_axis]; - return lang::CallExtern(extern_fun_name, {A, A_shape_axis, A(indices), offset, stride}); + return lang::CallExtern(index_func_name, {A, A_shape_axis, A(indices), offset, stride}); }, - name); - return res; -} - -std::vector Sort(const ir::Tensor &A, - const common::Target &target, - const int &axis, - const bool &is_ascend, - const std::string &name) { - std::string extern_fun_name; - if (target.arch == common::Target::Arch::NVGPU) { - extern_fun_name.assign("cinn_cuda_find_int_nd"); - } else if (target.arch == common::Target::Arch::X86) { - extern_fun_name.assign("cinn_host_find_int_nd"); - } else { - LOG(FATAL) << "Sort only supports X86 and NVGPU ! Please Check.\n"; - } - int pos_axis = axis; - if (pos_axis < 0) { - pos_axis += A->shape.size(); - } - auto sort_index = ArgSort(A, target, pos_axis, is_ascend, name + "_index"); - auto res = Compute( + name + "_temp"); + auto res = Compute( A->shape, [=](const std::vector &indices) { Expr offset(0); @@ -128,13 +111,35 @@ std::vector Sort(const ir::Tensor &A, stride = common::AutoSimplify(stride); auto A_shape_axis = A->shape[pos_axis]; - auto idx = lang::CallExtern(extern_fun_name, {sort_index, A_shape_axis, indices[pos_axis], offset, stride}); + auto idx = lang::CallExtern(find_func_name, {positions, A_shape_axis, indices[pos_axis], offset, stride}); + return idx; + }, + name); + stages->InsertLazily(positions); + return res; +} + +ir::Tensor Sort(const ir::Tensor &A, + const common::Target &target, + poly::StageMap stages, + const int &axis, + const bool &is_ascend, + const std::string &name) { + int pos_axis = axis; + if (pos_axis < 0) { + pos_axis += A->shape.size(); + } + auto sort_index = ArgSort(A, target, stages, pos_axis, is_ascend, name + "_index"); + auto res = Compute( + A->shape, + [=](const std::vector &indices) { std::vector A_indices(indices); - A_indices[pos_axis] = idx; + A_indices[pos_axis] = sort_index(indices); return A(A_indices); }, name); - return {sort_index, res}; + stages->InsertLazily(sort_index); + return res; } std::shared_ptr StrategyForSort(const framework::NodeAttr &attrs, @@ -169,11 +174,8 @@ std::shared_ptr StrategyForSort(const framework::NodeAttr CHECK(pack_args[1].is_string()); tensor_name = pack_args[1].operator std::string(); } - std::vector outputs = Sort(tensor_A, target, axis, is_ascend, tensor_name); - ir::Tensor sort_index = outputs[0]; - ir::Tensor out = outputs[1]; + ir::Tensor out = Sort(tensor_A, target, stages, axis, is_ascend, tensor_name); std::vector res; - stages->InsertLazily(sort_index); stages->InsertLazily(out); res.push_back(CINNValue(out)); CHECK(!out_type.empty()) << "Output type of Sort is empty! Please check.\n"; @@ -224,7 +226,7 @@ std::shared_ptr StrategyForArgSort(const framework::NodeA CHECK(pack_args[1].is_string()); tensor_name = pack_args[1].operator std::string(); } - ir::Tensor out = ArgSort(tensor_A, target, axis, is_ascend, tensor_name); + ir::Tensor out = ArgSort(tensor_A, target, stages, axis, is_ascend, tensor_name); std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); diff --git a/cinn/hlir/op/contrib/sort.h b/cinn/hlir/op/contrib/sort.h index 8ac93ad57c..090ca59e8c 100644 --- a/cinn/hlir/op/contrib/sort.h +++ b/cinn/hlir/op/contrib/sort.h @@ -25,11 +25,19 @@ namespace cinn { namespace hlir { namespace op { -ir::Tensor ArgSort( - const ir::Tensor& A, const common::Target& target, const int& axis, const bool& is_ascend, const std::string& name); +ir::Tensor ArgSort(const ir::Tensor& A, + const common::Target& target, + poly::StageMap stages, + const int& axis, + const bool& is_ascend, + const std::string& name); -std::vector Sort( - const ir::Tensor& A, const common::Target& target, const int& axis, const bool& is_ascend, const std::string& name); +ir::Tensor Sort(const ir::Tensor& A, + const common::Target& target, + poly::StageMap stages, + const int& axis, + const bool& is_ascend, + const std::string& name); } // namespace op } // namespace hlir diff --git a/cinn/hlir/op/contrib/sort_test.cc b/cinn/hlir/op/contrib/sort_test.cc index 2618c6e3a9..9d3f7b9604 100644 --- a/cinn/hlir/op/contrib/sort_test.cc +++ b/cinn/hlir/op/contrib/sort_test.cc @@ -35,19 +35,15 @@ namespace op { TEST(GenerateCode_Cpu, ArgSort) { common::Context::Global().ResetNameId(); -#ifdef CINN_WITH_CUDA - Target target = common::DefaultNVGPUTarget(); -#else Target target = common::DefaultHostTarget(); -#endif ir::Expr n(4); ir::Expr h(28); lang::Placeholder in("in", {n, h}); - ir::Tensor res = ArgSort(in.tensor(), target, 1, true, "test_arg_sort_out"); - - poly::StageMap stages = poly::CreateStages({in, res}); + poly::StageMap stages = poly::CreateStages({in}); + ir::Tensor res = ArgSort(in.tensor(), target, stages, 1, true, "test_arg_sort_out"); + stages->InsertLazily(res); std::vector funcs = lang::LowerVec("TestGenerateCodeCpu_ArgSort", stages, {in, res}, {}, {}, nullptr, target, true); @@ -69,23 +65,17 @@ TEST(GenerateCode_Cpu, ArgSort) { TEST(GenerateCode_Cpu, Sort) { common::Context::Global().ResetNameId(); -#ifdef CINN_WITH_CUDA - Target target = common::DefaultNVGPUTarget(); -#else Target target = common::DefaultHostTarget(); -#endif ir::Expr n(4); ir::Expr h(28); lang::Placeholder in("in", {n, h}); - std::vector outputs = Sort(in.tensor(), target, 1, true, "test_sort_out"); - ir::Tensor index = outputs[0]; - ir::Tensor out = outputs[1]; - - poly::StageMap stages = poly::CreateStages({in, index, out}); + auto stages = poly::CreateStages({in}); + ir::Tensor out = Sort(in, target, stages, 1, true, "test_sort_out"); + stages->InsertLazily(out); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Sort", stages, {in, index, out}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_Sort", stages, {in, out}, {}, {}, nullptr, target, true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -97,9 +87,47 @@ TEST(GenerateCode_Cpu, Sort) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); - VLOG(6) << "Cpu Codegen result:"; - VLOG(6) << code << std::endl; + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + auto target_source = R"ROC( +#include +#include + +void TestGenerateCodeCpu_Sort(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _in = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _test_sort_out = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _test_sort_out_index = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 28 }); + cinn_buffer_t* _test_sort_out_index_temp = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_int32_t(), { 4, 28 }); + cinn_buffer_malloc((void*)(0), _test_sort_out); + cinn_buffer_malloc((void*)(0), _test_sort_out_index); + cinn_buffer_malloc((void*)(0), _test_sort_out_index_temp); + const int32_t* in = ((const int32_t*)(_in->memory)); + int32_t* test_sort_out = ((int32_t*)(_test_sort_out->memory)); + int32_t* test_sort_out_index = ((int32_t*)(_test_sort_out_index->memory)); + int32_t* test_sort_out_index_temp = ((int32_t*)(_test_sort_out_index_temp->memory)); + { + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 28; j += 1) { + test_sort_out_index_temp[((28 * i) + j)] = cinn_host_lt_num_float(_in, 28, in[((28 * i) + j)], (28 * i), 1); + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 28; j += 1) { + test_sort_out_index[((28 * i) + j)] = cinn_host_find_int_nd(_test_sort_out_index_temp, 28, j, (28 * i), 1); + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 28; j += 1) { + test_sort_out[((28 * i) + j)] = in[((28 * i) + test_sort_out_index[((28 * i) + j)])]; + }; + }; + }; + cinn_buffer_free((void*)(0), _test_sort_out_index); + cinn_buffer_free((void*)(0), _test_sort_out_index_temp); + cinn_buffer_free((void*)(0), _test_sort_out); +} + )ROC"; + CHECK_EQ(utils::Trim(code), utils::Trim(target_source)); } } // namespace op diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 38f538f1d5..1d08d6def8 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -27,6 +27,8 @@ CINN_USE_REGISTER(scatter_ops) CINN_USE_REGISTER(cast_ops) CINN_USE_REGISTER(sort_ops) CINN_USE_REGISTER(squeeze_ops) +CINN_USE_REGISTER(argmin_ops) +CINN_USE_REGISTER(argmax_ops) CINN_USE_REGISTER(reduce_ops) CINN_USE_REGISTER(clip_ops) CINN_USE_REGISTER(custom_call_op) diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index dc51bf47f9..531de237f1 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -437,6 +437,8 @@ void BindFrontend(pybind11::module *m) { py::arg("axis") = -1) .def("relu6", &NetBuilder::Relu6, py::arg("a"), py::arg("threshold") = 6.0f) .def("squeeze", &NetBuilder::Squeeze, py::arg("a"), py::arg("axes")) + .def("argmax", &NetBuilder::Argmax, py::arg("x"), py::arg("axis"), py::arg("keep_dim") = false) + .def("argmin", &NetBuilder::Argmin, py::arg("x"), py::arg("axis"), py::arg("keep_dim") = false) .def("conv2d", &NetBuilder::Conv2d, py::arg("x"),