Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Sep 19, 2022
1 parent c2672f7 commit a828143
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
23 changes: 11 additions & 12 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,12 +693,12 @@ TEST(net_build, program_argmax_case1) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
float in_data = input_data[index];
int out_index = w + W * (h + H * n);
float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + OUT_C * n));
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_LE(in_data, max_value);
Expand Down Expand Up @@ -765,12 +765,12 @@ TEST(net_build, program_argmax_case2) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
float in_data = input_data[index];
int out_index = w + W * (h + H * n);
int out_data = output_data[index];
float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + n));
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_LE(in_data, max_value);
Expand Down Expand Up @@ -839,13 +839,12 @@ TEST(net_build, program_argmin_case1) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
float in_data = input_data[index];
int out_index = w + W * (h + H * n);
int out_data = output_data[index];

float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + OUT_C * n));
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_GE(in_data, max_value);
Expand Down Expand Up @@ -912,12 +911,12 @@ TEST(net_build, program_argmin_case2) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + IN_C * n));
float in_data = input_data[index];
int out_index = w + W * (h + H * n);
int out_data = output_data[index];
float in_data = input_data[index];
int out_data = output_data[out_index];
EXPECT_LE(0, out_data);
EXPECT_LT(out_data, IN_C);
int max_index = w + W * (h + H * (out_data + n));
int max_index = w + W * (h + H * (out_data + IN_C * n));
float max_value = input_data[max_index];
line += (std::to_string(out_data) + ", ");
EXPECT_GE(in_data, max_value);
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Tensor Argmax(const Tensor &in_tensor,
output_shape.push_back(Expr(1));
}

auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, true, name + "_index");
auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, false, name + "_index");
auto res = Compute(
output_shape,
[=](const std::vector<Expr> &indices) {
Expand Down

0 comments on commit a828143

Please sign in to comment.