diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index a133d474adee7..a121e25fe14d3 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -571,7 +571,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr.const_value); emit(expr.atomic); auto *e = expr.expr.get(); - emit(e->stmt); + emit(e->get_flattened_stmt()); emit(e->attributes); emit(e->ret_type); expr.expr->accept(this); diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index d789826156cc6..7b49a16dfa713 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -11,8 +11,10 @@ class ExpressionVisitor; // always a tree - used as rvalues class Expression { - public: + protected: Stmt *stmt; + + public: std::string tb; std::map attributes; DataType ret_type; @@ -53,6 +55,10 @@ class Expression { virtual ~Expression() { } + + Stmt *get_flattened_stmt() const { + return stmt; + } }; class ExprGroup { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 4d22a8f4b6eba..51eee829d6926 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -190,8 +190,8 @@ bool UnaryOpExpression::is_cast() const { } void UnaryOpExpression::flatten(FlattenContext *ctx) { - flatten_rvalue(operand, ctx); - auto unary = std::make_unique(type, operand->stmt); + auto operand_stmt = flatten_rvalue(operand, ctx); + auto unary = std::make_unique(type, operand_stmt); if (is_cast()) { unary->cast_type = cast_type; } @@ -331,17 +331,18 @@ void BinaryOpExpression::type_check(CompileConfig *config) { void BinaryOpExpression::flatten(FlattenContext *ctx) { // if (stmt) // return; - flatten_rvalue(lhs, ctx); + auto lhs_stmt = flatten_rvalue(lhs, ctx); + if (binary_is_logical(type)) { auto result = ctx->push_back(ret_type); - ctx->push_back(result, lhs->stmt); + ctx->push_back(result, lhs_stmt); auto cond = ctx->push_back(result); auto if_stmt = ctx->push_back(cond); FlattenContext rctx; rctx.current_block = ctx->current_block; - flatten_rvalue(rhs, &rctx); - rctx.push_back(result, rhs->stmt); + auto rhs_stmt = flatten_rvalue(rhs, &rctx); + rctx.push_back(result, rhs_stmt); auto true_block = std::make_unique(); if (type == BinaryOpType::logical_and) { @@ -361,8 +362,8 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { stmt->ret_type = ret_type; return; } - flatten_rvalue(rhs, ctx); - ctx->push_back(std::make_unique(type, lhs->stmt, rhs->stmt)); + auto rhs_stmt = flatten_rvalue(rhs, ctx); + ctx->push_back(std::make_unique(type, lhs_stmt, rhs_stmt)); ctx->stmts.back()->tb = tb; stmt = ctx->back_stmt(); stmt->ret_type = ret_type; @@ -374,18 +375,18 @@ void make_ifte(Expression::FlattenContext *ctx, Expr true_val, Expr false_val) { auto result = ctx->push_back(ret_type); - flatten_rvalue(cond, ctx); - auto if_stmt = ctx->push_back(cond->stmt); + auto cond_stmt = flatten_rvalue(cond, ctx); + auto if_stmt = ctx->push_back(cond_stmt); Expression::FlattenContext lctx; lctx.current_block = ctx->current_block; - flatten_rvalue(true_val, &lctx); - lctx.push_back(result, true_val->stmt); + auto true_val_stmt = flatten_rvalue(true_val, &lctx); + lctx.push_back(result, true_val_stmt); Expression::FlattenContext rctx; rctx.current_block = ctx->current_block; - flatten_rvalue(false_val, &rctx); - rctx.push_back(result, false_val->stmt); + auto false_val_stmt = flatten_rvalue(false_val, &rctx); + rctx.push_back(result, false_val_stmt); auto true_block = std::make_unique(); true_block->set_statements(std::move(lctx.stmts)); @@ -492,11 +493,11 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { // if (stmt) // return; if (type == TernaryOpType::select) { - flatten_rvalue(op1, ctx); - flatten_rvalue(op2, ctx); - flatten_rvalue(op3, ctx); + auto op1_stmt = flatten_rvalue(op1, ctx); + auto op2_stmt = flatten_rvalue(op2, ctx); + auto op3_stmt = flatten_rvalue(op3, ctx); ctx->push_back( - std::make_unique(type, op1->stmt, op2->stmt, op3->stmt)); + std::make_unique(type, op1_stmt, op2_stmt, op3_stmt)); } else if (type == TernaryOpType::ifte) { make_ifte(ctx, ret_type, op1, op2, op3); } @@ -517,8 +518,7 @@ void InternalFuncCallExpression::type_check(CompileConfig *) { void InternalFuncCallExpression::flatten(FlattenContext *ctx) { std::vector args_stmts(args.size()); for (int i = 0; i < (int)args.size(); ++i) { - flatten_rvalue(args[i], ctx); - args_stmts[i] = args[i]->stmt; + args_stmts[i] = flatten_rvalue(args[i], ctx); } ctx->push_back(func_name, args_stmts, nullptr, with_runtime_context); @@ -554,8 +554,7 @@ std::vector make_index_stmts(Expression::FlattenContext *ctx, const std::vector &offsets) { std::vector index_stmts; for (int i = 0; i < (int)indices.size(); i++) { - flatten_rvalue(indices.exprs[i], ctx); - Stmt *ind = indices.exprs[i]->stmt; + Stmt *ind = flatten_rvalue(indices.exprs[i], ctx); if (!offsets.empty()) { auto offset = ctx->push_back(TypedConstant(offsets[i])); ind = ctx->push_back(BinaryOpType::sub, ind, offset); @@ -591,14 +590,13 @@ Stmt *make_ndarray_access(Expression::FlattenContext *ctx, ExprGroup indices) { std::vector index_stmts; for (int i = 0; i < (int)indices.size(); i++) { - flatten_rvalue(indices.exprs[i], ctx); - Stmt *ind = indices.exprs[i]->stmt; + Stmt *ind = flatten_rvalue(indices.exprs[i], ctx); index_stmts.push_back(ind); } - flatten_lvalue(var, ctx); + auto var_stmt = flatten_lvalue(var, ctx); auto expr = var.cast(); auto external_ptr_stmt = std::make_unique( - expr->stmt, index_stmts, expr->dt.get_shape(), expr->element_dim); + var_stmt, index_stmts, expr->dt.get_shape(), expr->element_dim); if (expr->dim == indices.size()) { // Indexing into an scalar element external_ptr_stmt->ret_type = expr->dt.ptr_removed().get_element_type(); @@ -611,7 +609,7 @@ Stmt *make_ndarray_access(Expression::FlattenContext *ctx, } Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx, - const Expr &var, + Stmt *var_stmt, const ExprGroup &indices, const std::vector &shape, int stride, @@ -626,12 +624,12 @@ Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx, if (needs_dynamic_index) { offset_stmt = ctx->push_back(TypedConstant(0)); for (int i = 0; i < (int)indices.size(); ++i) { - flatten_rvalue(indices[i], ctx); + auto index_stmt = flatten_rvalue(indices[i], ctx); Stmt *shape_stmt = ctx->push_back(TypedConstant(shape[i])); Stmt *mul_stmt = ctx->push_back(BinaryOpType::mul, offset_stmt, shape_stmt); - offset_stmt = ctx->push_back(BinaryOpType::add, mul_stmt, - indices[i]->stmt); + offset_stmt = + ctx->push_back(BinaryOpType::add, mul_stmt, index_stmt); } } else { int offset = 0; @@ -646,7 +644,7 @@ Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx, offset_stmt = ctx->push_back(BinaryOpType::mul, offset_stmt, stride_stmt); } - return ctx->push_back(var->stmt, offset_stmt, tb); + return ctx->push_back(var_stmt, offset_stmt, tb); } Stmt *make_tensor_access(Expression::FlattenContext *ctx, @@ -656,22 +654,22 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, std::vector shape, int stride, const std::string &tb) { - flatten_lvalue(var, ctx); + auto var_stmt = flatten_lvalue(var, ctx); if (!var->is_lvalue()) { auto alloca_stmt = ctx->push_back(var->ret_type); - ctx->push_back(alloca_stmt, var->stmt); - var->stmt = alloca_stmt; + ctx->push_back(alloca_stmt, var_stmt); + var_stmt = alloca_stmt; } if (is_tensor(ret_type)) { std::vector stmts; for (auto &indices : indices_group) { - stmts.push_back(make_tensor_access_single_element(ctx, var, indices, + stmts.push_back(make_tensor_access_single_element(ctx, var_stmt, indices, shape, stride, tb)); } return ctx->push_back(stmts, ret_type); } - return make_tensor_access_single_element(ctx, var, indices_group[0], shape, - stride, tb); + return make_tensor_access_single_element(ctx, var_stmt, indices_group[0], + shape, stride, tb); } void MatrixExpression::type_check(CompileConfig *config) { @@ -686,8 +684,7 @@ void MatrixExpression::flatten(FlattenContext *ctx) { TI_ASSERT(this->dt->is()); std::vector values; for (auto &elt : elements) { - flatten_rvalue(elt, ctx); - values.push_back(elt->stmt); + values.push_back(flatten_rvalue(elt, ctx)); } stmt = ctx->push_back(values); stmt->ret_type = this->dt; @@ -831,10 +828,10 @@ void RangeAssumptionExpression::type_check(CompileConfig *) { } void RangeAssumptionExpression::flatten(FlattenContext *ctx) { - flatten_rvalue(input, ctx); - flatten_rvalue(base, ctx); + auto input_stmt = flatten_rvalue(input, ctx); + auto base_stmt = flatten_rvalue(base, ctx); ctx->push_back( - Stmt::make(input->stmt, base->stmt, low, high)); + Stmt::make(input_stmt, base_stmt, low, high)); stmt = ctx->back_stmt(); } @@ -848,8 +845,8 @@ void LoopUniqueExpression::type_check(CompileConfig *) { } void LoopUniqueExpression::flatten(FlattenContext *ctx) { - flatten_rvalue(input, ctx); - ctx->push_back(Stmt::make(input->stmt, covers)); + auto input_stmt = flatten_rvalue(input, ctx); + ctx->push_back(Stmt::make(input_stmt, covers)); stmt = ctx->back_stmt(); } @@ -919,10 +916,9 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { op_type = AtomicOpType::add; } // expand rhs - flatten_rvalue(val, ctx); - auto src_val = val->stmt; - flatten_lvalue(dest, ctx); - stmt = ctx->push_back(op_type, dest->stmt, src_val); + auto val_stmt = flatten_rvalue(val, ctx); + auto dest_stmt = flatten_lvalue(dest, ctx); + stmt = ctx->push_back(op_type, dest_stmt, val_stmt); stmt->ret_type = stmt->as()->dest->ret_type; stmt->tb = tb; } @@ -952,8 +948,7 @@ void SNodeOpExpression::type_check(CompileConfig *config) { void SNodeOpExpression::flatten(FlattenContext *ctx) { std::vector indices_stmt; for (int i = 0; i < (int)indices.size(); i++) { - flatten_rvalue(indices[i], ctx); - indices_stmt.push_back(indices[i]->stmt); + indices_stmt.push_back(flatten_rvalue(indices[i], ctx)); } auto is_cell_access = SNodeOpStmt::activation_related(op_type) && snode->type != SNodeType::dynamic; @@ -971,18 +966,16 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { } else if (op_type == SNodeOpType::get_addr) { ctx->push_back(SNodeOpType::get_addr, snode, ptr, nullptr); } else if (op_type == SNodeOpType::append) { - for (auto &value : values) { - flatten_rvalue(value, ctx); - } auto alloca = ctx->push_back(PrimitiveType::i32); alloca->set_tb(tb); auto addr = ctx->push_back(SNodeOpType::allocate, snode, ptr, alloca); addr->set_tb(tb); for (int i = 0; i < values.size(); i++) { + auto value_stmt = flatten_rvalue(values[i], ctx); auto ch_addr = ctx->push_back(addr, snode, i); ch_addr->set_tb(tb); - ctx->push_back(ch_addr, values[i]->stmt)->set_tb(tb); + ctx->push_back(ch_addr, value_stmt)->set_tb(tb); } ctx->push_back(alloca)->set_tb(tb); TI_ERROR_IF(snode->type != SNodeType::dynamic, @@ -1072,13 +1065,12 @@ void TextureOpExpression::type_check(CompileConfig *config) { } void TextureOpExpression::flatten(FlattenContext *ctx) { - flatten_rvalue(texture_ptr, ctx); + auto texture_ptr_stmt = flatten_rvalue(texture_ptr, ctx); std::vector arg_stmts; for (Expr &arg : args.exprs) { - flatten_rvalue(arg, ctx); - arg_stmts.push_back(arg->stmt); + arg_stmts.push_back(flatten_rvalue(arg, ctx)); } - ctx->push_back(op, texture_ptr->stmt, arg_stmts); + ctx->push_back(op, texture_ptr_stmt, arg_stmts); stmt = ctx->back_stmt(); } @@ -1121,8 +1113,7 @@ void FuncCallExpression::type_check(CompileConfig *) { void FuncCallExpression::flatten(FlattenContext *ctx) { std::vector stmt_args; for (auto &arg : args.exprs) { - flatten_rvalue(arg, ctx); - stmt_args.push_back(arg->stmt); + stmt_args.push_back(flatten_rvalue(arg, ctx)); } ctx->push_back(func, stmt_args); stmt = ctx->back_stmt(); @@ -1137,7 +1128,7 @@ void GetElementExpression::type_check(CompileConfig *config) { } void GetElementExpression::flatten(FlattenContext *ctx) { - ctx->push_back(src->stmt, index); + ctx->push_back(src->get_flattened_stmt(), index); stmt = ctx->back_stmt(); } // Mesh related. @@ -1157,13 +1148,13 @@ void MeshRelationAccessExpression::type_check(CompileConfig *) { } void MeshRelationAccessExpression::flatten(FlattenContext *ctx) { - flatten_rvalue(mesh_idx, ctx); + auto mesh_idx_stmt = flatten_rvalue(mesh_idx, ctx); if (neighbor_idx) { - flatten_rvalue(neighbor_idx, ctx); - ctx->push_back(mesh, mesh_idx->stmt, to_type, - neighbor_idx->stmt); + auto neighbor_idx_stmt = flatten_rvalue(neighbor_idx, ctx); + ctx->push_back(mesh, mesh_idx_stmt, to_type, + neighbor_idx_stmt); } else { - ctx->push_back(mesh, mesh_idx->stmt, to_type); + ctx->push_back(mesh, mesh_idx_stmt, to_type); } stmt = ctx->back_stmt(); } @@ -1173,8 +1164,8 @@ void MeshIndexConversionExpression::type_check(CompileConfig *) { } void MeshIndexConversionExpression::flatten(FlattenContext *ctx) { - flatten_rvalue(idx, ctx); - ctx->push_back(mesh, idx_type, idx->stmt, conv_type); + auto idx_stmt = flatten_rvalue(idx, ctx); + ctx->push_back(mesh, idx_type, idx_stmt, conv_type); stmt = ctx->back_stmt(); } @@ -1183,8 +1174,8 @@ void ReferenceExpression::type_check(CompileConfig *) { } void ReferenceExpression::flatten(FlattenContext *ctx) { - flatten_lvalue(var, ctx); - ctx->push_back(var->stmt); + auto var_stmt = flatten_lvalue(var, ctx); + ctx->push_back(var_stmt); stmt = ctx->back_stmt(); } @@ -1570,39 +1561,43 @@ void ASTBuilder::pop_scope() { loop_state_stack_.pop_back(); } -void flatten_lvalue(Expr expr, Expression::FlattenContext *ctx) { +Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx) { expr->flatten(ctx); + return expr->get_flattened_stmt(); } -void flatten_global_load(Expr ptr, Expression::FlattenContext *ctx) { - ctx->push_back(std::make_unique(ptr->stmt)); - ptr->stmt = ctx->back_stmt(); +Stmt *flatten_global_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { + ctx->push_back(std::make_unique(ptr_stmt)); + return ctx->back_stmt(); } -void flatten_local_load(Expr ptr, Expression::FlattenContext *ctx) { - ctx->push_back(ptr->stmt); - ptr->stmt = ctx->back_stmt(); +Stmt *flatten_local_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { + ctx->push_back(ptr_stmt); + return ctx->back_stmt(); } -void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { +Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { ptr->flatten(ctx); + Stmt *ptr_stmt = ptr->get_flattened_stmt(); if (ptr.is()) { - if (ptr->stmt->is()) { - flatten_local_load(ptr, ctx); + if (ptr_stmt->is()) { + return flatten_local_load(ptr_stmt, ctx); } } else if (ptr.is()) { auto ix = ptr.cast(); if (ix->is_local()) { - flatten_local_load(ptr, ctx); + return flatten_local_load(ptr_stmt, ctx); } else { - flatten_global_load(ptr, ctx); + return flatten_global_load(ptr_stmt, ctx); } } else if (ptr.is()) { - flatten_global_load(ptr, ctx); + return flatten_global_load(ptr_stmt, ctx); } else if (ptr.is() && ptr.cast()->is_ptr) { - flatten_global_load(ptr, ctx); + return flatten_global_load(ptr_stmt, ctx); } + + return ptr_stmt; } } // namespace taichi::lang diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 31670288af918..e2a52fa88f0a2 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -1067,8 +1067,8 @@ class FrontendContext { } }; -void flatten_lvalue(Expr expr, Expression::FlattenContext *ctx); +Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx); -void flatten_rvalue(Expr expr, Expression::FlattenContext *ctx); +Stmt *flatten_rvalue(Expr expr, Expression::FlattenContext *ctx); } // namespace taichi::lang diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 68dcb47b70401..c473c165684c2 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -84,9 +84,9 @@ class LowerAST : public IRVisitor { void visit(FrontendIfStmt *stmt) override { auto fctx = make_flatten_ctx(); - flatten_rvalue(stmt->condition, &fctx); + auto condition_stmt = flatten_rvalue(stmt->condition, &fctx); - auto new_if = std::make_unique(stmt->condition->stmt); + auto new_if = std::make_unique(condition_stmt); if (stmt->true_statements) { new_if->set_true_statements(std::move(stmt->true_statements)); @@ -116,9 +116,9 @@ class LowerAST : public IRVisitor { for (auto c : stmt->contents) { if (std::holds_alternative(c)) { auto x = std::get(c); - flatten_rvalue(x, &fctx); - stmts.push_back(x->stmt); - new_contents.push_back(x->stmt); + auto x_stmt = flatten_rvalue(x, &fctx); + stmts.push_back(x_stmt); + new_contents.push_back(x_stmt); } else { auto x = std::get(c); new_contents.push_back(x); @@ -145,8 +145,7 @@ class LowerAST : public IRVisitor { // while (1) { cond; if (no active) break; original body...} auto cond = stmt->cond; auto fctx = make_flatten_ctx(); - flatten_rvalue(cond, &fctx); - auto cond_stmt = fctx.back_stmt(); + auto cond_stmt = flatten_rvalue(cond, &fctx); auto &&new_while = std::make_unique(std::move(stmt->body)); auto mask = std::make_unique(PrimitiveType::i32); @@ -291,15 +290,15 @@ class LowerAST : public IRVisitor { TI_ASSERT(stmt->loop_var_ids.size() == 1); auto begin = stmt->begin; auto end = stmt->end; - flatten_rvalue(begin, &fctx); - flatten_rvalue(end, &fctx); + auto begin_stmt = flatten_rvalue(begin, &fctx); + auto end_stmt = flatten_rvalue(end, &fctx); bool is_good_range_for = detected_fors_with_break_.find(stmt) == detected_fors_with_break_.end(); // #578: a good range for is a range for that doesn't contain a break // statement if (is_good_range_for) { auto &&new_for = std::make_unique( - begin->stmt, end->stmt, std::move(stmt->body), + begin_stmt, end_stmt, std::move(stmt->body), stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized); new_for->body->insert(std::make_unique(new_for.get(), 0), @@ -316,7 +315,7 @@ class LowerAST : public IRVisitor { stmt->parent->local_var_to_stmt[stmt->loop_var_ids[0]] = loop_var; auto const_one = fctx.push_back(TypedConstant((int32)1)); auto begin_minus_one = fctx.push_back( - BinaryOpType::sub, begin->stmt, const_one); + BinaryOpType::sub, begin_stmt, const_one); fctx.push_back(loop_var, begin_minus_one); auto loop_var_addr = loop_var->as(); VecStatement load_and_compare; @@ -326,7 +325,7 @@ class LowerAST : public IRVisitor { BinaryOpType::add, loop_var_load_stmt, const_one); auto cond_stmt = load_and_compare.push_back( - BinaryOpType::cmp_lt, loop_var_add_one, end->stmt); + BinaryOpType::cmp_lt, loop_var_add_one, end_stmt); auto &&new_while = std::make_unique(std::move(stmt->body)); auto mask = std::make_unique(PrimitiveType::i32); @@ -384,8 +383,7 @@ class LowerAST : public IRVisitor { auto fctx = make_flatten_ctx(); std::vector return_ele; for (auto &x : expr_group.exprs) { - flatten_rvalue(x, &fctx); - return_ele.push_back(fctx.back_stmt()); + return_ele.push_back(flatten_rvalue(x, &fctx)); } fctx.push_back(return_ele); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); @@ -395,23 +393,23 @@ class LowerAST : public IRVisitor { auto dest = assign->lhs; auto expr = assign->rhs; auto fctx = make_flatten_ctx(); - flatten_rvalue(expr, &fctx); - flatten_lvalue(dest, &fctx); + auto expr_stmt = flatten_rvalue(expr, &fctx); + auto dest_stmt = flatten_lvalue(dest, &fctx); if (dest.is()) { - fctx.push_back(dest->stmt, expr->stmt); + fctx.push_back(dest_stmt, expr_stmt); } else if (dest.is()) { auto ix = dest.cast(); if (ix->is_local()) { - fctx.push_back(dest->stmt, expr->stmt); + fctx.push_back(dest_stmt, expr_stmt); } else { - fctx.push_back(dest->stmt, expr->stmt); + fctx.push_back(dest_stmt, expr_stmt); } } else if (dest.is()) { - fctx.push_back(dest->stmt, expr->stmt); + fctx.push_back(dest_stmt, expr_stmt); } else { TI_ASSERT(dest.is() && dest.cast()->is_ptr); - fctx.push_back(dest->stmt, expr->stmt); + fctx.push_back(dest_stmt, expr_stmt); } fctx.stmts.back()->set_tb(assign->tb); assign->parent->replace_with(assign, std::move(fctx.stmts)); @@ -422,15 +420,12 @@ class LowerAST : public IRVisitor { Stmt *val_stmt = nullptr; auto fctx = make_flatten_ctx(); if (stmt->val.expr) { - auto expr = stmt->val; - flatten_rvalue(expr, &fctx); - val_stmt = expr->stmt; + val_stmt = flatten_rvalue(stmt->val, &fctx); } std::vector indices_stmt(stmt->indices.size(), nullptr); for (int i = 0; i < (int)stmt->indices.size(); i++) { - flatten_rvalue(stmt->indices[i], &fctx); - indices_stmt[i] = stmt->indices[i]->stmt; + indices_stmt[i] = flatten_rvalue(stmt->indices[i], &fctx); } if (stmt->snode->type == SNodeType::dynamic) { @@ -459,16 +454,13 @@ class LowerAST : public IRVisitor { Stmt *val_stmt = nullptr; auto fctx = make_flatten_ctx(); if (stmt->cond.expr) { - auto expr = stmt->cond; - flatten_rvalue(expr, &fctx); - val_stmt = expr->stmt; + val_stmt = flatten_rvalue(stmt->cond, &fctx); } auto &fargs = stmt->args; // frontend stmt args std::vector args_stmts(fargs.size()); for (int i = 0; i < (int)fargs.size(); ++i) { - flatten_rvalue(fargs[i], &fctx); - args_stmts[i] = fargs[i]->stmt; + args_stmts[i] = flatten_rvalue(fargs[i], &fctx); } fctx.push_back(val_stmt, stmt->text, args_stmts); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); @@ -489,12 +481,10 @@ class LowerAST : public IRVisitor { std::vector arg_statements, output_statements; if (stmt->so_func != nullptr || !stmt->asm_source.empty()) { for (auto &s : stmt->args) { - flatten_rvalue(s, &ctx); - arg_statements.push_back(s->stmt); + arg_statements.push_back(flatten_rvalue(s, &ctx)); } for (auto &s : stmt->outputs) { - flatten_lvalue(s, &ctx); - output_statements.push_back(s->stmt); + output_statements.push_back(flatten_lvalue(s, &ctx)); } ctx.push_back(std::make_unique( (stmt->so_func != nullptr) ? ExternalFuncCallStmt::SHARED_OBJECT @@ -506,8 +496,7 @@ class LowerAST : public IRVisitor { TI_ASSERT_INFO( s.is(), "external func call via bitcode must pass in local variables.") - flatten_lvalue(s, &ctx); - arg_statements.push_back(s->stmt); + arg_statements.push_back(flatten_lvalue(s, &ctx)); } ctx.push_back(std::make_unique( ExternalFuncCallStmt::BITCODE, nullptr, "", stmt->bc_filename,