diff --git a/scripts/python/tao_build.py b/scripts/python/tao_build.py index 6dd54a5554f..3d6e6c8e98b 100755 --- a/scripts/python/tao_build.py +++ b/scripts/python/tao_build.py @@ -391,7 +391,7 @@ def bazel_test(target, flag=""): logger.info("Testing bazel target: " + target) flag += " --experimental_ui_max_stdouterr_bytes=-1 " execute(" ".join([BAZEL_BUILD_CMD, flag, target])) - execute(" ".join([BAZEL_TEST_CMD, flag + ' --test_env=TF_CPP_VMODULE=disc_compiler=1 --test_env=TF_ENABLE_ONEDNN_OPTS=0' , target])) + execute(" ".join([BAZEL_TEST_CMD, flag + ' --test_env=TF_CPP_VMODULE=disc_compiler=1,disc_transform_legalize_to_loop=1 --test_env=TF_ENABLE_ONEDNN_OPTS=0' , target])) with cwd(tf_root_dir(root)), gcc_env(args.compiler_gcc): execute( diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 792cc2c5301..d412096cea1 100644 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -2039,8 +2039,9 @@ disc_cc_library( ) cc_library( - name = "disc_transform_legalize_to_loop", - srcs = ["transforms/disc_transform_legalize_to_loop.cc"], + name = "disc_transform_schedule", + srcs = ["transforms/disc_transform_schedule.cc"], + hdrs = ["transforms/disc_transform_schedule.h"], deps = [ ":codegen_utils", ":disc_lhlo_elemental_utils", @@ -2053,6 +2054,7 @@ cc_library( "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/mlir/disc/tools/disc-transform:all_passes", "//tensorflow/compiler/mlir/disc/tools/disc-transform:DISCLinalgExtDialect", + "//tensorflow/tsl/platform:logging", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:FuncDialect", @@ -2069,6 +2071,15 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_transform_legalize_to_loop", + srcs = ["transforms/disc_transform_legalize_to_loop.cc"], + deps = [ + ":disc_transform_schedule", + ], + alwayslink = 1, +) + cc_library( name = "all_passes", hdrs = [ diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_f32_large_schedule.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_f32_large_schedule.mlir index 8d6c30b9cdf..98962fc3a4b 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_f32_large_schedule.mlir +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_f32_large_schedule.mlir @@ -54,7 +54,7 @@ transform.structured.canonicalized_sequence failures(propagate) { multireduction_lowering = "innerparallel", split_transfers = "linalg-copy", // stages = [0, 1, 2, 3, 4, 5, 6, 7], - stages = [0, 1, 2, 3], + stages = [0, 1, 2, 3, 4], transpose_avx2_lowering = false, transpose_lowering = "eltwise", unroll_vector_transfers = true diff --git a/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc b/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc index 2183fcd0c7f..1473681733b 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc @@ -28,9 +28,10 @@ static bool init_threads = []() { }(); TEST(SimpleTest, MatMulF32_11x13x12) { - EnvSetting setting = {{"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, - {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}}; + EnvSetting setting = { + {"DISC_TRANSFORM_SCHEDULE_FILE", + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, + {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}}; EnvSettingContext ctx(setting); EXPECT_TRUE(feature_test_main( /*mlir_file_path*/ c_ft_path + "matmul_nn_d_f32.mlir", @@ -42,9 +43,10 @@ TEST(SimpleTest, MatMulF32_11x13x12) { } TEST(SimpleTest, MatMulF32_111x131x121) { - EnvSetting setting = {{"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, - {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}}; + EnvSetting setting = { + {"DISC_TRANSFORM_SCHEDULE_FILE", + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, + {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}}; EnvSettingContext ctx(setting); EXPECT_TRUE(feature_test_main( /*mlir_file_path*/ c_ft_path + "matmul_nn_d_f32.mlir", @@ -58,7 +60,7 @@ TEST(SimpleTest, MatMulF32_111x131x121) { TEST(SimpleTest, MatMulF32_304x1024x256) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -78,7 +80,7 @@ TEST(SimpleTest, MatMulF32_304x1024x256) { TEST(SimpleTest, MatMulF32_1024x1024x1024) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -98,7 +100,8 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024) { TEST(SimpleTest, MatMulF32_304x1024x256_2) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -118,7 +121,8 @@ TEST(SimpleTest, MatMulF32_304x1024x256_2) { TEST(SimpleTest, MatMulF32_1024x1024x1024_2) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -138,7 +142,8 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_2) { TEST(SimpleTest, MatMulF32_304x256x256_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -158,7 +163,8 @@ TEST(SimpleTest, MatMulF32_304x256x256_3) { TEST(SimpleTest, MatMulF32_304x512x256_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -178,7 +184,8 @@ TEST(SimpleTest, MatMulF32_304x512x256_3) { TEST(SimpleTest, MatMulF32_304x1024x256_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -198,7 +205,8 @@ TEST(SimpleTest, MatMulF32_304x1024x256_3) { TEST(SimpleTest, MatMulF32_304x1024x512_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -218,7 +226,8 @@ TEST(SimpleTest, MatMulF32_304x1024x512_3) { TEST(SimpleTest, MatMulF32_1024x1024x1024_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -238,7 +247,8 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_3) { TEST(SimpleTest, MatMulF32_304x1024x512_4) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -258,7 +268,8 @@ TEST(SimpleTest, MatMulF32_304x1024x512_4) { TEST(SimpleTest, MatMulF32_1024x1024x1024_4) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -278,7 +289,8 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_4) { TEST(SimpleTest, MatMulF32_1026x1024x1024_4) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -298,7 +310,8 @@ TEST(SimpleTest, MatMulF32_1026x1024x1024_4) { TEST(SimpleTest, MatMulF32_304x1024x512_5) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_nn_d_f32_large_schedule_5.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_5.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; diff --git a/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc b/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc index 4c830625e16..f8ebcbba0f3 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc @@ -30,7 +30,8 @@ static bool init_threads = []() { TEST(SimpleMTTest, MatMulF32_111x131x121_Thread_8) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_multithread_nn_d_f32_schedule.mlir", false}}, + {"kGEMM::" + c_ft_path + "matmul_multithread_nn_d_f32_schedule.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}}; EnvSettingContext ctx(setting); EXPECT_TRUE(feature_test_main( @@ -43,10 +44,11 @@ TEST(SimpleMTTest, MatMulF32_111x131x121_Thread_8) { } TEST(SimpleTest, MatMulF32_304x1024x256) { - EnvSetting setting = { - {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "matmul_multithread_nn_d_f32_large_schedule.mlir", false}}, - {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}}; + EnvSetting setting = {{"DISC_TRANSFORM_SCHEDULE_FILE", + {"kGEMM::" + c_ft_path + + "matmul_multithread_nn_d_f32_large_schedule.mlir", + false}}, + {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}}; EnvSettingContext ctx(setting); EXPECT_TRUE(feature_test_main( /*mlir_file_path*/ c_ft_path + "matmul_multithread_nn_d_f32.mlir", diff --git a/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc b/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc index dad1f1b796e..5cf780a866c 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc @@ -30,7 +30,8 @@ static bool init_threads = []() { TEST(PackedMatmul, F32_304x1024x512) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {c_ft_path + "packed_matmul_nn_p_f32_large_schedule.mlir", false}}, + {"kGEMM::" + c_ft_path + "packed_matmul_nn_p_f32_large_schedule.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -47,4 +48,21 @@ TEST(PackedMatmul, F32_304x1024x512) { /*profiling*/ true)); } +TEST(PackedMatmul, F32_304x1024x512_Using_Default_Schedule) { + EnvSetting setting = {{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, + {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, + {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; + EnvSettingContext ctx(setting); + EXPECT_TRUE(feature_test_main( + /*mlir_file_path*/ c_ft_path + "packed_matmul_nn_p_512x1024_f32.mlir", + /*backend_types*/ {BackendType::kAArch64}, + /*num_inputs*/ 1, + /*num_outputs*/ 1, + /*input_descriptors*/ {"304x512xf32_X"}, + /*output_descriptors*/ {"f32_X"}, + /*input_vals*/ {}, + /*expected_output_vals*/ {}, + /*profiling*/ true)); +} + } // namespace mlir_test diff --git a/tao_compiler/mlir/disc/tools/disc-transform/BUILD b/tao_compiler/mlir/disc/tools/disc-transform/BUILD index 30302019b59..04b7c210eba 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/BUILD +++ b/tao_compiler/mlir/disc/tools/disc-transform/BUILD @@ -319,6 +319,7 @@ cc_library( "@llvm-project//mlir:PDLInterpDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransformOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformDialect", diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc index 2a94bc13798..faee8351a0e 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc @@ -198,9 +198,11 @@ DiagnosedSilenceableFailure DISCBufferizeOp::apply( })); pm.addNestedPass(bufferization::createBufferDeallocationPass()); - if (failed(pm.run(state.getTopLevel()))) + if (failed(pm.run(moduleOp))) return DiagnosedSilenceableFailure::definiteFailure(); + results.set(getResult().cast(), {payload}); + return DiagnosedSilenceableFailure::success(); } diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h b/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h index 2cf124a32b7..98076247412 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h @@ -49,6 +49,10 @@ std::unique_ptr> createDiscTransformDialectInterpreterPass(const std::string& fileName = "", bool enableExpensiveChecks = false); +// Erases transform dialect schedule from the IR +std::unique_ptr> +createDiscTransformDialectEraseSchedulePass(); + // Converts the transformed payload IR to be suitable for RAL. std::unique_ptr> createDiscRewritePayloadIRForRALPass( bool gpuEnabled = false); diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_dialect_interpreter.cc b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_dialect_interpreter.cc index c09ce68fe5a..ebfd463ea69 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_dialect_interpreter.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_dialect_interpreter.cc @@ -15,7 +15,6 @@ #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" @@ -29,6 +28,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" @@ -36,13 +36,12 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/FileUtilities.h" #include "mlir/Transforms/Passes.h" #include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtDialect.h" #include "tensorflow/compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h" #include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/utils.h" #define DEBUG_TYPE "disc-transform-dialect-interpreter" @@ -50,27 +49,51 @@ namespace mlir { namespace disc_ral { -namespace { -LogicalResult parseTransformModuleFromFile( - MLIRContext* context, llvm::StringRef transformFileName, - OwningOpRef& transformModule) { - // Parse transformFileName content into a ModuleOp. - std::string errorMessage; - auto memoryBuffer = openInputFile(transformFileName, &errorMessage); - if (!memoryBuffer) { - llvm::errs() << "failed to parse transform file: " << transformFileName - << "\n"; - return failure(); - } - // Tell sourceMgr about this buffer, the parser will pick it up. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); - transformModule = - OwningOpRef(parseSourceFile(sourceMgr, context)); - return success(); +void addTransformDialectDependentDialects(DialectRegistry& registry) { + // TODO: this is only necessary to make registry subset happy when running + // the lowering to LLVM. The lowering should be changed to stop using the + // nested pass manager and this will go away. + + // clang-format off + registry.insert(); + + // TODO: these should be registered by the extension instead, but there is + // no support for it in core currently. + arith::registerBufferizableOpInterfaceExternalModels(registry); + linalg::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + vector::registerBufferizableOpInterfaceExternalModels(registry); + + registry.addExtensions< + mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension, + transform_ext::StructuredTransformOpsExtension>(); + linalg::registerTransformDialectExtension(registry); + scf::registerTransformDialectExtension(registry); + registerTransformDialectCommonExtension(registry); } +namespace { + struct DiscTransformDialectInterpreterPass : public DiscTransformDialectInterpreterPassBase< DiscTransformDialectInterpreterPass> { @@ -84,43 +107,7 @@ struct DiscTransformDialectInterpreterPass } void getDependentDialects(DialectRegistry& registry) const override { - // TODO: this is only necessary to make registry subset happy when running - // the lowering to LLVM. The lowering should be changed to stop using the - // nested pass manager and this will go away. - - // clang-format off - registry.insert(); - - // TODO: these should be registered by the extension instead, but there is - // no support for it in core currently. - arith::registerBufferizableOpInterfaceExternalModels(registry); - linalg::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerBufferizableOpInterfaceExternalModels(registry); - bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( - registry); - tensor::registerBufferizableOpInterfaceExternalModels(registry); - vector::registerBufferizableOpInterfaceExternalModels(registry); - - registry.addExtensions< - mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension, - transform_ext::StructuredTransformOpsExtension>(); - linalg::registerTransformDialectExtension(registry); - registerTransformDialectCommonExtension(registry); + addTransformDialectDependentDialects(registry); } void runOnOperation() override; @@ -161,6 +148,22 @@ void DiscTransformDialectInterpreterPass::runOnOperation() { } } +struct DiscTransformDialectEraseSchedulePass + : public DiscTransformDialectEraseSchedulePassBase< + DiscTransformDialectEraseSchedulePass> { + void runOnOperation() override; +}; + +void DiscTransformDialectEraseSchedulePass::runOnOperation() { + getOperation()->walk([&](Operation* nestedOp) { + if (isa(nestedOp)) { + nestedOp->erase(); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); +} + } // namespace std::unique_ptr> @@ -170,5 +173,10 @@ createDiscTransformDialectInterpreterPass(const std::string& fileName, fileName, enableExpensiveChecks); } +std::unique_ptr> +createDiscTransformDialectEraseSchedulePass() { + return std::make_unique(); +} + } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td index bda97e24d26..203e36a2620 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td @@ -34,6 +34,11 @@ def DiscTransformDialectInterpreterPass : Pass<"disc-transform-dialect-interpret ]; } +def DiscTransformDialectEraseSchedulePass : Pass<"disc-transform-dialect-erase-schedule", "ModuleOp"> { + let summary = "Pass to erase transform dialect schedule from the IR."; + let constructor = "createDiscTransformDialectEraseSchedulePass()"; +} + def DiscRewritePayloadIRForRALPass : Pass<"disc-rewrite-payload-ir-for-ral", "ModuleOp"> { let summary = "Pass to rewrite the payload IR transformed by transform IR to be suitable for RAL."; let constructor = "createDiscRewritePayloadIRForRALPass()"; diff --git a/tao_compiler/mlir/disc/tools/disc-transform/utils.cc b/tao_compiler/mlir/disc/tools/disc-transform/utils.cc index 035b1575c38..4414e9011ca 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/utils.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/utils.cc @@ -12,7 +12,10 @@ #include "tensorflow/compiler/mlir/disc/tools/disc-transform/utils.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/FileUtilities.h" #define DEBUG_TYPE "disc-transform-utils" @@ -49,5 +52,25 @@ Operation* createLinalgCopyOp(OpBuilder& b, Location loc, Value from, Value to, attributes); } +/// Load transform dialect IR from the given file. +LogicalResult parseTransformModuleFromFile( + MLIRContext* context, llvm::StringRef transformFileName, + OwningOpRef& transformModule) { + // Parse transformFileName content into a ModuleOp. + std::string errorMessage; + auto memoryBuffer = openInputFile(transformFileName, &errorMessage); + if (!memoryBuffer) { + llvm::errs() << "failed to parse transform file: " << transformFileName + << "\n"; + return failure(); + } + // Tell sourceMgr about this buffer, the parser will pick it up. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + transformModule = + OwningOpRef(parseSourceFile(sourceMgr, context)); + return success(); +} + } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/tools/disc-transform/utils.h b/tao_compiler/mlir/disc/tools/disc-transform/utils.h index 03be831ea1b..51eabf183b9 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/utils.h +++ b/tao_compiler/mlir/disc/tools/disc-transform/utils.h @@ -27,6 +27,14 @@ namespace disc_ral { Operation* createLinalgCopyOp(OpBuilder& b, Location loc, Value from, Value to, ArrayRef attributes = {}); +/// Load transform dialect IR from the given file. +LogicalResult parseTransformModuleFromFile( + MLIRContext* context, llvm::StringRef transformFileName, + OwningOpRef& transformModule); + +// Appends transform dependent dialects. +void addTransformDialectDependentDialects(DialectRegistry& registry); + } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index ec297bce8fa..b999fe41d74 100644 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -575,13 +575,4 @@ def DiscTransformLegalizeToLoopPass : Pass<"disc-transform-legalize-to-loop", "m Option<"enableExpensiveChecks_", "enable-expensive-checks", "bool", /*default=*/"false", "perform expensive checks to better report errors in the transform IR.">, ]; - let dependentDialects = [ - "AffineDialect", - "LLVM::LLVMDialect", - "arith::ArithDialect", - "linalg::LinalgDialect", - "scf::SCFDialect", - "tensor::TensorDialect", - "vector::VectorDialect", - ]; } diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc b/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc index b70d82f7dac..21b4bf65640 100644 --- a/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc +++ b/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -37,11 +38,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtDialect.h" #include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h" #include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/passes.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/utils.h" #include "tensorflow/compiler/mlir/disc/transforms/PassDetail.h" #include "tensorflow/compiler/mlir/disc/transforms/codegen_utils.h" +#include "tensorflow/compiler/mlir/disc/transforms/disc_transform_schedule.h" #include "tensorflow/compiler/mlir/disc/transforms/fusion_utils.h" #include "tensorflow/compiler/mlir/disc/transforms/lhlo_elemental_utils.h" #include "tensorflow/compiler/mlir/disc/transforms/placement_utils.h" +#include "tensorflow/tsl/platform/default/logging.h" // This file implements logic to legalize transform fusion pattern to loop. @@ -83,8 +87,13 @@ struct DiscTransformLegalizeToLoopPass void runOnOperation() override; + void getDependentDialects(DialectRegistry& registry) const override { + addTransformDialectDependentDialects(registry); + } + LogicalResult handleCpuFusionOp(OpBuilder& b, Operation* fusion, - ShapeAnalysis& shapeAnalysis); + ShapeAnalysis& shapeAnalysis, + ScheduleDispatcher& scheduleDispatcher); // Outlines the fusion op to a standalone module op. LogicalResult outlineFusionOp(lmhlo::FusionOp fusionOp, FusionPattern& fusionPattern, @@ -132,10 +141,26 @@ LogicalResult DiscTransformLegalizeToLoopPass::outlineFusionOp( LogicalResult DiscTransformLegalizeToLoopPass::runTransformPipeline( ModuleOp m) { + if (VLOG_IS_ON(1)) + llvm::dbgs() << "/// ------- Apply Transform IR for fusion:\n" + << m << "\n\n"; + PassManager pm(m.getContext()); + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + pm.enableIRPrinting( + /*shouldPrintBeforePass=*/nullptr, + /*shouldPrintAfterPass=*/ + [](Pass* pass, Operation*) { return VLOG_IS_ON(1); }, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/true, + /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); + pm.addPass(createDiscLegalizeLmhloFusionToLinalgPass()); - pm.addPass(createDiscTransformDialectInterpreterPass(transformFileName_, - enableExpensiveChecks_)); + // Using transform IR attached in the module. + pm.addPass(createDiscTransformDialectInterpreterPass( + /* transformFileName */ "", enableExpensiveChecks_)); + pm.addPass(createDiscTransformDialectEraseSchedulePass()); pm.addNestedPass(createDiscMemrefCopyToLinalgPass()); pm.addPass(createDiscRewritePayloadIRForRALPass(gpuEnabled_)); return pm.run(m); @@ -182,7 +207,8 @@ LogicalResult DiscTransformLegalizeToLoopPass::inlineTransformedModule( } LogicalResult DiscTransformLegalizeToLoopPass::handleCpuFusionOp( - OpBuilder& b, Operation* fusion, ShapeAnalysis& shapeAnalysis) { + OpBuilder& b, Operation* fusion, ShapeAnalysis& shapeAnalysis, + ScheduleDispatcher& scheduleDispatcher) { auto fusionOp = cast(fusion); assert(fusionOp); FusionPattern fusionPattern(fusionOp, &shapeAnalysis); @@ -193,19 +219,30 @@ LogicalResult DiscTransformLegalizeToLoopPass::handleCpuFusionOp( // 1, Outline the fusion to a standalone module op. OwningOpRef m; - if (failed(outlineFusionOp(fusionOp, fusionPattern, m))) return failure(); + if (failed(outlineFusionOp(fusionOp, fusionPattern, m))) { + return fusionOp->emitError() << "failed to outlineFusionOp\n"; + } LLVM_DEBUG(llvm::dbgs() << "After outline fusion op:\n" << m.get() << "\n"); - // 2, TODO(wyzero): assign a default schedule for each pattern here. + // 2, assign a default schedule for each pattern here. + PatternDescription patternDescription(fusionOp, fusionPattern, shapeAnalysis); + if (failed(scheduleDispatcher.dispatch(patternDescription, m.get()))) { + return fusionOp->emitError() << "failed to assignSchedule\n"; + } + LLVM_DEBUG(llvm::dbgs() << "After assign schedule for fusion op:\n" + << m.get() << "\n"); // 3, Build a nested pass pipeline to legalize the outlined fusion op. - if (failed(runTransformPipeline(m.get()))) return failure(); + if (failed(runTransformPipeline(m.get()))) { + return fusionOp->emitError() << "failed to run runTransformPipeline\n"; + } LLVM_DEBUG(llvm::dbgs() << "After run transform pipeline:\n" << m.get() << "\n"); // 4, Inline the lowered IR into the orignal module. - if (failed(inlineTransformedModule(b, fusion, fusionPattern, m.get()))) - return failure(); + if (failed(inlineTransformedModule(b, fusion, fusionPattern, m.get()))) { + return fusion->emitError() << "failed to inlineTransformedModule\n"; + } LLVM_DEBUG(llvm::dbgs() << "After inline transformed module:\n" << *fusion << "\n"); return success(); @@ -230,6 +267,9 @@ void DiscTransformLegalizeToLoopPass::runOnOperation() { } } + // Assign a transform schedule for the given fusion pattern. + ScheduleDispatcher scheduleDispatcher{transformFileName_}; + for (Operation* fusion : gpu_fusion_worklist) { // TODO(disc): handling stitch fusion on GPU. return signalPassFailure(); @@ -244,7 +284,8 @@ void DiscTransformLegalizeToLoopPass::runOnOperation() { } // Error message should be emitted inside the function. - if (failed(handleCpuFusionOp(b, fusion, *shapeAnalysisPtr))) { + if (failed(handleCpuFusionOp(b, fusion, *shapeAnalysisPtr, + scheduleDispatcher))) { return signalPassFailure(); } } diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc new file mode 100644 index 00000000000..dd0ee704ba5 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc @@ -0,0 +1,430 @@ +/* Copyright 2022 The BladeDISC Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/disc/transforms/disc_transform_schedule.h" + +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h" +#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" +#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h" +#include "llvm/Support/Debug.h" +#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "tensorflow/compiler/mlir/disc/disc_util.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtDialect.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/passes.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/utils.h" +#include "tensorflow/compiler/mlir/disc/transforms/codegen_utils.h" +#include "tensorflow/compiler/mlir/disc/transforms/fusion_utils.h" +#include "tensorflow/compiler/mlir/disc/transforms/lhlo_elemental_utils.h" +#include "tensorflow/compiler/mlir/disc/transforms/placement_utils.h" +#include "tensorflow/tsl/platform/default/logging.h" + +// This file implements logic to assign transform schedule for the given pattern +// description. + +namespace mlir { +namespace disc_ral { + +namespace { + +std::unordered_map& getPatternKindToStringMap() { + static std::unordered_map patternKindToStringMap; + return patternKindToStringMap; +} + +std::unordered_map& getStringToPatternKindMap() { + static std::unordered_map stringToPatternKindMap; + return stringToPatternKindMap; +} + +bool PatternKindAndStringMapRegistrar = []() { + auto& patternKindToStringMap = getPatternKindToStringMap(); + auto& stringToPatternKindMap = getStringToPatternKindMap(); + patternKindToStringMap.emplace(PatternKind::kNone, "kNone"); + patternKindToStringMap.emplace(PatternKind::kGEMM, "kGEMM"); + for (auto& pair : patternKindToStringMap) { + stringToPatternKindMap[pair.second] = pair.first; + } + return true; +}(); + +using transform::FuseIntoContainingOp; +using transform::FuseOp; +using transform::MatchOp; +using transform::TileOp; +using transform::TileToForeachThreadOp; +using transform::VectorizeOp; + +MatchOp buildMatchOp(OpBuilder& b, Location& loc, Value target, + ArrayRef ops) { + return b.create(loc, pdl::OperationType::get(b.getContext()), target, + b.getStrArrayAttr(ops), + transform::MatchInterfaceEnumAttr{}, + DictionaryAttr{}, TypeAttr{}); +} + +TileToForeachThreadOp buildTileToForEachThreadOp(OpBuilder& b, Location& loc, + Value target, + ArrayRef numThreads) { + auto pdlType = pdl::OperationType::get(b.getContext()); + return b.create( + loc, TypeRange{pdlType, pdlType}, target, + /* num_threads */ ValueRange{}, + /* tile_sizes */ ValueRange{}, + /* static_num_threads */ b.getI64ArrayAttr(numThreads), + /* static_num_threads */ ArrayAttr{}, + /* thread_dim_mapping */ ArrayAttr{}); +} + +FuseIntoContainingOp buildFuseIntoContainingOp(OpBuilder& b, Location& loc, + Value target, Value anchor) { + auto pdlType = pdl::OperationType::get(b.getContext()); + return b.create(loc, pdlType, target, anchor); +} + +FuseOp buildFuseOp(OpBuilder& b, Location& loc, Value target, + ArrayRef tileSizes, ArrayRef interchange) { + auto pdlType = pdl::OperationType::get(b.getContext()); + SmallVector loopTypes; + for (int64_t tileSize : tileSizes) + if (tileSize) loopTypes.push_back(pdlType); + return b.create(loc, pdlType, loopTypes, target, + b.getI64ArrayAttr(tileSizes), + b.getI64ArrayAttr(interchange)); +} + +TileOp buildTileOp(OpBuilder& b, Location& loc, Value target, + ArrayRef tileSizes, ArrayRef interchange) { + auto pdlType = pdl::OperationType::get(b.getContext()); + SmallVector loopTypes; + for (int64_t tileSize : tileSizes) + if (tileSize) loopTypes.push_back(pdlType); + return b.create(loc, pdlType, loopTypes, target, ValueRange{}, + b.getI64ArrayAttr(tileSizes), + b.getI64ArrayAttr(interchange)); +} + +transform_dialect::ApplyPatternsOp buildRunCanonicalizer(OpBuilder& b, + Location& loc, + Value target) { + return b.create(loc, target, true); +} + +transform::GetProducerOfOperand buildGetProducerOfOperand(OpBuilder& b, + Location& loc, + Value target, + int64_t operandIdx) { + auto pdlType = pdl::OperationType::get(b.getContext()); + return b.create(loc, pdlType, target, + operandIdx); +} + +transform_dialect::FoldProducerExtractSliceOp buildFoldProducerExtractSlice( + OpBuilder& b, Location& loc, Value target, int64_t repeat) { + return b.create(loc, target, + repeat); +} + +transform::PadOp buildPadOp(OpBuilder& b, Location& loc, Value target, + ArrayRef paddingDimensions) { + auto pdlType = pdl::OperationType::get(b.getContext()); + // TODO(wyzero): support other types. + SmallVector paddingAttrs(paddingDimensions.size(), + b.getZeroAttr(b.getF32Type())); + return b.create(loc, pdlType, target, + b.getArrayAttr(paddingAttrs), + b.getI64ArrayAttr(paddingDimensions), + ArrayAttr{}, ArrayAttr{}, ArrayAttr{}); +} + +transform::GetParentForOp buildGetParentForOp(OpBuilder& b, Location& loc, + Value target, int64_t num_loops) { + auto pdlType = pdl::OperationType::get(b.getContext()); + return b.create(loc, pdlType, target, num_loops); +} + +transform_dialect::CacheReadOp buildCacheRead(OpBuilder& b, Location& loc, + Value target, Value anchor, + ArrayRef tileLevels, + ArrayRef tileSizes, + bool padded, + ArrayRef permutation) { + return b.create( + loc, target, anchor, tileLevels, tileSizes, padded, permutation); +} + +transform_dialect::LowerMultiLevelPackToLoopOp buildLowerMultiLevelPackToLoop( + OpBuilder& b, Location& loc, Value target) { + return b.create(loc, target); +} + +VectorizeOp buildVectorize(OpBuilder& b, Location& loc, Value target, + bool vectorizePad) { + auto pdlType = pdl::OperationType::get(b.getContext()); + return b.create(loc, pdlType, target, vectorizePad); +} + +transform_dialect::DISCBufferizeOp buildDISCBufferize(OpBuilder& b, + Location& loc, + Value target) { + auto pdlType = pdl::OperationType::get(b.getContext()); + return b.create(loc, pdlType, target); +} + +void buildLowerVectors(OpBuilder& b, Location& loc, ArrayRef stages, + StringRef contractionLowering, + StringRef multireductionLowering, + StringRef splitTransfers, bool unrollVectorTransfers, + StringRef transposeLowering, + bool transposeAvx2Lowering) { + b.create( + loc, b.getI64ArrayAttr(stages), contractionLowering, + multireductionLowering, splitTransfers, unrollVectorTransfers, + transposeLowering, transposeAvx2Lowering); +} + +LogicalResult aarch64GEMMDefaultScheduleFactory(PatternDescription& pd, + ModuleOp m) { + OpBuilder b(m); + b.setInsertionPointToStart(&m.getBodyRegion().front()); + Location loc = m.getLoc(); + MLIRContext* ctx = m->getContext(); + auto seqOp = b.create( + loc, TypeRange{}, transform::FailurePropagationMode::Propagate, Value{}); + seqOp.getBody().push_back(new Block); + auto& bodyBlock = seqOp.getBody().front(); + auto pdlOpType = pdl::OperationType::get(ctx); + bodyBlock.addArgument(pdl::OperationType::get(ctx), loc); + b.setInsertionPointToStart(&bodyBlock); + Value variant = bodyBlock.getArgument(0); + + // %fill = transform.structured.match ops{["linalg.fill"]} in %variant + Value fill = buildMatchOp(b, loc, variant, {"linalg.fill"}); + // %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + Value matmul = buildMatchOp(b, loc, variant, {"linalg.matmul"}); + + // transform.structured.tile_to_foreach_thread_op %matmul num_threads [1, 1] + auto forEachThreadOp = buildTileToForEachThreadOp(b, loc, matmul, {1, 1}); + Value forEachThreadLoop = forEachThreadOp->getResult(0); + Value tiledMatmul = forEachThreadOp->getResult(1); + + // transform.structured.fuse_into_containing_op %fill into %0#0 + auto fuseIntoContainingOp = + buildFuseIntoContainingOp(b, loc, fill, forEachThreadLoop); + + // first level tile and fuse matmul and fill op. + auto fuseOp0 = buildFuseOp(b, loc, tiledMatmul, {288, 48, 0}, {0, 1, 2}); + + // second level tile and fuse matmul and fill op. + auto fuseOp1 = + buildFuseOp(b, loc, fuseOp0->getResult(0), {6, 16, 0}, {0, 1, 2}); + + // gemm reduction axis tiling + auto tileOp = + buildTileOp(b, loc, fuseOp1->getResult(0), {0, 0, 1}, {0, 1, 2}); + + variant = buildRunCanonicalizer(b, loc, variant); + + // fold two extract_slice ops generated by two-level tiling. It's needed to + // enable following pad and hosit schedule. + Value weightInnerSlice = + buildGetProducerOfOperand(b, loc, tileOp->getResult(0), 1); + buildFoldProducerExtractSlice(b, loc, weightInnerSlice, 2); + + // pad to match the requirement of hardware vector/tensor instruction. + auto padOp = buildPadOp(b, loc, tileOp->getResult(0), {0, 1, 2}); + + Value padForInput = buildGetProducerOfOperand(b, loc, padOp, 0); + Value padForWeight = buildGetProducerOfOperand(b, loc, padOp, 1); + forEachThreadLoop = buildMatchOp(b, loc, variant, {"scf.foreach_thread"}); + auto outterLoopForN = buildGetParentForOp(b, loc, padForInput, 4); + buildCacheRead(b, loc, padForWeight, forEachThreadLoop, {1, 1}, {1, 16}, true, + {2, 0, 1, 3}); + buildCacheRead(b, loc, padForInput, outterLoopForN, {1, 1}, {6, 1}, true, + {0, 2, 3, 1}); + + variant = buildRunCanonicalizer(b, loc, variant); + + Value multiLevelPackOps = + buildMatchOp(b, loc, variant, {"disc_linalg_ext.multi_level_pack"}); + buildLowerMultiLevelPackToLoop(b, loc, multiLevelPackOps); + + variant = buildRunCanonicalizer(b, loc, variant); + + Value func = buildMatchOp(b, loc, variant, {"func.func"}); + buildVectorize(b, loc, func, true); + + variant = buildRunCanonicalizer(b, loc, variant); + variant = buildDISCBufferize(b, loc, variant); + + buildLowerVectors(b, loc, {0, 1, 2, 3, 4}, "outerproduct", "innerparallel", + "linalg-copy", true, "eltwise", false); + buildLowerVectors(b, loc, {5, 6, 7}, "outerproduct", "innerparallel", + "linalg-copy", true, "eltwise", false); + b.create(loc); + return success(); +} + +DISC_TRANSFORM_SCHEDULE(PatternKind::kGEMM, "", + aarch64GEMMDefaultScheduleFactory); + +} // namespace + +std::string patternKindToString(PatternKind kind) { + auto& map = getPatternKindToStringMap(); + auto it = map.find(kind); + if (it != map.end()) return it->second; + llvm_unreachable("unknown pattern kind"); + return ""; +} + +PatternKind patternKindFromString(const std::string& str) { + auto& map = getStringToPatternKindMap(); + auto it = map.find(str); + if (it != map.end()) return it->second; + llvm_unreachable("unknown pattern kind str"); + return PatternKind::kNone; +} + +PatternDescription::PatternDescription(lmhlo::FusionOp op, + FusionPattern& fusionPattern, + ShapeAnalysis& shapeAnalysis) + : op_(op), fusionPattern_(fusionPattern), shapeAnalysis_(shapeAnalysis) { + // TODO(wyzero): select the pattern kind according to the `fusionPattern`. + patternKind_ = PatternKind::kGEMM; +} + +PatternKind PatternDescription::getPatternKind() const { return patternKind_; } + +std::string PatternDescription::getPatternTagStr() const { + return getFusionTagStr(op_).str(); +} + +std::string PatternDescription::getTaggedPatternStr() const { + return patternKindToString(patternKind_) + "@" + getPatternTagStr(); +} + +/* static */ ScheduleFactoryRegistry& ScheduleFactoryRegistry::get() { + static ScheduleFactoryRegistry instance; + return instance; +} + +bool ScheduleFactoryRegistry::registerScheduleFactory(PatternKind kind, + const std::string& tag, + ScheduleFactory factory) { + return factoryMap_[kind].emplace(tag, factory).second; +} + +ScheduleFactory ScheduleFactoryRegistry::getScheduleFactory( + PatternKind kind, const std::string& tag) { + auto it = factoryMap_.find(kind); + if (it == factoryMap_.end()) return nullptr; + auto factoryIt = it->second.find(tag); + if (factoryIt == it->second.end()) return nullptr; + return factoryIt->second; +} + +ScheduleDispatcher::ScheduleDispatcher(const std::string& transformFileName) + : transformFileName_(transformFileName) {} + +LogicalResult ScheduleDispatcher::parseModuleFromFile(MLIRContext* ctx) { + if (transformFileName_.empty() || !parsedModuleMap_.empty()) return success(); + std::string expectedFormatStr = + "::;::<" + "filename-1>"; + SmallVector patternSettings; + StringRef(transformFileName_) + .split(patternSettings, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); + for (auto& patternSetting : patternSettings) { + SmallVector items; + patternSetting.split(items, ":", /*MaxSplit=*/-1, /*KeepEmpty=*/true); + if (items.size() != 3) { + llvm::dbgs() << "illegal transform file setting, expected format: " + << expectedFormatStr << "\n"; + return failure(); + } + PatternKind kind = patternKindFromString(items[0].str()); + if (kind == PatternKind::kNone) { + llvm::dbgs() << "illegal transform file setting, unknown pattern kind: " + << items[0] << "\n"; + return failure(); + } + + if (failed(parseTransformModuleFromFile( + ctx, items[2], parsedModuleMap_[kind][items[1].str()]))) { + llvm::dbgs() + << "illegal transform file setting, unable to load module from: " + << items[2] << "\n"; + return failure(); + } + } + return success(); +} + +bool ScheduleDispatcher::tryToApplyScheduleFromParsedFile( + PatternDescription& pd, ModuleOp m) { + auto it = parsedModuleMap_.find(pd.getPatternKind()); + if (it == parsedModuleMap_.end()) return false; + auto tagIt = it->second.find(pd.getPatternTagStr()); + if (tagIt == it->second.end()) return false; + OpBuilder b(m); + for (auto& op : tagIt->second.get().getBody()->getOperations()) { + if (!isa(&op)) continue; + m.push_back(b.clone(op)); + } + return true; +} + +LogicalResult ScheduleDispatcher::dispatch(PatternDescription& pd, ModuleOp m) { + if (failed(parseModuleFromFile(m.getContext()))) return failure(); + if (tryToApplyScheduleFromParsedFile(pd, m)) return success(); + + auto factory = ScheduleFactoryRegistry::get().getScheduleFactory( + pd.getPatternKind(), pd.getPatternTagStr()); + if (!factory) { + llvm::dbgs() << "no default schedule for pattern: " + << pd.getTaggedPatternStr() << "\n"; + return failure(); + } + + return factory(pd, m); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.h b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.h new file mode 100644 index 00000000000..1b05d6c9795 --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.h @@ -0,0 +1,130 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DISC_TRANSFORMS_TRANSFORM_SCHEDULE_H_ +#define DISC_TRANSFORMS_TRANSFORM_SCHEDULE_H_ + +#include +#include + +#include "tensorflow/compiler/mlir/disc/transforms/fusion_utils.h" + +namespace mlir { +namespace disc_ral { + +// PatternKind reprensets the category of a given schedule. +// For the same `PatternKind`, we may still have different schedule strategies +// for different shape range. We further use different tags to distinguish such +// schedules within the same category. +enum class PatternKind : int32_t { kNone, kGEMM }; + +// Converts a pattern kind to its string representation. +std::string patternKindToString(PatternKind kind); + +// Creates a pattern kind from its string representation. +PatternKind patternKindFromString(const std::string& str); + +// PatternDescription collects everything needed to assign schedule for a give +// fusion pattern. +class PatternDescription { + public: + explicit PatternDescription(lmhlo::FusionOp op, FusionPattern& fusionPattern, + ShapeAnalysis& shapeAnalysis); + + // Returns the kind of this `PatternDescription`. + PatternKind getPatternKind() const; + + // Returns the tags attached to this fusion pattern. + std::string getPatternTagStr() const; + + // Returns the full pattern kind str + tag str. + std::string getTaggedPatternStr() const; + + private: + lmhlo::FusionOp op_; + FusionPattern& fusionPattern_; + ShapeAnalysis& shapeAnalysis_; + PatternKind patternKind_; +}; + +// Factory used to assign specific schedule for the given PatternDescription +using ScheduleFactory = + std::function; + +// A registry for different schedule factories. +class ScheduleFactoryRegistry { + public: + // Returns the singleton + static ScheduleFactoryRegistry& get(); + + // Inserts the new `ScheduleFactory`. Returns true if inserted, otherwise + // false. + bool registerScheduleFactory(PatternKind kind, const std::string& tag, + ScheduleFactory); + + // Returns the schedule factory factor according to `kind` and `tag`. + // Returns nullptr if not found. + ScheduleFactory getScheduleFactory(PatternKind kind, const std::string& tag); + + private: + ScheduleFactoryRegistry() = default; + std::unordered_map> + factoryMap_; +}; + +// Macros used to define disc transform schedule factory. +#define DISC_TRANSFORM_SCHEDULE(kind, tag, ...) \ + DISC_TRANSFORM_SCHEDULE_UNIQ_HELPER(kind, tag, __COUNTER__, __VA_ARGS__) + +#define DISC_TRANSFORM_SCHEDULE_UNIQ_HELPER(kind, tag, ctr, ...) \ + DISC_TRANSFORM_SCHEDULE_UNIQ(kind, tag, ctr, __VA_ARGS__) + +#define DISC_TRANSFORM_SCHEDULE_UNIQ(kind, tag, ctr, ...) \ + static bool unused_ret_val_##ctr = \ + ::mlir::disc_ral::ScheduleFactoryRegistry::get() \ + .registerScheduleFactory(kind, tag, __VA_ARGS__); + +// Assign schedule for the given PatternDescription according to its kind and +// tag. +class ScheduleDispatcher { + public: + // Users may override the schedule by providing its own implementation and + // pass the schedule files to the dispatcher. + // Format of `transformFileName`: + // "::;::;" + explicit ScheduleDispatcher(const std::string& transformFileName); + + // Attaches a schedule for the given pattern description. + LogicalResult dispatch(PatternDescription& pd, ModuleOp m); + + private: + // Parses schedule modules from the given files. + LogicalResult parseModuleFromFile(MLIRContext* ctx); + // Returns true when applied, otherwise false. + bool tryToApplyScheduleFromParsedFile(PatternDescription& pd, ModuleOp m); + + private: + std::string transformFileName_; + // > + std::unordered_map>> + parsedModuleMap_; +}; + +} // namespace disc_ral +} // namespace mlir + +#endif // DISC_TRANSFORMS_TRANSFORM_SCHEDULE_H_ diff --git a/tao_compiler/mlir/disc/transforms/fusion_utils.cc b/tao_compiler/mlir/disc/transforms/fusion_utils.cc index 6a1d592a6d1..d05ed7c4098 100644 --- a/tao_compiler/mlir/disc/transforms/fusion_utils.cc +++ b/tao_compiler/mlir/disc/transforms/fusion_utils.cc @@ -262,6 +262,11 @@ void addFusionTag(OpBuilder& b, lmhlo::FusionOp op, StringRef tag) { op->setAttr(kFusionOpTagAttr, b.getStringAttr((Twine(oldTag) + tag).str())); } +StringRef getFusionTagStr(lmhlo::FusionOp op) { + auto attr = op->getAttrOfType(kFusionOpTagAttr); + return attr ? attr.getValue() : ""; +} + // Returns the full name of the fusion op // Here full name is composed of the name and tag of the fusion op. std::string getFusionFullName(lmhlo::FusionOp op) { diff --git a/tao_compiler/mlir/disc/transforms/fusion_utils.h b/tao_compiler/mlir/disc/transforms/fusion_utils.h index 52ce578bd2f..03fdbf288a9 100644 --- a/tao_compiler/mlir/disc/transforms/fusion_utils.h +++ b/tao_compiler/mlir/disc/transforms/fusion_utils.h @@ -417,6 +417,9 @@ void setFusionName(OpBuilder& b, lmhlo::FusionOp op, StringRef name); // Here different tags is mapping to different variants of the fusion op. void addFusionTag(OpBuilder& b, lmhlo::FusionOp op, StringRef tag); +// Returns the tag string attached to the fusion op. +StringRef getFusionTagStr(lmhlo::FusionOp op); + // Returns the full name of the fusion op // Here full name is composed of the name and tag of the fusion op. std::string getFusionFullName(lmhlo::FusionOp op);