diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc index ef408b7b7778a..f7fffa0e0ff4b 100644 --- a/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc @@ -20,6 +20,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/common/context.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" @@ -49,9 +50,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec( - "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {A, B}, &tensor_group, target); ASSERT_FALSE(funcs.empty()); ir::Expr ast_expr = funcs[0]->body; @@ -115,9 +116,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { ir::Tensor C = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = poly::CreateStages({C}); - std::vector funcs = lang::LowerVec( - "AddDiffShape", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("AddDiffShape", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before MultiLevelTiling: "; @@ -169,9 +170,9 @@ TEST(AnalyzeIr, ContainsNodeType) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec( - "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {A, B}, &tensor_group, target); ASSERT_FALSE(funcs.empty()); ir::Expr ast_expr = funcs[0]->body; diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc index 9364374156f4a..3b51eac2600e3 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/common/context.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" @@ -48,9 +49,9 @@ TEST(FeatureExtractor, SimpleAssign) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec( - "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {A, B}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr to test: " << ast_expr; @@ -109,9 +110,9 @@ TEST(FeatureExtractor, MatrixMultiply) { [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); - poly::StageMap stages = poly::CreateStages({C}); - std::vector funcs = lang::LowerVec( - "MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("SimpleAssign", {C}, &tensor_group, target); std::vector vec_ast{funcs[0]->body}; ir::ModuleExpr mod_expr(vec_ast); diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc index 5d6d1be6e0c13..5db6f8999b18a 100644 --- a/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc +++ b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc @@ -20,6 +20,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/search_state.h" #include "paddle/cinn/auto_schedule/task/task_registry.h" #include "paddle/cinn/cinn.h" @@ -47,8 +48,8 @@ std::vector LowerCompute(const std::vector& shape, C = Compute( domain, [&B](Var i, Var j) { return B(i, j); }, "C"); - return cinn::lang::LowerVec( - "test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({A, B}); + return cinn::lang::LowerToAstVec("test_func", {A, B}, &tensor_group, target); } // Create a new IRSchedule with copied ir::LoweredFunc AST diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc index 4cfef12e030e0..e69d3069f1939 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/cinn.h" @@ -59,16 +60,13 @@ TEST(AutoInline, SingleLoopInline) { ir::Tensor C = Compute( {M}, [&](Var i) { return B(i) + ir::Expr(1.f); }, "C"); - poly::StageMap stages = CreateStages({A, B, C}); + ast_gen_ius::TensorGroup tensor_group({A, B, C}); std::vector funcs = - lang::LowerVec("TestAutoInline_SingleLoopInline", - stages, - {A, C}, - {}, - {}, - nullptr, - target, - true); + lang::LowerToAstVec("TestAutoInline_SingleLoopInline", + + {A, C}, + &tensor_Group, + target); VLOG(6) << "Expr after lowering:"; VLOG(6) << funcs[0]->body; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc index 8b08d2c0658b3..e4b0597cfeed7 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/lang/lower.h" @@ -38,9 +39,9 @@ TEST(AutoUnroll, Init) { #else Target target = common::DefaultHostTarget(); #endif - auto stages = CreateStages({C}); - auto funcs = cinn::lang::LowerVec( - "test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + auto funcs = + cinn::lang::LowerToAstVec("test_init", {A, B, C}, &tensor_group, target); auto ast_expr = funcs[0]->body; ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr})); diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc index 5a5c68537e9a7..62f1bb74f4ac0 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" #include "paddle/cinn/cinn.h" @@ -106,16 +107,9 @@ TEST(MultiLevelTile, SimpleLoops) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = - lang::LowerVec("TestMultiLevelTile_SimpleLoops", - stages, - {C}, - {}, - {}, - nullptr, - target, - true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = lang::LowerToAstVec( + "TestMultiLevelTile_SimpleLoops", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before MultiLevelTiling: "; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc index 52f38e0b65b03..5ba15a46fef18 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc @@ -21,6 +21,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/ir.h" @@ -52,9 +53,9 @@ TEST(SkipRule, Basic) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = lang::LowerVec( - "TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before SkipRule: "; @@ -101,9 +102,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = lang::LowerVec( - "TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + ast_gen_ius::TensorGroup tensor_group({C}); + std::vector funcs = + lang::LowerToAstVec("TestSkipRule_Basic", {C}, &tensor_group, target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before SkipRule: "; diff --git a/paddle/cinn/auto_schedule/search_space/search_state_test.cc b/paddle/cinn/auto_schedule/search_space/search_state_test.cc index 61547d228302f..b0f216c4895aa 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state_test.cc +++ b/paddle/cinn/auto_schedule/search_space/search_state_test.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/common/context.h" @@ -35,35 +36,18 @@ TEST(TestSearchState, SearchStateHash_Equal) { ir::Tensor C = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + ast_gen_ius::TensorGroup const_group_1({A, B}); cinn::common::Context::Global().ResetNameId(); - auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const", - poly::CreateStages({A, B}), - {A, B}, - {}, - {}, - nullptr, - target, - true); - + auto a_plus_const_funcs_1 = + lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_1, target); cinn::common::Context::Global().ResetNameId(); - auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const", - poly::CreateStages({A, B}), - {A, B}, - {}, - {}, - nullptr, - target, - true); - + ast_gen_ius::TensorGroup const_group_2({A, B}); + auto a_plus_const_funcs_2 = + lang::LowerToAstVec("A_plus_const", {A, B}, &const_group_2, target); cinn::common::Context::Global().ResetNameId(); - auto a_plus_b_funcs = lang::LowerVec("A_plus_B", - poly::CreateStages({A, C}), - {A, C}, - {}, - {}, - nullptr, - target, - true); + ast_gen_ius::TensorGroup plus_group({A, C}); + auto a_plus_b_funcs = + lang::LowerToAstVec("A_plus_B", {A, C}, &plus_group, target); std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B) { diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc index 443c297c5e722..2e4ecf034b740 100644 --- a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" @@ -46,16 +47,13 @@ TEST(MutateTileSize, Basic) { [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); - poly::StageMap stages = CreateStages({A, B, C}); + ast_gen_ius::TensorGroup tensor_group({A, B, C}); std::vector funcs = - lang::LowerVec("TestMutateTileSize_Basic", - stages, - {A, B, C}, - {}, - {}, - nullptr, - target, - true); + lang::LowerToAstVec("TestMutateTileSize_Basic", + + {A, B, C}, + &tensor_group, + target); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Original Expr: "; diff --git a/paddle/cinn/cinn.h b/paddle/cinn/cinn.h index 333bc051ead98..e81771ba0c7e7 100644 --- a/paddle/cinn/cinn.h +++ b/paddle/cinn/cinn.h @@ -29,6 +29,7 @@ namespace cinn { +using ast_gen_ius::TensorGroup; using backends::CodeGenC; using backends::CodeGenCX86; using backends::Outputs; @@ -39,6 +40,7 @@ using lang::CallExtern; using lang::CallLowered; using lang::Compute; using lang::Lower; +using lang::LowerToAst; using lang::Placeholder; using lang::ReduceAll; using lang::ReduceAny; diff --git a/paddle/cinn/ir/test/collect_ir_nodes_test.cc b/paddle/cinn/ir/test/collect_ir_nodes_test.cc index d380b4475e37d..859a35a5c0fa9 100644 --- a/paddle/cinn/ir/test/collect_ir_nodes_test.cc +++ b/paddle/cinn/ir/test/collect_ir_nodes_test.cc @@ -42,15 +42,15 @@ TEST(CollectIRNodes, basic) { auto C = Compute( {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); - auto stages = CreateStages({C}); + ast_gen_ius::TensorGroup tensor_group({C}); - auto fn = Lower("fn", stages, {A, B, C}); + auto fn = LowerToAst("fn", {A, B, C}, &tensor_group); LOG(INFO) << "fn:\n" << fn; auto tensors = CollectIRNodes(fn, [](const Expr* x) { return x->as_tensor(); }); - ASSERT_EQ(tensors.size(), 5UL); + ASSERT_EQ(tensors.size(), 3UL); auto fn_body = fn.As()->body; LOG(INFO) << "fn.body:\n" << fn_body;