Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【PaddlePaddle Hackathon 68】add argmax and argmin op #946

Merged
merged 12 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,24 @@ Variable NetBuilder::Sort(const Variable& operand, const int& axis, const bool&
return instr.GetOutput(0);
}

Variable NetBuilder::Argmax(const Variable& x, const int& axis, const bool& keep_dim) {
Instruction instr("argmax", {x});
instr.SetAttr("axis", axis);
instr.SetAttr("keep_dim", keep_dim);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::Argmin(const Variable& x, const int& axis, const bool& keep_dim) {
Instruction instr("argmin", {x});
instr.SetAttr("axis", axis);
instr.SetAttr("keep_dim", keep_dim);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::Conv2d(const Variable& a,
const Variable& b,
const std::vector<int>& strides,
Expand Down
20 changes: 20 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,26 @@ class NetBuilder {
const float epsilon = 1e-5f,
const std::string& data_layout = "NCHW");

/**
* @brief Get index of variable x to the maximum value along the given axis.
* @param x An input N-D variable.
* @param axis Specify the axis to operate on the input. Default: 0.
* @param keep_dim Decide whether to keep the dimension.
* Defalut “NCHW”.
* @return `Index of variable x to the maximum value`.
*/
Variable Argmax(const Variable& x, const int& axis = 0, const bool& keep_dim = false);

/**
* @brief Get index of variable x to the minimum value along the given axis.
* @param x An input N-D variable.
* @param axis Specify the axis to operate on the input. Default: 0.
* @param keep_dim Decide whether to keep the dimension.
* Defalut “NCHW”.
* @return `Index of variable x to the minimum value`.
*/
Variable Argmin(const Variable& x, const int& axis = 0, const bool& keep_dim = false);

/**
* @brief Sort Variable x along the given axis and return sorted index. The original Variable x will not be changed.
* @param operand The variable that will be sorted.
Expand Down
294 changes: 293 additions & 1 deletion cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ TEST(net_build, program_execute_argsort) {
for (int b = 0; b < B; ++b) {
int index = h + H * b;
sorted_data.push_back(input_data[index]);
out_sorted_data[output_data[index]] = input_data[index];
out_sorted_data[b] = input_data[h + H * output_data[index]];
}
std::sort(sorted_data.begin(), sorted_data.begin() + B);

Expand Down Expand Up @@ -887,5 +887,297 @@ TEST(net_build, program_execute_arange_int) {
}
}

TEST(net_build, program_argmax_case1) {
const int N = 4;
const int IN_C = 3;
const int OUT_C = 1;
const int H = 7;
const int W = 7;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In");
Variable output = builder.Argmax(input, 1, true);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize input_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
line += (std::to_string(input_data[index]) + ", ");
}
VLOG(6) << line;
}
}
}
runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_shape.size(), 4UL);
EXPECT_EQ(output_shape[0], N);
EXPECT_EQ(output_shape[1], OUT_C);
EXPECT_EQ(output_shape[2], H);
EXPECT_EQ(output_shape[3], W);

int* output_data = output_tensor->mutable_data<int>(target);
VLOG(6) << "Visualize output_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
int out_index = w + W * (h + H * n);
float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_LE(in_data, max_value);
}
VLOG(6) << line;
}
}
}
}

TEST(net_build, program_argmax_case2) {
const int N = 4;
const int IN_C = 3;
const int H = 7;
const int W = 7;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In");
Variable output = builder.Argmax(input, 1, false);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize input_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
line += (std::to_string(input_data[index]) + ", ");
}
VLOG(6) << line;
}
}
}
runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_shape.size(), 3UL);
EXPECT_EQ(output_shape[0], N);
EXPECT_EQ(output_shape[1], H);
EXPECT_EQ(output_shape[2], W);

int* output_data = output_tensor->mutable_data<int>(target);
VLOG(6) << "Visualize output_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
int out_index = w + W * (h + H * n);
float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_LE(in_data, max_value);
}
VLOG(6) << line;
}
}
}
}

TEST(net_build, program_argmin_case1) {
const int N = 4;
const int IN_C = 3;
const int OUT_C = 1;
const int H = 7;
const int W = 7;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In");
Variable output = builder.Argmin(input, 1, true);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize input_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
line += (std::to_string(input_data[index]) + ", ");
}
VLOG(6) << line;
}
}
}
runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_shape.size(), 4UL);
EXPECT_EQ(output_shape[0], N);
EXPECT_EQ(output_shape[1], OUT_C);
EXPECT_EQ(output_shape[2], H);
EXPECT_EQ(output_shape[3], W);

int* output_data = output_tensor->mutable_data<int>(target);
VLOG(6) << "Visualize output_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
int out_index = w + W * (h + H * n);
float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_GE(in_data, max_value);
}
VLOG(6) << line;
}
}
}
}

TEST(net_build, program_argmin_case2) {
const int N = 4;
const int IN_C = 3;
const int H = 7;
const int W = 7;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In");
Variable output = builder.Argmin(input, 1, false);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize input_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
line += (std::to_string(input_data[index]) + ", ");
}
VLOG(6) << line;
}
}
}
runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_shape.size(), 3UL);
EXPECT_EQ(output_shape[0], N);
EXPECT_EQ(output_shape[1], H);
EXPECT_EQ(output_shape[2], W);

int* output_data = output_tensor->mutable_data<int>(target);
VLOG(6) << "Visualize output_data";
for (int n = 0; n < N; ++n) {
for (int c = 0; c < IN_C; ++c) {
VLOG(6) << "n = " << n << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
int out_index = w + W * (h + H * n);
float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_GE(in_data, max_value);
}
VLOG(6) << line;
}
}
}
}

} // namespace frontend
} // namespace cinn
4 changes: 4 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ gather_srcs(cinnapi_src SRCS
clip.cc
arange.cc
sort.cc
argmin.cc
argmax.cc
squeeze.cc
)

Expand All @@ -17,4 +19,6 @@ cc_test(test_gather SRCS gather_test.cc DEPS cinncore)
cc_test(test_scatter SRCS scatter_test.cc DEPS cinncore)
cc_test(test_clip SRCS clip_test.cc DEPS cinncore)
cc_test(test_sort SRCS sort_test.cc DEPS cinncore)
cc_test(test_argmin SRCS argmin_test.cc DEPS cinncore)
cc_test(test_argmax SRCS argmax_test.cc DEPS cinncore)
cc_test(test_arange SRCS arange_test.cc DEPS cinncore)
Loading