Skip to content

Commit

Permalink
【CINN】Unify for,ifthenelse expression (#57312)
Browse files Browse the repository at this point in the history
* unify for,ifthenelse expression

* delete logic about simplify block in ifthenelse

* fix test case

* delete comment
  • Loading branch information
Courtesy-Xs authored Sep 15, 2023
1 parent ec619e6 commit c3b0078
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 59 deletions.
22 changes: 2 additions & 20 deletions paddle/cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,31 +285,13 @@ void CodeGenC::Visit(const ir::Select *op) {
void CodeGenC::Visit(const ir::IfThenElse *op) {
str_ += "if (";
IrPrinter::Visit(op->condition);
str_ += ") {\n";
str_ += ") ";

if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
IrPrinter::Visit(op->true_case);
if (!op->true_case.As<ir::Block>()) str_ += ";";
str_ += "\n";

if (!op->true_case.As<ir::Block>()) DecIndent();

DoIndent();
str_ += "}";

if (op->false_case.defined()) {
str_ += " else {\n";

if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
str_ += " else ";
IrPrinter::Visit(op->false_case);
if (!op->false_case.As<ir::Block>()) str_ += ";";
str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent();

DoIndent();
str_ += "}";
}
}
void CodeGenC::Visit(const ir::Block *op) {
Expand Down
4 changes: 0 additions & 4 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,8 @@ void test_simple_compute_at(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
Expand Down Expand Up @@ -869,10 +867,8 @@ void test_compute_at0(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
Expand Down
8 changes: 6 additions & 2 deletions paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ Expr For::Make(Var loop_var,
node->min = min;
node->extent = extent;
node->device_api = device_api;
node->body = body;
node->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
node->set_for_type(for_type);
node->set_vectorize_info(vector_info);
node->set_bind_info(bind_info);
Expand Down Expand Up @@ -346,6 +346,10 @@ std::vector<const Expr *> ScheduleBlockRealize::expr_fields() const {
}

Expr IfThenElse::Make(Expr condition, Expr true_case, Expr false_case) {
if (true_case.defined() && (!true_case.As<Block>()))
true_case = ir::Block::Make({true_case});
if (false_case.defined() && (!false_case.As<Block>()))
false_case = ir::Block::Make({false_case});
auto node = make_shared<IfThenElse>(condition, true_case, false_case);
return Expr(node);
}
Expand Down Expand Up @@ -513,7 +517,7 @@ Expr PolyFor::Make(Var iterator,
n->condition = condition;
n->inc = inc;
n->device_api = device_api;
n->body = body;
n->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
n->set_for_type(for_type);
n->set_vectorize_info(vectorize_info);
n->set_bind_info(bind_info);
Expand Down
18 changes: 2 additions & 16 deletions paddle/cinn/ir/utils/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,26 +229,12 @@ void IrPrinter::Visit(const PolyFor *x) {
void IrPrinter::Visit(const IfThenElse *x) {
str_ += "if (";
Visit(x->condition);
str_ += ") {\n";
IncIndent();
DoIndent();
str_ += ") ";
Visit(x->true_case);
DecIndent();
str_ += "\n";
DoIndent();
str_ += "}";

if (x->false_case.defined()) {
str_ += " else {\n";
IncIndent();

DoIndent();
str_ += " else ";
Visit(x->false_case);
str_ += "\n";

DecIndent();
DoIndent();
str_ += "}";
}
}
void IrPrinter::Visit(const Block *x) {
Expand Down
17 changes: 0 additions & 17 deletions paddle/cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,23 +306,6 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> {
}
}

void Visit(const IfThenElse* op, Expr* expr) override {
auto* node = expr->As<IfThenElse>();
Visit(&node->condition, &node->condition);
if (node->true_case.As<Block>() &&
(node->true_case.As<Block>()->stmts.size() == 1)) {
node->true_case = node->true_case.As<Block>()->stmts[0];
}
Visit(&node->true_case, &node->true_case);
if (node->false_case.defined()) {
if (node->false_case.As<Block>() &&
(node->false_case.As<Block>()->stmts.size() == 1)) {
node->false_case = node->false_case.As<Block>()->stmts[0];
}
Visit(&node->false_case, &node->false_case);
}
}

void Visit(const ScheduleBlock* op, Expr* expr) override {
auto* node = expr->As<ScheduleBlock>();
CHECK(node);
Expand Down

0 comments on commit c3b0078

Please sign in to comment.