From f84053c49ff98b7f7269e97ad557a8a4087b4f88 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 10 Dec 2022 19:31:29 -0800 Subject: [PATCH] Delete unused code in tensorflow/iree-dialects This code doesn't make part of the tensorflow import and seems to have just been maintained in severed state this past year. Delete instead of having folks update unused code. --- integrations/tensorflow/iree-dialects/BUILD | 490 --- .../include/iree-dialects-c/Dialects.h | 26 - .../iree-dialects/Dialect/CMakeLists.txt | 2 - .../Dialect/LinalgExt/CMakeLists.txt | 3 - .../Dialect/LinalgExt/IR/CMakeLists.txt | 28 - .../Dialect/LinalgExt/IR/LinalgExtBase.td | 114 - .../Dialect/LinalgExt/IR/LinalgExtDialect.h | 17 - .../LinalgExt/IR/LinalgExtInterfaces.h | 38 - .../LinalgExt/IR/LinalgExtInterfaces.td | 418 --- .../Dialect/LinalgExt/IR/LinalgExtOps.h | 32 - .../Dialect/LinalgExt/IR/LinalgExtOps.td | 1090 ------- .../Dialect/LinalgExt/Passes/CMakeLists.txt | 5 - .../Dialect/LinalgExt/Passes/PassDetail.h | 30 - .../Dialect/LinalgExt/Passes/Passes.h | 366 --- .../Dialect/LinalgExt/Passes/Passes.td | 220 -- .../Dialect/LinalgExt/Passes/Transforms.h | 73 - .../iree-dialects/Dialect/LinalgExt/README.md | 13 - .../LinalgExt/TransformOps/CMakeLists.txt | 9 - .../TransformOps/LinalgExtTransformOps.h | 43 - .../TransformOps/LinalgExtTransformOps.td | 104 - .../LinalgExt/Transforms/CodegenStrategy.h | 298 -- .../Dialect/LinalgExt/Transforms/Transforms.h | 485 --- .../Dialect/LinalgExt/Transforms/Utils.h | 122 - .../Dialect/LinalgExt/Utils/Utils.h | 66 - .../LinalgExt/Utils/WinogradConstants.h | 90 - .../Dialect/LinalgTransform/CMakeLists.txt | 20 - .../LinalgTransform/LinalgTransformOps.h | 29 - .../LinalgTransform/LinalgTransformOps.td | 74 - .../Dialect/LinalgTransform/Passes.h | 31 - .../Dialect/LinalgTransform/ScopedTransform.h | 31 - .../LinalgTransform/SimplePatternRewriter.h | 23 - .../StructuredTransformOpsExt.h | 87 - .../StructuredTransformOpsExt.td | 204 -- .../TransformInterpreterUtils.h | 40 - .../iree-dialects/Transforms/Listener.h | 121 - .../iree-dialects/Transforms/ListenerCSE.h | 21 - .../ListenerGreedyPatternRewriteDriver.h | 40 - .../Transforms/TransformMatchers.h | 470 --- .../iree-dialects/lib/CAPI/CMakeLists.txt | 4 - .../iree-dialects/lib/CAPI/Dialects.cpp | 43 - .../iree-dialects/lib/CMakeLists.txt | 1 - .../iree-dialects/lib/Dialect/CMakeLists.txt | 2 - .../lib/Dialect/LinalgExt/CMakeLists.txt | 5 - .../lib/Dialect/LinalgExt/IR/CMakeLists.txt | 34 - .../Dialect/LinalgExt/IR/LinalgExtDialect.cpp | 50 - .../LinalgExt/IR/LinalgExtInterfaces.cpp | 70 - .../lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp | 2889 ----------------- .../Dialect/LinalgExt/Passes/CMakeLists.txt | 32 - .../Passes/ConvertConv2DToWinograd.cpp | 400 --- .../LinalgExt/Passes/ConvertToLoops.cpp | 128 - .../Passes/FoldIntoPackAndUnpackOps.cpp | 102 - .../LinalgExt/Passes/MaterializeEncoding.cpp | 456 --- .../Passes/PadContractionToBlockSize.cpp | 141 - .../lib/Dialect/LinalgExt/Passes/Passes.cpp | 102 - .../LinalgExt/Passes/SplitReduction.cpp | 430 --- .../Passes/TileAndDecomposeWinogradPass.cpp | 381 --- .../lib/Dialect/LinalgExt/Passes/Tiling.cpp | 445 --- .../LinalgExt/TransformOps/CMakeLists.txt | 16 - .../TransformOps/LinalgExtTransformOps.cpp | 154 - .../LinalgExt/Transforms/CMakeLists.txt | 34 - .../LinalgExt/Transforms/CodegenStrategy.cpp | 46 - .../Transforms/ForeachThreadToAsync.cpp | 92 - .../ForeachThreadToSequentialFor.cpp | 123 - .../Dialect/LinalgExt/Transforms/Fusion.cpp | 60 - .../Dialect/LinalgExt/Transforms/Tiling.cpp | 209 -- .../LinalgExt/Transforms/Transforms.cpp | 1053 ------ .../Dialect/LinalgExt/Transforms/Utils.cpp | 96 - .../Dialect/LinalgExt/Utils/CMakeLists.txt | 12 - .../lib/Dialect/LinalgExt/Utils/Utils.cpp | 86 - .../Dialect/LinalgTransform/CMakeLists.txt | 2 - .../Dialect/LinalgTransform/IR/CMakeLists.txt | 38 - .../LinalgTransform/IR/LinalgTransformOps.cpp | 84 - .../LinalgTransform/IR/ScopedTransform.cpp | 83 - .../IR/StructuredTransformOpsExt.cpp | 1359 -------- .../LinalgTransform/Passes/CMakeLists.txt | 24 - .../Passes/ExpertExpansion.cpp | 118 - .../Passes/TransformInterpreter.cpp | 283 -- .../lib/Transforms/CMakeLists.txt | 19 - .../iree-dialects/lib/Transforms/Listener.cpp | 49 - .../lib/Transforms/ListenerCSE.cpp | 448 --- .../ListenerGreedyPatternRewriteDriver.cpp | 469 --- .../lib/Transforms/TransformMatchers.cpp | 359 -- .../iree-dialects/python/CMakeLists.txt | 27 - .../python/IREEDialectsModule.cpp | 50 - .../compiler/dialects/IreeLinalgExtBinding.td | 13 - .../dialects/IreeStructuredTransformOps.td | 13 - .../dialects/LinalgTransformBinding.td | 13 - .../_iree_linalg_transform_ops_ext.py | 111 - .../_iree_structured_transform_ops_ext.py | 137 - .../iree/compiler/dialects/iree_linalg_ext.py | 8 - .../dialects/iree_linalg_transform.py | 8 - .../Dialect/iree_linalg_ext/canonicalize.mlir | 42 - .../iree_linalg_ext/conv2d_to_winograd.mlir | 77 - .../iree_linalg_ext/convert_to_loops.mlir | 1443 -------- .../fold_into_pack_unpack_ops.mlir | 44 - .../foreach-thread-to-async.mlir | 57 - .../foreach-thread-to-scf-for.mlir | 51 - .../iree_linalg_ext/fuse-operands.mlir | 122 - .../test/Dialect/iree_linalg_ext/invalid.mlir | 694 ---- .../iree_linalg_ext/materialize_encoding.mlir | 166 - .../pad_contraction_to_block_size.mlir | 92 - .../Dialect/iree_linalg_ext/pad_tiling.mlir | 41 - .../resolve-shaped-type-result-dims.mlir | 31 - .../Dialect/iree_linalg_ext/roundtrip.mlir | 1000 ------ .../iree_linalg_ext/split-reduction.mlir | 164 - .../tile_and_decompose_winograd.mlir | 195 -- .../test/Dialect/iree_linalg_ext/tiling.mlir | 1424 -------- .../iree_linalg_ext/vectorization.mlir | 354 -- .../Dialect/linalg_transform/bufferize.mlir | 24 - .../linalg_transform/drop-schedule.mlir | 31 - .../test/Dialect/linalg_transform/expert.mlir | 168 - .../Dialect/linalg_transform/failure.mlir | 23 - .../Dialect/linalg_transform/invalid.mlir | 53 - .../Dialect/linalg_transform/roundtrip.mlir | 26 - .../test/Dialect/linalg_transform/scoped.mlir | 30 - .../linalg_transform/selective-targeting.mlir | 159 - .../single-tiling-full-script.mlir | 25 - .../test/lib/Dialect/CMakeLists.txt | 1 - .../Dialect/LinalgTransform/CMakeLists.txt | 12 - .../LinalgTransform/TestScopedTransform.cpp | 64 - .../test/lib/Transforms/CMakeLists.txt | 12 - .../lib/Transforms/TestListenerPasses.cpp | 103 - .../dialects/iree_structured_transform.py | 30 - .../iree-dialects/test/python/smoketest.py | 4 - .../tools/iree-dialects-opt/CMakeLists.txt | 8 - .../iree-dialects-opt/iree-dialects-opt.cpp | 19 - 126 files changed, 23573 deletions(-) delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/README.md delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/Passes.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/ScopedTransform.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/Listener.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h delete mode 100644 integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/ScopedTransform.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/ExpertExpansion.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Transforms/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/lib/Transforms/Listener.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Transforms/ListenerCSE.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp delete mode 100644 integrations/tensorflow/iree-dialects/lib/Transforms/TransformMatchers.cpp delete mode 100644 integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeLinalgExtBinding.td delete mode 100644 integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td delete mode 100644 integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/LinalgTransformBinding.td delete mode 100644 integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py delete mode 100644 integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py delete mode 100644 integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_ext.py delete mode 100644 integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_transform.py delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/canonicalize.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fold_into_pack_unpack_ops.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_contraction_to_block_size.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_tiling.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/split-reduction.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/expert.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/failure.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/invalid.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/scoped.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir delete mode 100644 integrations/tensorflow/iree-dialects/test/lib/Dialect/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/TestScopedTransform.cpp delete mode 100644 integrations/tensorflow/iree-dialects/test/lib/Transforms/CMakeLists.txt delete mode 100644 integrations/tensorflow/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp delete mode 100644 integrations/tensorflow/iree-dialects/test/python/dialects/iree_structured_transform.py diff --git a/integrations/tensorflow/iree-dialects/BUILD b/integrations/tensorflow/iree-dialects/BUILD index 46b52178e268..497431dcda30 100644 --- a/integrations/tensorflow/iree-dialects/BUILD +++ b/integrations/tensorflow/iree-dialects/BUILD @@ -30,8 +30,6 @@ filegroup( name = "TdFilegroup", srcs = glob([ "include/iree-dialects/Dialect/Input/*.td", - "include/iree-dialects/Dialect/LinalgExt/IR/*.td", - "include/iree-dialects/Dialect/LinalgExt/Passes/*.td", ]), ) @@ -39,9 +37,6 @@ td_library( name = "TdFiles", srcs = glob([ "include/iree-dialects/Dialect/Input/*.td", - "include/iree-dialects/Dialect/LinalgExt/IR/*.td", - "include/iree-dialects/Dialect/LinalgExt/Passes/*.td", - "include/iree-dialects/Dialect/LinalgTransform/*.td", "python/iree/compiler/dialects/*.td", ]) + [ "@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td", @@ -130,477 +125,6 @@ gentbl_filegroup( ], ) -################################################################################ -# IREELinalgExt Dialect -################################################################################ - -cc_library( - name = "IREEDialectsTransforms", - srcs = glob([ - "lib/Transforms/*.cpp", - ]), - hdrs = glob([ - "include/iree-dialects/Transforms/*.h", - ]), - includes = ["include"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformDialect", - "@llvm-project//mlir:Transforms", - ], -) - -gentbl_cc_library( - name = "IREELinalgExtIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - [ - "--dialect=iree_linalg_ext", - "--gen-dialect-decls", - ], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h.inc", - ), - ( - [ - "--dialect=iree_linalg_ext", - "--gen-dialect-defs", - ], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.cpp.inc", - ), - ( - ["--gen-attrdef-decls"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtAttrs.h.inc", - ), - ( - ["--gen-attrdef-defs"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtAttrs.cpp.inc", - ), - ( - ["--gen-enum-decls"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtEnums.h.inc", - ), - ( - ["--gen-enum-defs"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtEnums.cpp.inc", - ), - ( - ["--gen-op-decls"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc", - ), - ( - ["--gen-op-defs"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc", - ), - ( - ["--gen-typedef-decls"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypes.h.inc", - ), - ( - ["--gen-typedef-defs"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypes.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td", - deps = [ - ":TdFiles", - "@llvm-project//mlir:CallInterfacesTdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:TilingInterfaceTdFiles", - "@llvm-project//mlir:ViewLikeInterfaceTdFiles", - ], -) - -gentbl_cc_library( - name = "IREELinalgExtInterfacesIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["--gen-op-interface-decls"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.h.inc", - ), - ( - ["--gen-op-interface-defs"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc", - ), - ( - ["--gen-type-interface-decls"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypeInterfaces.h.inc", - ), - ( - ["--gen-type-interface-defs"], - "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypeInterfaces.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td", - deps = [ - ":TdFiles", - ], -) - -gentbl_cc_library( - name = "IREELinalgExtPassIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["--gen-pass-decls"], - "include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc", - ), - ( - ["--gen-pass-capi-header"], - "include/iree-dialects/Dialect/LinalgExt/Passes/Passes.capi.h.inc", - ), - ( - ["--gen-pass-capi-impl"], - "include/iree-dialects/Dialect/LinalgExt/Passes/Passes.capi.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td", - deps = [ - ":TdFiles", - "@llvm-project//mlir:PassBaseTdFiles", - ], -) - -cc_library( - name = "IREELinalgExtUtils", - srcs = glob([ - "lib/Dialect/LinalgExt/Utils/*.cpp", - ]), - hdrs = glob([ - "include/iree-dialects/Dialect/LinalgExt/Utils/*.h", - ]), - includes = ["include"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorUtils", - ], -) - -cc_library( - name = "IREELinalgExtDialect", - srcs = glob([ - "lib/Dialect/LinalgExt/IR/*.cpp", - ]), - hdrs = glob([ - "include/iree-dialects/Dialect/LinalgExt/IR/*.h", - ]), - includes = ["include"], - deps = [ - ":IREELinalgExtIncGen", - ":IREELinalgExtInterfacesIncGen", - ":IREELinalgExtPassIncGen", - ":IREELinalgExtUtils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:DestinationStyleOpInterface", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgUtils", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TilingInterface", - "@llvm-project//mlir:ViewLikeInterface", - ], -) - -# TODO(#9827): Remove aliases and/or backref. -alias( - name = "IREELinalgExtPasses", - actual = ":IREELinalgExtPassesAndTransforms", -) - -alias( - name = "IREELinalgExtTransforms", - actual = ":IREELinalgExtPassesAndTransforms", -) - -gentbl_cc_library( - name = "IREELinalgExtTransformOpsIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["--gen-op-decls"], - "include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h.inc", - ), - ( - ["--gen-op-defs"], - "include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td", - deps = [ - ":TdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -cc_library( - name = "IREELinalgExtTransformOps", - srcs = glob(["lib/Dialect/LinalgExt/TransformOps/*.cpp"]), - hdrs = glob(["include/iree-dialects/Dialect/LinalgExt/TransformOps/*.h"]), - deps = [ - ":IREEDialectsTransforms", - ":IREELinalgExtDialect", - ":IREELinalgExtTransformOpsIncGen", - ":IREELinalgExtTransforms", - ":IREELinalgTransformDialect", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:PDLDialect", - "@llvm-project//mlir:TransformDialect", - ], -) - -cc_library( - name = "IREELinalgExtPassesAndTransforms", - srcs = glob([ - "lib/Dialect/LinalgExt/Passes/*.cpp", - "lib/Dialect/LinalgExt/Transforms/*.cpp", - ]), - hdrs = glob([ - "include/iree-dialects/Dialect/LinalgExt/Passes/*.h", - "include/iree-dialects/Dialect/LinalgExt/Transforms/*.h", - ]), - deps = [ - ":IREEInputDialect", - ":IREELinalgExtDialect", - ":IREELinalgExtPassIncGen", - ":IREELinalgExtUtils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:AsyncDialect", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgStructuredOpsIncGen", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LinalgUtils", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorTransforms", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TilingInterface", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - "@llvm-project//mlir:VectorTransforms", - ], -) - -################################################################################ -# IREELinalgTransform Dialect -################################################################################ - -gentbl_cc_library( - name = "IREELinalgTransformIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - [ - "--dialect=iree_linalg_transform", - "--gen-dialect-decls", - ], - "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.h.inc", - ), - ( - [ - "--dialect=iree_linalg_transform", - "--gen-dialect-defs", - ], - "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.cpp.inc", - ), - ( - ["--gen-op-decls"], - "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h.inc", - ), - ( - ["--gen-op-defs"], - "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td", - deps = [ - ":TdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "IREELinalgTransformStructuredIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["--gen-op-decls"], - "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc", - ), - ( - ["--gen-op-defs"], - "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td", - deps = [ - ":TdFiles", - ], -) - -cc_library( - name = "IREELinalgTransformDialect", - srcs = glob([ - "lib/Dialect/LinalgTransform/IR/*.cpp", - "lib/Dialect/LinalgTransform/IR/*.h", - ]), - hdrs = glob([ - "include/iree-dialects/Dialect/LinalgTransform/*.h", - ]), - includes = ["include"], - deps = [ - ":IREEDialectsTransforms", - ":IREELinalgExtDialect", - ":IREELinalgExtPasses", - ":IREELinalgExtTransforms", - ":IREELinalgTransformIncGen", - ":IREELinalgTransformStructuredIncGen", - "@llvm-project//llvm:Support", - # Dialects - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:AsyncDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:PDLDialect", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformDialect", - # IR - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - # Interfaces - "@llvm-project//mlir:ControlFlowInterfaces", - - # Transforms - "@llvm-project//mlir:AsyncTransforms", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:AffineToStandard", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - # Utils - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:DialectUtils", - # Conversions - "@llvm-project//mlir:AsyncToLLVM", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:LinalgToLLVM", - "@llvm-project//mlir:LinalgToStandard", - "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MemRefToLLVM", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:VectorToLLVM", - ], -) - -cc_library( - name = "IREELinalgTransformDialectPasses", - srcs = glob([ - "lib/Dialect/LinalgTransform/Passes/*.cpp", - ]), - deps = [ - ":IREEDialectsTransforms", - ":IREELinalgExtDialect", - ":IREELinalgTransformDialect", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithTransforms", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:PDLDialect", - "@llvm-project//mlir:PDLInterpDialect", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorTransforms", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TransformDialect", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - "@llvm-project//mlir:VectorToLLVM", - "@llvm-project//mlir:VectorTransforms", - ], -) - ################################################################################ # CAPI ################################################################################ @@ -612,10 +136,6 @@ cc_library( includes = ["include"], deps = [ ":IREEInputDialect", - ":IREELinalgExtDialect", - ":IREELinalgExtTransformOps", - ":IREELinalgTransformDialect", - ":IREELinalgTransformDialectPasses", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransformOps", @@ -634,10 +154,6 @@ cc_library( "test/lib/**/*.cpp", ]), deps = [ - ":IREEDialectsTransforms", - ":IREELinalgExtDialect", - ":IREELinalgTransformDialect", - ":IREELinalgTransformDialectPasses", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Rewrite", @@ -656,14 +172,8 @@ cc_binary( ], tags = ["hostonly"], deps = [ - "IREELinalgExtTransforms", ":IREEDialectsTest", ":IREEInputDialect", - ":IREELinalgExtDialect", - ":IREELinalgExtPasses", - ":IREELinalgExtTransformOps", - ":IREELinalgTransformDialect", - ":IREELinalgTransformDialectPasses", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Dialects.h b/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Dialects.h index 2ababb17d1af..a605f9993986 100644 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Dialects.h +++ b/integrations/tensorflow/iree-dialects/include/iree-dialects-c/Dialects.h @@ -21,32 +21,6 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREEInput, iree_input); -//===--------------------------------------------------------------------===// -// IREELinalgExt -//===--------------------------------------------------------------------===// - -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREELinalgExt, iree_linalg_ext); - -//===--------------------------------------------------------------------===// -// LinalgTransform -//===--------------------------------------------------------------------===// - -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LinalgTransform, iree_linalg_transform); - -/// Register all passes for LinalgTransform. -MLIR_CAPI_EXPORTED void mlirIREELinalgTransformRegisterPasses(); - -//===--------------------------------------------------------------------===// -// TransformDialect -//===--------------------------------------------------------------------===// - -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); - -MLIR_CAPI_EXPORTED void ireeRegisterTransformExtensions(MlirContext context); - -/// Register all passes for the transform dialect. -MLIR_CAPI_EXPORTED void mlirIREETransformRegisterPasses(); - #ifdef __cplusplus } #endif diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt index 16d52d437fde..ab1d7407c50a 100644 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt +++ b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt @@ -1,3 +1 @@ add_subdirectory(Input) -add_subdirectory(LinalgExt) -add_subdirectory(LinalgTransform) diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt deleted file mode 100644 index 4391ced12426..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Passes) -add_subdirectory(TransformOps) diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/CMakeLists.txt b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/CMakeLists.txt deleted file mode 100644 index dfbb58fd9ada..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -function(_add_interfaces) - set(LLVM_TARGET_DEFINITIONS LinalgExtInterfaces.td) - mlir_tablegen(LinalgExtOpInterfaces.h.inc -gen-op-interface-decls) - mlir_tablegen(LinalgExtOpInterfaces.cpp.inc -gen-op-interface-defs) - mlir_tablegen(LinalgExtTypeInterfaces.h.inc -gen-type-interface-decls) - mlir_tablegen(LinalgExtTypeInterfaces.cpp.inc -gen-type-interface-defs) - add_public_tablegen_target(IREELinalgExtInterfacesIncGen) - add_dependencies(IREELinalgExtIncGen IREELinalgExtInterfacesIncGen) -endfunction() - -function(_add_dialect) - set(LLVM_TARGET_DEFINITIONS LinalgExtOps.td) - mlir_tablegen(LinalgExtAttrs.h.inc -gen-attrdef-decls) - mlir_tablegen(LinalgExtAttrs.cpp.inc -gen-attrdef-defs) - mlir_tablegen(LinalgExtEnums.h.inc -gen-enum-decls) - mlir_tablegen(LinalgExtEnums.cpp.inc -gen-enum-defs) - mlir_tablegen(LinalgExtOps.h.inc -gen-op-decls) - mlir_tablegen(LinalgExtOps.cpp.inc -gen-op-defs) - mlir_tablegen(LinalgExtTypes.h.inc -gen-typedef-decls) - mlir_tablegen(LinalgExtTypes.cpp.inc -gen-typedef-defs) - mlir_tablegen(LinalgExtDialect.h.inc --gen-dialect-decls --dialect=iree_linalg_ext) - mlir_tablegen(LinalgExtDialect.cpp.inc --gen-dialect-defs --dialect=iree_linalg_ext) - add_public_tablegen_target(IREELinalgExtIncGen) - add_dependencies(mlir-headers IREELinalgExtIncGen) -endfunction() - -_add_dialect() -_add_interfaces() diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td deleted file mode 100644 index c0530fe7021b..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECT_LINALGEXT_BASE -#define IREE_DIALECT_LINALGEXT_BASE - -include "mlir/IR/OpBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/EnumAttr.td" - -//===----------------------------------------------------------------------===// -// Dialect definition -//===----------------------------------------------------------------------===// - -def IREELinalgExt_Dialect : Dialect { - let name = "iree_linalg_ext"; - let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt"; - let description = [{ - The `iree_linalg_ext` dialect is intended to experiment more support for - non-structured operations, ie, can not be represented in Linalg operations. - }]; - let hasCanonicalizer = 1; - let useDefaultAttributePrinterParser = 1; -} - -//===----------------------------------------------------------------------===// -// Type definitions -//===----------------------------------------------------------------------===// - -class RankedTensorOrMemRefOf allowedTypes> : - ShapedContainerType]>, - "ranked tensor or memref", "::mlir::ShapedType">; - -def AnyRankedTensorOrMemRefType : RankedTensorOrMemRefOf<[AnyType]>; - -//===---------------------------------------------------------------------===// -// Data layout encoding attributes -//===---------------------------------------------------------------------===// - -class IREELinalgExt_Attr traits = []> - : AttrDef; - -// List of pre-defined data layout encoding attributes. -def MATMUL_F32F32F32_LHS - : I32EnumAttrCase<"MATMUL_F32F32F32_LHS", 0>; -def MATMUL_F32F32F32_RHS - : I32EnumAttrCase<"MATMUL_F32F32F32_RHS", 1>; -def MATMUL_F32F32F32_RHS_TRANSPOSE - : I32EnumAttrCase<"MATMUL_F32F32F32_RHS_TRANSPOSE", 2>; -def MATMUL_F32F32F32_RESULT - : I32EnumAttrCase<"MATMUL_F32F32F32_RESULT", 3>; -def MATMUL_I8I8I32_LHS - : I32EnumAttrCase<"MATMUL_I8I8I32_LHS", 4>; -def MATMUL_I8I8I32_RHS - : I32EnumAttrCase<"MATMUL_I8I8I32_RHS", 5>; -def MATMUL_I8I8I32_RHS_TRANSPOSE - : I32EnumAttrCase<"MATMUL_I8I8I32_RHS_TRANSPOSE", 6>; -def MATMUL_I8I8I32_RESULT - : I32EnumAttrCase<"MATMUL_I8I8I32_RESULT", 7>; - -def TensorEncodingEnum - : I32EnumAttr<"TensorEncoding", - "identifier for encoding used for the tensor",[ - MATMUL_F32F32F32_LHS, MATMUL_F32F32F32_RHS, MATMUL_F32F32F32_RHS_TRANSPOSE, MATMUL_F32F32F32_RESULT, - MATMUL_I8I8I32_LHS, MATMUL_I8I8I32_RHS, MATMUL_I8I8I32_RHS_TRANSPOSE, MATMUL_I8I8I32_RESULT, - ]> { - let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt"; - let genSpecializedAttr = 0; -} - -def TensorEncodingAttr : - EnumAttr { - let assemblyFormat = "``$value"; -} - -def IREELinalgExt_EncodingAttr : IREELinalgExt_Attr<"Encoding"> { - let mnemonic = "encoding"; - let summary = [{tensor layout encoding}]; - let description = [{ - This attribute describes the change in the layout for - a given tensor to execute subsequent operations on - the tiled layout. The encoding serves as a way to - represent the change in the way the data is laid out in - memory without changing the logical rank/extent of - the tensor itself. When required, the encoding - can be used to explicitly manifest the layout change - through operations like pack/unpack. - - Currently the encoding is just an enum that describes - in an ad-hoc fashions the data layouts we initially care - about. In fullness of time the encoding attribute can be - made richer. - }]; - - let parameters = (ins - AttrParameter<"IREE::LinalgExt::TensorEncodingAttr", - "Tensor encoding to use for a tensor">:$encoding - ); - - let assemblyFormat = [{ - `<` `` $encoding `>` - }]; - - let builders = [ - AttrBuilder<(ins "TensorEncoding":$encoding)> - ]; -} - - -#endif // IREE_DIALECT_LINALGEXT_BASE diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h deleted file mode 100644 index a9fe5d9c9414..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_ - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" - -// clang-format off: must be included after all LLVM/MLIR headers -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h.inc" // IWYU pragma: keep -// clang-format on - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h deleted file mode 100644 index c6e12e663703..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_ - -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { -class LinalgExtOp; - -namespace detail { -LogicalResult verifyLinalgExtOpInterface(Operation *op); -} - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export - -/// Include the generated interface declarations. -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.h.inc" // IWYU pragma: export - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td deleted file mode 100644 index 0757eb914851..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECT_LINALGEXT_INTERFACES -#define IREE_DIALECT_LINALGEXT_INTERFACES - -include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td" - -// The interface is a subset of LinalgStructuredInterface. -def LinalgExtInterface : OpInterface<"LinalgExtOp"> { - let methods = [ - //===------------------------------------------------------------------===// - // Num input/output arguments handling. - //===------------------------------------------------------------------===// - // `inputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the input shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"getInputs", - /*args=*/(ins) - >, - // These special methods rely on `inputs` and `outputs` being defined by - // each op that wants to implement the LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the number of inputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getInputs().size(); - }] - >, - // `outputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the output shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"getOutputs", - /*args=*/(ins) - >, - InterfaceMethod< - /*desc=*/[{ - Return the number of outputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getOutputs().size(); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the number of inputs and outputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputsAndOutputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return getNumInputs() + getNumOutputs(); - }] - >, - //===------------------------------------------------------------------===// - // Input operands handling. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Return the input operands. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numInputs = getNumInputs(); - OpOperandVector result; - result.reserve(numInputs); - llvm::transform( - this->getOperation()->getOpOperands().take_front(numInputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the `i`-th input operand. - }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - return &this->getOperation()->getOpOperand(i); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the subset of input operands that are of buffer type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputBufferOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the subset of input operands that are of tensor type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputTensorOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; - }] - >, - //===------------------------------------------------------------------===// - // Output operands handling. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Return the output operands. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numOutputs = getNumOutputs(); - OpOperandVector result; - result.reserve(numOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .drop_front(getNumInputs()) - .take_front(numOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the `i`-th output operand. - }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getOutputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - return &this->getOperation()->getOpOperand(getNumInputs() + i); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Set the `i`-th output operand. - }], - /*retTy=*/"void", - /*methodName=*/"setOutputOperand", - /*args=*/(ins "int64_t":$i, "Value":$value), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - this->getOperation()->setOperand(getNumInputs() + i, value); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the subset of output operands that are of buffer type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputBufferOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the subset of output operands that are of tensor type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputTensorOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the types of the subset of output operands that are of buffer type. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputBufferTypes", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputBufferOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the types of the subset of output operands that are of tensor type. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputTensorTypes", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputTensorOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] - >, - //===------------------------------------------------------------------===// - // Input and Output arguments handling. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Return the range over input and output operands. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputAndOutputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numInputsAndOutputs = getNumInputsAndOutputs(); - OpOperandVector result; - result.reserve(numInputsAndOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .take_front(numInputsAndOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return true if `opOperand` is an input tensor. - }], - /*retTy=*/"bool", - /*methodName=*/"isInputTensor", - /*args=*/(ins "OpOperand *":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() < $_op.getNumInputs()) - return true; - return false; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return true if `opOperand` is an output tensor. - }], - /*retTy=*/"bool", - /*methodName=*/"isOutputTensor", - /*args=*/(ins "OpOperand *":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() >= $_op.getNumInputs()) - return true; - return false; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the `opOperand` rank or zero for scalars. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getRank", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getRank(); - return 0; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the `opOperand` shape or an empty vector for scalars. - }], - /*retTy=*/"ArrayRef", - /*methodName=*/"getShape", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getShape(); - return {}; - }] - >, - //===------------------------------------------------------------------===// - // Non input and output operands handling - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Return operands that are neither inputs nor outputs. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getNonInputOrOutputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numInputsAndOutputs = getNumInputsAndOutputs(); - int64_t numOperands = this->getOperation()->getNumOperands(); - assert(numInputsAndOutputs <= numOperands); - if (numInputsAndOutputs == numOperands) - return {}; - OpOperandVector result; - result.reserve(numOperands - numInputsAndOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .drop_front(numInputsAndOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) {return &opOperand;}); - return result; - }] - > - ]; - - let extraClassDeclaration = [{ - /// Returns the value that expresses the shape of the output in terms of - /// shape of the input operands where possible. - LogicalResult reifyResultShapes(OpBuilder &b, - mlir::ReifiedRankedShapedTypeDims &reifiedReturnShapes); - - //========================================================================// - // Helper functions to mutate the `operand_segment_sizes` attribute. - // These are useful when cloning and changing operand types. - //========================================================================// - void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } - void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } - - private: - void setOperandSegmentAt(unsigned idx, unsigned val) { - auto attr = (*this)->getAttr("operand_segment_sizes") - .cast(); - unsigned i = 0; - auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), - [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); - getOperation()->setAttr("operand_segment_sizes", newAttr); - } - }]; - - let verify = [{ return detail::verifyLinalgExtOpInterface($_op); }]; -} - -#endif // IREE_DIALECT_LINALGEXT_INTERFACES diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h deleted file mode 100644 index bd01789563a0..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_ - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/TilingInterface.h" - -// clang-format off - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtEnums.h.inc" // IWYU pragma: export - -#define GET_ATTRDEF_CLASSES -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtAttrs.h.inc" // IWYU pragma: export - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export - -// clang-format on - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td deleted file mode 100644 index 25733c77a010..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td +++ /dev/null @@ -1,1090 +0,0 @@ - // Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECT_LINALGEXT_OPS -#define IREE_DIALECT_LINALGEXT_OPS - -include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td" -include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/DestinationStyleOpInterface.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/TilingInterface.td" -include "mlir/Interfaces/ViewLikeInterface.td" - - -def IREELinalgExt_DoNotDCEOperandsOp : - Op { - let summary = "Unfoldable op that just keeps its operands live"; - let description = [{ - Unfoldable op that just keeps its operands live. This is to use with the - transform dialect in case where transforms introduce IR that would be - otherwise DCE'd by canonicalizations. - - This op should be added to the transform dialect in the fullness of time but - it can't be registered dynamically on the IREE side as that triggers errors - since the op does not implement any transform interface. - }]; - - let arguments = (ins Variadic:$operands); - let results = (outs); - let assemblyFormat = "attr-dict $operands `:` type($operands)"; -} - -//===----------------------------------------------------------------------===// -// Base class. -//===----------------------------------------------------------------------===// - -class IREELinalgExt_PureOp traits = []> : - Op { -} - -class IREELinalgExt_Op traits = []> : - IREELinalgExt_PureOp, - DestinationStyleOpInterface, LinalgExtInterface, - SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp"> - ])> { - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; - code extraLinalgExtOpClassDeclaration = ""; -} - -//===----------------------------------------------------------------------===// -// Non-structured ops -//===----------------------------------------------------------------------===// - -def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter", - [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Scatter operator"; - let description = [{ - Based on XLA operation semantics, takes two `inputs` (`update` and - `indices`) and `outputs` value (`original`). The operation updates - the value at the slices specified by `indices` by combining the - current value with the value in `updates` using the computation - specified in `region`. The `region` specifies a binary operation - of signature (T, T) -> T, where `T` is the element-type of - `updates` (and `original`). The first argument correspond the - value to be updated (i.e. from `updates`), and the second the - current value (i.e. value from `original`). - - The `indices` is a 2D tensor/memref type. The first dim is the number of - updates, and the second dim is index depth. The index depth should always be - static. - - The first dim of `updates` and `indices` is identical, since they represent - the number of updates. - - The rank of the `original`/`result` is at least - `index_depth + rank(%updates) - 1`. The first `index_depth` indices are - derived from `indices` and the shape of update value has the last - rank(%original) - index_depth values match %(originals) last dimensions, - with the previous dims extending from the index offsets. - - The dimension_map attributes describes which index value maps to which - dimension in the destionation. It cannot contain duplicate values, must - have as many entries as index depth, and values must be within the rank of - the destination. - - The unique_indices attribute carries the information whether all the indices - are unique. If there are repeated indices, the first iteration loop will be - marked as reduction. - - The shapes definition follows tensorflow operations execept that it force - batch dims to be 1D. See more information in - https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update - }]; - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs, - DenseI64ArrayAttr:$dimension_map, - DefaultValuedAttr:$unique_indices - ); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - let assemblyFormat = [{ - attr-dict `dimension_map` `=` $dimension_map - `unique_indices` `(` $unique_indices `)` - (`ins` `(` $inputs^ `:` type($inputs) `)`)? - `outs` `(` $outputs `:` type($outputs) `)` - $region (`->` type($results)^)? - }]; - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - - int64_t getIndexDepth() { - return getInputOperand(1) - ->get() - .getType() - .cast() - .getShape() - .back(); - } - - Value updates() { - return getInputOperand(0)->get(); - } - - ShapedType getUpdateType() { - return updates().getType().cast(); - } - - Value indices() { - return getInputOperand(1)->get(); - } - - ShapedType getIndicesType() { - return indices().getType().cast(); - } - - Value original() { - return getOutputOperand(0)->get(); - } - - ShapedType getOriginalType() { - return original().getType().cast(); - } - - int64_t getUpdateSliceRank() { - return updates().getType().cast().getRank() - 1; - } - - bool isScalarUpdate() { - return getUpdateSliceRank() == 0; - } - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_SortOp : IREELinalgExt_Op<"sort", - [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Sort operator"; - let description = [{ - Based on XLA operation semantics, sorts the given `operands` at the given - `dimension` with the given `comparator`. - - See https://www.tensorflow.org/xla/operation_semantics#sort. - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64Attr:$dimension - ); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - let assemblyFormat = [{ - attr-dict - `dimension` `(` $dimension `)` - (`ins` `(` $inputs^ `:` type($inputs) `)`)? - `outs` `(` $outputs `:` type($outputs) `)` - $region (`->` type($results)^)? - }]; - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value operand(int index) { - return getOutputs()[index]; - } - ShapedType getOperandType(int index) { - return operand(index).getType().cast(); - } - int64_t getOperandRank() { - return getOperandType(0).getRank(); - } - ArrayRef getOperandShape() { - return getOperandType(0).getShape(); - } - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_FftOp : IREELinalgExt_Op<"fft", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Fft operator"; - let description = [{ - Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive. - Thus, the bit reversal is assumed applied on the input. The op carries an - input -- stage, which indicates the level of reduction loop in the - algorithm. It represents the computation body. For more details, see - "Data reordering, bit reversal, and in-place algorithms" section in - https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm - - The size of innermost dim is expected to be a power of 2. - - It is optional to carry coefficient tensors/buffers as inputs. In this - context, they will be the second and third inputs. - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs - ); - let results = (outs Variadic:$results); - let assemblyFormat = [{ - attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? - `outs` `(` $outputs `:` type($outputs) `)` - (`:` type($results)^)? - }]; - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value getStage() { return getInputs()[0]; } - Value getReal() { return getOutputs()[0]; } - Value getImag() { return getOutputs()[1]; } - bool hasCoeff() { return getNumInputs() > 1; } - void generateScalarImplWithoutCoeffBuf( - OpBuilder & b, Location loc, ArrayRef operands, Value wholeSize); - void generateScalarImplWithCoeffBuf(OpBuilder & b, Location loc, - ArrayRef operands); - Value getRealCoeff() { - if (!hasCoeff()) return Value(); - return getInputs()[1]; - } - Value getImagCoeff() { - if (!hasCoeff()) return Value(); - return getInputs()[2]; - } - ShapedType getOperandType() { - return getReal().getType().cast(); - } - int64_t getOperandRank() { - return getOperandType().getRank(); - } - ArrayRef getOperandShape() { - return getOperandType().getShape(); - } - int64_t getFftLength() { - return getOperandShape().back(); - } - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan", - [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Scan operator"; - let description = [{ - Computes the inclusive/exclusive scan along a given dimension. - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64Attr:$dimension, - BoolAttr:$inclusive - ); - - let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)> - ]; - - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - let hasFolder = 1; - let assemblyFormat = [{ - attr-dict - `dimension` `(` $dimension `)` - `inclusive` `(` $inclusive `)` - `ins` `(` $inputs `:` type($inputs) `)` - `outs` `(` $outputs `:` type($outputs) `)` - $region (`->` type($results)^)? - }]; - - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value input() { - return getInputOperand(0)->get(); - } - Value accumulator() { - return getOutputOperand(1)->get(); - } - Value output() { - return getOutputOperand(0)->get(); - } - ShapedType getOperandType() { - return input().getType().cast(); - } - int64_t getOperandRank() { - return getOperandType().getRank(); - } - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods< - TilingInterface, - ["generateScalarImplementation", - "getIterationDomain", - "getLoopIteratorTypes", - "getResultTilePosition", - "getTiledImplementation"]>, - DeclareOpInterfaceMethods]> { - let summary = "Reverse operator"; - let description = [{ - A temporary solution for lowering reverse ops into IREE, allowing IREE to - tile and distribute them. - } - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64ElementsAttr:$dimensions - ); - let results = (outs Variadic:$results); - let assemblyFormat = [{ - attr-dict `dimensions` `(` $dimensions `)` - (`ins` `(` $inputs^ `:` type($inputs) `)`)? - (`outs` `(` $outputs^ `:` type($outputs) `)`)? - (`:` type($results)^)? - }]; - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value input() { - return getInputOperand(0)->get(); - } - Value output() { - return getOutputOperand(0)->get(); - } - ShapedType getOperandType() { - return input().getType().cast(); - } - int64_t getOperandRank() { - return getOperandType().getRank(); - } - ArrayRef getOprerandShape() { - return getOperandType().getShape(); - } - SmallVector dims() { - SmallVector ret; - for (const APInt& elem : getDimensions()) { - ret.push_back(elem.getLimitedValue()); - } - return ret; - } - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods -]>{ - let summary = "Top-K operator"; - let description = [{ - A Top-K operation for N-D tensors. Reduces the target dimension from the input - size N down to K elements based on the supplied binary region. - - Accepts an N-D tensor input consisting of values and an optioanl N-D tensor - for indices of those values (i32 type). If input indices aren't provided, the - index mapping is inferred based on the k dim. Both input values/indices - tensors and output values/indicies tensors must have the same shape. Top-K is - computed along the target dimension (from dimension()). Returns two output - tensors of values and the indicies of Top-K results. The output dimensions - must match the input save for the dimension that is reduced to K results. - - Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an - i1. If true, the two values are swapped: - - For Top-K compoarision: > - - For Min-K comparision: < - Note: when the two values are equal, the first occurence is always selected. - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64Attr:$dimension - ); - - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - let assemblyFormat = [{ - attr-dict - `dimension` `(` $dimension `)` - `ins` `(` $inputs `:` type($inputs) `)` - `outs` `(` $outputs `:` type($outputs) `)` - $region (`->` type($results)^)? - }]; - - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value values() { - return getInputOperand(0)->get(); - } - Optional indices() { - if (getNumInputs() < 2) { - return {}; - } else { - return getInputOperand(1)->get(); - } - } - Value outputValues() { - return getOutputOperand(0)->get(); - } - Value outputIndices() { - return getOutputOperand(1)->get(); - } - ShapedType getInputType() { - return values().getType().cast(); - } - int64_t getInputRank() { - return getInputType().getRank(); - } - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_PackOp : IREELinalgExt_Op<"pack", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods -]>{ - let summary = "pack operation"; - let description = [{ - The pack operation converts an `input` into a tiled and packed layout. The - dimensions to be tiled are obtained from `inner_dims_pos` and the size of the - tile is obtained from `inner_tiles`. The dimensions listed in `inner_dims_pos` - do not need to be contiguous in which case the tile will get transposed. We - handle only full tiles if `padding_value` is not set; it is UB if the tile does - not perfectly divide the dimension. If `padding_value` is set, it will pad - along high dimensions, i.e., it pads at the bottom and on the right if the - input has rank 2, and the result type shape, will be dynamic in any dimension - if and only if the input shape is. As optional input, the operation takes - `outer_dims_perm` that allows to permute the tiled loops. - - Example KC_to_KCck: - - ```mlir - iree_linalg_ext.pack %arg0 inner_dims_pos = [1, 0] - inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<16x8x32x8xf32>) - ``` - - Example NC_to_NCnc: - - ```mlir - iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %arg1 : (memref<128x256xf32> memref<16x8x8x32xf32>) - ``` - Example KC_to_CKkc - - ```mlir - iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<32x4x32x8xf32>) - ``` - - In all cases, dimension at position 0 in the input memref (128) is tiled - with a factor of 8, while dimension at position 1 (256) is tiled with a factor - of 32. In the KC_to_KCck example, the point loops are interchanged, while in the - KC_to_CKkc example the tiled loops. - - Example NC_to_NCnc with padding: - - ```mlir - iree_linalg_ext.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1] - inner_tiles = [8, 2] into %arg1 : (memref<13x15xf32> memref<2x8x8x2xf32>) - ``` - - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - DefaultValuedOptionalAttr:$outer_dims_perm, - DenseI64ArrayAttr:$inner_dims_pos, - Variadic:$inner_tiles, - DenseI64ArrayAttr:$static_inner_tiles, - Optional:$padding_value); - - let results = (outs Variadic:$results); - let assemblyFormat = [{ - attr-dict - $inputs - (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? - (`outer_dims_perm` `=` $outer_dims_perm^)? - `inner_dims_pos` `=` $inner_dims_pos - `inner_tiles` `=` - custom($inner_tiles, $static_inner_tiles) - `into` $outputs `:` `(` type($inputs) type($outputs) `)` - (`->` type($results)^)? - }]; - - let builders = [ - OpBuilder<(ins "Value":$source, "Value":$output, - "ArrayRef":$innerDimsPos, - "ArrayRef":$innerTiles, - CArg<"Optional", "llvm::None">:$paddingValue, - CArg<"ArrayRef", "{}">:$outerDimsPerm)> - ]; - - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - - // Return the output operand. - Value getOutput() { - return getOutputOperand(0)->get(); - } - - // Return the input operand. - Value getInput() { - return getInputOperand(0)->get(); - } - - // Return the output rank. - int64_t getOutputRank() { - return getOutputType().getRank(); - } - - // Return the output type. - ShapedType getOutputType() { - return getOutput().getType(); - } - - // Return the input type. - ShapedType getInputType() { - return getInput().getType(); - } - - // Return the output shape. - ArrayRef getOutputShape() { - return getOutputType().getShape(); - } - - // Return the input shape. - ArrayRef getInputShape() { - return getInputType().getShape(); - } - - // Return the element type. - Type getElementType() { - return getInputType().getElementType(); - } - - // Return the rank of the input operand. - int64_t getInputRank() { - return getInputType().getRank(); - } - - // Return the tile sizes. - SmallVector getMixedTiles(); - SmallVector getStaticTiles(); - - // Return a mapping from positions `dims_pos` to their tile factors. - DenseMap getDimAndTileMapping(); - - // Method to get the shape of the result as `SmallVector`. - // This is a static method to allow getting the shape of the destination - // expected while creating a `pack` op. - static SmallVector getResultShape(OpBuilder &builder, - Location loc, ArrayRef sourceDims, - ArrayRef innerTileDims, ArrayRef innerDimsPos, - ArrayRef outerDimsPerm = {}); - // Method to return the shape of the result as `SmallVector`. - SmallVector getResultShape(OpBuilder &builder); - - // Method to get the `ShapedType` of the result. This is a static method - // to allow getting the type of the destination while creating the `pack` - // op. - static ShapedType getPackedType(ShapedType sourceType, - ArrayRef innerTileSizes, ArrayRef innerDimsPos, - ArrayRef outerDimsPerm = {}); - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_UnPackOp : IREELinalgExt_Op<"unpack", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods -]>{ - let summary = "unpack operation"; - - let description = [{ - The unpack operation converts a tiled and packed input to an unpacked - output. See `pack` for more details on `inner_tiles` and `dims_pos`; it is UB - if the tile does not perfectly divide the dimension. Optionally, the operation - also supports permuting the tiled loops. - - Example KCck_to_KC: - - ```mlir - iree_linalg_ext.unpack %arg0 dims_pos = [1, 0] - inner_tiles = [32, 8] into %arg1 : (memref<16x8x32x8xf32> memref<128x256xf32>) - ``` - - Example NCnc_to_NC: - - ```mlir - iree_linalg_ext.unpack %arg0 dims_pos = [0, 1] - inner_tiles = [8, 32] into %arg1 : (memref<16x8x8x32xf32> memref<128x256xf32>) - ``` - - Example CKkc_to_KC: - - ```mlir - iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [32, 8] into %arg0 : (memref<32x4x32x8xf32> memref<128x256xf32>) - ``` - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - DefaultValuedOptionalAttr:$outer_dims_perm, - DefaultValuedAttr:$inner_dims_pos, - Variadic:$inner_tiles, - DenseI64ArrayAttr:$static_inner_tiles); - - let results = (outs Variadic:$results); - let assemblyFormat = [{ - attr-dict - $inputs - (`outer_dims_perm` `=` $outer_dims_perm^)? - `inner_dims_pos` `=` $inner_dims_pos - `inner_tiles` `=` - custom($inner_tiles, $static_inner_tiles) - `into` $outputs `:` `(` type($inputs) type($outputs) `)` - (`->` type($results)^)? - }]; - - let builders = [ - OpBuilder<(ins "Value":$source, "Value":$output, - "ArrayRef":$innerDimsPos, - "ArrayRef":$innerTiles, - CArg<"ArrayRef", "{}">:$outerDimsPerm)> - ]; - - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - - // Return the output operand. - Value getOutput() { - return getOutputOperand(0)->get(); - } - - // Return the input operand. - Value getInput() { - return getInputOperand(0)->get(); - } - - // Return the output rank. - int64_t getOutputRank() { - return getOutputType().getRank(); - } - - // Return the output type. - ShapedType getOutputType() { - return getOutput().getType(); - } - - // Return the input type. - ShapedType getInputType() { - return getInput().getType(); - } - - // Return the output shape. - ArrayRef getOutputShape() { - return getOutputType().getShape(); - } - - // Return the input shape. - ArrayRef getInputShape() { - return getInputType().getShape(); - } - - // Return the rank of the input operand. - int64_t getInputRank() { - return getInputType().getRank(); - } - - // Return the tile sizes. - SmallVector getMixedTiles(); - SmallVector getStaticTiles(); - - // Return a mapping from positions `dims_pos` to their tile factors. - DenseMap getDimAndTileMapping(); - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_SetEncodingOp : IREELinalgExt_PureOp<"set_encoding",[ - DeclareOpInterfaceMethods, Pure - ]> { - let summary = "perform pack and pad operation on source"; - let description = [{ - Operation to assign an encoding to a tensor. The operation - does not change the rank or extent of a tensor. Instead it - adds an encoding attribute to the tensor type to represent - a change in layout. - }]; - - let arguments = (ins AnyRankedTensor:$source); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - attr-dict $source `:` type($source) `->` type($result) - }]; - -let builders = [ - OpBuilder<(ins "Value":$source, "TensorEncoding":$encoding)> - ]; - let hasVerifier = 1; - - let extraClassDeclaration = [{ - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - RankedTensorType getResultType() { - return getResult().getType().cast(); - } - TensorEncoding getResultTensorEncoding() { - return getResultType().getEncoding().cast() - .getEncoding().getValue(); - } - }]; -} - -def IREELinalgExt_UnsetEncodingOp : IREELinalgExt_PureOp<"unset_encoding", [ - DeclareOpInterfaceMethods, Pure - ]> { - let summary = "perfom unpack and extract operation on source"; - let description = [{ - Operation to convert an tensor with encoding that represents - its data layout into a tensor with default layout (i.e. no encoding). - For now in IREE the default layout is row-major. - }]; - let arguments = (ins AnyRankedTensor:$source); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - attr-dict $source `:` type($source) `->` type($result) - }]; - -let builders = [ - OpBuilder<(ins "Value":$source)> - ]; - let hasVerifier = 1; - - let extraClassDeclaration = [{ - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - RankedTensorType getResultType() { - return getResult().getType().cast(); - } - TensorEncoding getSourceTensorEncoding() { - return getSourceType().getEncoding().cast() - .getEncoding().getValue(); - } - }]; -} - -//===----------------------------------------------------------------------===// -// Winograd ops -//===----------------------------------------------------------------------===// - -def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_transform", - [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Winograd Input Transform operator"; - let description = [{ - This operator is the first step in converting a convolution to - its Winograd equivalent. Given a tile of an input image (I), - this operator computes matmul(tranpose(B), matmul(I, B)). - The input tile is assumed to be square with each side of size m + r - 1, - where the convolutional kernel is m x m and the output tile size is r x r. - B is a constant 2-d square matrix of the same shape as the input tile I. - The input to the operator is an image of shape (N, H, W, C) and the - output is an operator of shape (m + r - 1, m + r - 1, N, H', W', C) - where H' = ceil((H - m + 1)/r) and W' = ceil((W - m + 1)/r). The result - of this operator is first collapsed and then fed to a batch matmul op. - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64Attr:$output_tile_size, - I64Attr:$kernel_size, - DenseI64ArrayAttr:$image_dimensions - ); - - let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size, - CArg<"ArrayRef", "{1, 2}">:$image_dimensions)> - ]; - - let results = (outs Variadic:$result); - let hasFolder = 1; - let assemblyFormat = [{ - attr-dict - `output_tile_size` `(` $output_tile_size `)` - `kernel_size` `(` $kernel_size `)` - `image_dimensions` `(` $image_dimensions `)` - `ins` `(` $inputs `:` type($inputs) `)` - `outs` `(` $outputs `:` type($outputs) `)` - (`->` type($result)^)? - }]; - - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value input() { - return getInputOperand(0)->get(); - } - Value output() { - return getOutputOperand(0)->get(); - } - ShapedType getInputOperandType() { - return input().getType().cast(); - } - ShapedType getOutputOperandType() { - return output().getType().cast(); - } - int64_t getInputOperandRank() { - return getInputOperandType().getRank(); - } - int64_t getOutputOperandRank() { - return getOutputOperandType().getRank(); - } - int64_t getInputTileSize() { - return getOutputTileSize() + getKernelSize() - 1; - } - SmallVector imageDimensions() { - return llvm::to_vector(getImageDimensions()); - } - int64_t getIterationDomainRank() { - SmallVector imageDims = imageDimensions(); - return getInputOperandRank() - imageDims.size(); - } - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_transform", - [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Winograd Output Transform operator"; - let description = [{ - This operator is the last transform in converting a convolution to - its Winograd equivalent. After convolution in the Winograd domain - (which turns into an elementwise product for a single channel and - batch matrix multiplication for many channels), this operator converts - the output back into the original domain. Given a tile of the - output (O) in the Winograd domain, this operator computes - matmul(transpose(A), matmul(O, A)). The output tile is square with - each side of size m + r - 1, where the convolutional kernel is m x m - and the output tile size is r x r. A is a constant 2-d matrix of - shape (m + r - 1) x r. The input to the operator is a tensor of - shape (m + r - 1, m + r - 1, N, H', W', C) and the output is a - tensor of shape (N, H, W, C) where H = r H' and W = r W'. This operator - is followed by a tensor.extract_slice which extracts only the non-padded - part of the output. - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64Attr:$output_tile_size, - I64Attr:$kernel_size, - DenseI64ArrayAttr:$image_dimensions - ); - - let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size, - CArg<"ArrayRef", "{1, 2}">:$image_dimensions)> - ]; - - let results = (outs Variadic:$result); - let hasFolder = 1; - let assemblyFormat = [{ - attr-dict - `output_tile_size` `(` $output_tile_size `)` - `kernel_size` `(` $kernel_size `)` - `image_dimensions` `(` $image_dimensions `)` - `ins` `(` $inputs `:` type($inputs) `)` - `outs` `(` $outputs `:` type($outputs) `)` - (`->` type($result)^)? - }]; - - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value input() { - return getInputOperand(0)->get(); - } - Value output() { - return getOutputOperand(0)->get(); - } - ShapedType getInputOperandType() { - return input().getType().cast(); - } - ShapedType getOutputOperandType() { - return output().getType().cast(); - } - SmallVector imageDimensions() { - return llvm::to_vector(getImageDimensions()); - } - int64_t getInputOperandRank() { - return getInputOperandType().getRank(); - } - int64_t getOutputOperandRank() { - return getOutputOperandType().getRank(); - } - int64_t getIterationDomainRank() { - SmallVector imageDims = imageDimensions(); - return getOutputOperandRank() - imageDims.size(); - } - int64_t getInputTileSize() { - return getOutputTileSize() + getKernelSize() - 1; - } - // Method to implement for specifying output range for - // DestinationStyleOpInterface - std::pair getDpsInitsPositionRange() { - std::pair outputsIndexAndLength = - getODSOperandIndexAndLength(1); - return std::make_pair( - outputsIndexAndLength.first, - outputsIndexAndLength.first + outputsIndexAndLength.second); - } - }]; -} - -//===----------------------------------------------------------------------===// -// Pure ops -//===----------------------------------------------------------------------===// - -def IREELinalgExt_YieldOp : IREELinalgExt_PureOp<"yield", [Pure, ReturnLike, Terminator]> { - let summary = "LinalgExt yield op"; - let description = [{ - `iree_linalg_ext.yield` is a special terminator operation for blocks inside - regions in `iree_linalg_ext` ops. - }]; - - let arguments = (ins Variadic:$operands); - - let builders = [ - OpBuilder<(ins), [{ /* nothing to do */ }]>, - ]; - - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; -} - -#endif // IREE_DIALECT_LINALGEXT_OPS diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/CMakeLists.txt b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/CMakeLists.txt deleted file mode 100644 index 07379ca71495..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) -mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header) -mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl) -add_public_tablegen_target(IREELinalgExtPassesIncGen) diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h deleted file mode 100644 index e07c82487a4a..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_ - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -#define GEN_PASS_CLASSES - -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: keep - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h deleted file mode 100644 index 1d16af50f375..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_ - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { - -class ConversionTarget; -class TypeConverter; - -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { -// Marker used as attribute name in generated Linalg rewriting transformations. -struct LinalgTransforms { - static const StringLiteral kLinalgTransformMarker; -}; - -/// Helper class to control application of linalg transformation patterns. -/// Control comes in 2 forms: -/// 1. attribute matching and setting behavior using the attribute named -/// `kLinalgTransformMarker`. This can be used to build a state machine -/// using attributes and incrementally applying patterns to advance states. -/// 2. filter function, which is a simple lambda on the Operation* that -/// returns a LogicalResult. -struct LinalgTransformationFilter { - using FilterFunction = std::function; - - explicit LinalgTransformationFilter( - ArrayRef matchDisjunction = {}, - Optional replacement = None); - - explicit LinalgTransformationFilter( - const FilterFunction &f, ArrayRef matchDisjunction = {}, - Optional replacement = None); - - LinalgTransformationFilter(LinalgTransformationFilter &&) = default; - LinalgTransformationFilter(const LinalgTransformationFilter &) = default; - LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; - void replaceLinalgTransformationFilter(PatternRewriter &rewriter, - Operation *op) const; - bool hasReplacementFilter(Operation *op) const; - - LinalgTransformationFilter &addFilter(const FilterFunction &f) { - if (f) - filters.push_back(f); - return *this; - } - - template - LinalgTransformationFilter &addOpFilter() { - return addFilter( - [](Operation *op) { return success(isa(op)); }); - } - - LinalgTransformationFilter &addOpNameFilter(StringRef opName) { - return addFilter([opName](Operation *op) { - return success(op->getName().getStringRef() == opName); - }); - } - - LinalgTransformationFilter &setMatchByDefault() { - matchByDefault = true; - return *this; - } - -private: - SmallVector filters; - SmallVector matchDisjunction; - Optional replacement; - /// When set to true, if the attribute is not set, it will be treated as - /// a match. Default is false. - bool matchByDefault; -}; - -std::unique_ptr> createTilingInterfaceTilingPass(); - -std::unique_ptr> createLinalgExtToLoopsPass(); - -/// Container of information needed to materialize the pack operation. -struct MaterializeEncodingInfo { - SmallVector innerDimsPos; - SmallVector innerTileSizes; - SmallVector outerDimsPerm; - unsigned srcRank = 0; -}; -using MaterializeEncodingFn = - std::function(RankedTensorType)>; - -/// TypeConverter to use for materializing the encoding. -struct MaterializeEncodingTypeConverter : public TypeConverter { - MaterializeEncodingTypeConverter(MaterializeEncodingFn materializeEncodingFn); - MaterializeEncodingFn &getMaterializeEncodingFn() { - return materializeEncodingFn; - } - -private: - MaterializeEncodingFn materializeEncodingFn; -}; - -/// Conversion target to use for for materializing the encoding. -struct MaterializeEncodingConversionTarget : public ConversionTarget { - MaterializeEncodingConversionTarget(MLIRContext &context); -}; - -/// Base class for patterns that materialize encoding. -template -struct OpMaterializeEncodingPattern : public OpConversionPattern { - OpMaterializeEncodingPattern(MaterializeEncodingTypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} -}; - -/// Method to populate the patterns to convert operations that have operands -/// with tensor encodings into ops that materialize the layout specified by the -/// encoding, as well as ops that perform the computation on the materialized -/// layout. For now these hard-code a fixed way the lowering is encoded, but the -/// encoding can be made backend specific. Also initializes the -/// `conversionTarget` and `typeConverter`. -void populateMaterializeEncodingPatterns( - RewritePatternSet &patterns, - MaterializeEncodingConversionTarget &conversionTarget, - MaterializeEncodingTypeConverter &typeConverter); - -/// Pass to apply patterns specified by `populateMaterializeEncodingPass`. -std::unique_ptr> createMaterializeEncodingPass(); - -/// Patterns to fold operations like `tensor.pad` and `tensor.extract_slice` -/// into `linalg_ext.pack` and `linalg_ext.unpack` operations respectively. -void populateFoldIntoPackAndUnpackOpsPatterns(RewritePatternSet &patterns); - -/// Pass to apply patterns specified by `populateFoldIntoPackAndUnpackOps`. -std::unique_ptr> createFoldIntoPackAndUnpackOps(); - -std::unique_ptr> createPadContractionToBlockSizePass(); - -/// Function signature to control reduction splitting. This returns the split -/// reduction ratio used to split the reduction dimension. The ratio is applied -/// to the reduction dimension of TopK. If the ratio value is less or equal to 1 -/// then nothing will be done. Input is the current depth of recursive split -/// reduction, starting from 0 (first level). -using TopkSplitReductionControlFn = - std::function; - -/// Patterns to apply `topk split reduction` pass. -void populateTopkSplitReductionPattern( - RewritePatternSet &patterns, - const TopkSplitReductionControlFn &splitReductionFn, - const LinalgExt::LinalgTransformationFilter &f = - LinalgExt::LinalgTransformationFilter()); - -std::unique_ptr> createTopkSplitReductionPass(); - -std::unique_ptr> createLinalgExtVectorizationPass(); - -/// Tile and decompose the winograd transform ops into a sequence -/// of linalg ops. -std::unique_ptr> -createTileAndDecomposeWinogradTransformPass(); - -// Creates a pass to convert linalg convolution ops into a sequence of -// linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd -// tranformation. -std::unique_ptr createConvertConv2DToWinogradPass(); - -// Marker used as attribute the depth of the split reduction transformations. -const StringLiteral kSplitReductionDepthMarker = "__split_reduction_depth__"; - -//===---------------------------------------------------------------------===// -// Codegen Strategy passes that are moved into IREE. -//===---------------------------------------------------------------------===// -/// Options to control the application of enabling transformations. -/// Hoisting transformations are always deemed beneficial and must be disabled -/// explicitly. -struct LinalgEnablingOptions { - /// Enable LICM. - bool licm = true; - LinalgEnablingOptions &enableLICM(bool val = true) { - licm = val; - return *this; - } - /// Enable hoisting of redundant vector transfer ops. - bool hoistRedundantVectorTransfers = true; - LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) { - hoistRedundantVectorTransfers = val; - return *this; - } - /// Enable hoisting of redundant vector transfer ops on tensor. - bool hoistRedundantVectorTransfersOnTensor = true; - LinalgEnablingOptions & - enableHoistRedundantVectorTransfersOnTensor(bool val = true) { - hoistRedundantVectorTransfersOnTensor = val; - return *this; - } -}; - -/// Create a LinalgStrategyTileAndFusePass. -std::unique_ptr> -createLinalgStrategyTileAndFusePass( - StringRef opName = "", const scf::SCFTileAndFuseOptions &options = {}, - const LinalgExt::LinalgTransformationFilter &filter = - LinalgExt::LinalgTransformationFilter()); - -/// Create a LinalgStrategyTilePass. -std::unique_ptr> createLinalgStrategyTilePass( - StringRef opName = "", - const scf::SCFTilingOptions &options = scf::SCFTilingOptions(), - const LinalgExt::LinalgTransformationFilter &filter = - LinalgExt::LinalgTransformationFilter()); - -/// Create a LinalgStrategyPadPass. -std::unique_ptr> createLinalgStrategyPadPass( - StringRef opName = "", - const linalg::LinalgPaddingOptions &opt = linalg::LinalgPaddingOptions(), - const LinalgExt::LinalgTransformationFilter &filter = - LinalgExt::LinalgTransformationFilter()); - -/// Create a LinalgStrategyDecomposePass. -// TODO: if/when we need finer control add an `opName` parameter. -std::unique_ptr> createLinalgStrategyDecomposePass( - const LinalgExt::LinalgTransformationFilter &filter = - LinalgExt::LinalgTransformationFilter()); - -/// Create a LinalgStrategyPeelPass. -using LoopsToPeelComputationFunction = std::function &)>; - -struct LinalgPeelOptions { - LoopsToPeelComputationFunction loopsToPeelComputationFunction = nullptr; -}; -std::unique_ptr> createLinalgStrategyPeelPass( - StringRef opName = "", const LinalgPeelOptions &opt = LinalgPeelOptions(), - const LinalgExt::LinalgTransformationFilter &filter = - LinalgExt::LinalgTransformationFilter()); - -/// Create a LinalgStrategyVectorizePass. -std::unique_ptr> createLinalgStrategyVectorizePass( - StringRef opName = "", - const LinalgExt::LinalgTransformationFilter &filter = - LinalgExt::LinalgTransformationFilter(), - bool padVectorize = false); - -/// Create a LinalgStrategyEnablePass. -std::unique_ptr> createLinalgStrategyEnablePass( - LinalgEnablingOptions opt = LinalgEnablingOptions(), - const LinalgExt::LinalgTransformationFilter &filter = - LinalgExt::LinalgTransformationFilter()); - -/// Create a LinalgStrategyLowerVectorsPass. -/// Vector lowering options control how ops are lowered down to 1-D and scf.for -/// form. -struct LinalgVectorLoweringOptions { - /// Enable lowering of vector.contract. - /// In a progressive lowering of vectors, this would be the 1st step. - bool contractionLowering = false; - LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) { - contractionLowering = val; - return *this; - } - /// Enable lowering of vector.multi_reduce. - /// In a progressive lowering of vectors, this would be the 2nd step. - bool multiReductionLowering = false; - LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) { - multiReductionLowering = val; - return *this; - } - /// Trigger full / partial vector.transfer splits. - /// In a progressive lowering of vectors, this would be the 3rd step. - bool transferPartialRewrite = false; - LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) { - transferPartialRewrite = val; - return *this; - } - /// Enable lowering of vector.transfer to scf. - /// In a progressive lowering of vectors, this would be the 4th step. - bool transferToSCFConversion = false; - LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) { - transferToSCFConversion = val; - return *this; - } - /// Maximal transfer rank under which we do not lower further. - int64_t maxTransferRank = 1; - LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { - maxTransferRank = val; - return *this; - } - /// Vector lowering operations may result in surprising behavior when - /// composing multiple codegen strategies and must be enabled explicitly. - /// In a progressive lowering of vectors, this would be the 5th step. - bool transferLowering = true; - LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) { - transferLowering = val; - return *this; - } - /// Enable lowering of vector.shape_cast to insert/extract. - /// In a progressive lowering of vectors, this would be the 6th step. - bool shapeCastLowering = true; - LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) { - shapeCastLowering = val; - return *this; - } - /// Enable lowering of vector.transpose. - /// In a progressive lowering of vectors, this would be the 7th step. - bool transposeLowering = false; - LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) { - transposeLowering = val; - return *this; - } - /// Enable AVX2-specific lowerings. - bool avx2Lowering = false; - LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) { - avx2Lowering = val; - return *this; - } - - /// Configure the post staged-patterns late vector.transfer to scf - /// conversion. - VectorTransferToSCFOptions vectorTransferToSCFOptions; - LinalgVectorLoweringOptions & - setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { - vectorTransferToSCFOptions = options; - return *this; - } - /// Configure late vector transformations. - vector::VectorTransformsOptions vectorTransformOptions; - LinalgVectorLoweringOptions & - setVectorTransformsOptions(vector::VectorTransformsOptions options) { - vectorTransformOptions = options; - return *this; - } - /// Configure specialized vector lowerings. - x86vector::avx2::LoweringOptions avx2LoweringOptions; - LinalgVectorLoweringOptions & - setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) { - avx2LoweringOptions = options; - return *this; - } -}; - -std::unique_ptr> -createLinalgStrategyLowerVectorsPass( - LinalgVectorLoweringOptions opt = LinalgVectorLoweringOptions(), - const LinalgTransformationFilter &filter = LinalgTransformationFilter()); - -/// Create a LinalgStrategyRemoveMarkersPass. -std::unique_ptr> -createLinalgStrategyRemoveMarkersPass(); - -void registerPasses(); - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td deleted file mode 100644 index badb6584ea06..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECT_LINALGEXT_PASSES -#define IREE_DIALECT_LINALGEXT_PASSES - -include "mlir/Pass/PassBase.td" - -def LinalgExtToLoops : - Pass<"iree-linalg-ext-to-loops", "func::FuncOp"> { - let summary = "Convert LinalgExt ops to loops and Linalg ops."; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::createLinalgExtToLoopsPass()"; -} - -def TilingInterfaceTiling : - Pass<"iree-linalg-ext-tile", "func::FuncOp"> { - let summary = "Test pass for tiling using TiledOpInterface"; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::createTilingInterfaceTilingPass()"; -} - -def MaterializeEncoding : - Pass<"iree-linalg-ext-materialize-encoding", "func::FuncOp"> { - let summary = "Test pass to materialize ops with tensor encoding into ops with explicit data movement"; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::createMaterializeEncodingPass()"; -} - -def FoldIntoPackAndUnpackOps : - Pass<"iree-linalg-ext-fold-into-pack-unpack-ops", "func::FuncOp"> { - let summary = "Test pass to fold operations into pack and unpacl operations"; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::createFoldIntoPackAndUnpackOps()"; -} - -def PadContractionToBlockSize : - Pass<"iree-linalg-pad-contraction-to-block-size", ""> { - let summary = "Pads contraction (matmul) ops to next multiple of block size"; - let description = [{ - This pass will apply padding to any supported linalg contractions: - * Row-major matmul: - Padded to - - Both rowAlignment and columnAlignment must be power-of-two values. If an - op is already statically padded properly, no change will be made. However, - if dynamic dimensions exist, padding will be applied regardless. Because - of the dynamic case, applying this pass multiple times can result in - mutation on each run. - }]; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::createPadContractionToBlockSizePass()"; - let options = [ - Option<"rowAlignment", "rowAlignment", "int", /*default=*/"16", - "The row-wise output block size">, - Option<"columnAlignment", "columnAlignment", "int", /*default=*/"16", - "The column-wise output block size">, - ]; -} - -def TopkSplitReduction: - Pass<"iree-linalg-ext-topk-split-reduction", "func::FuncOp"> { - let summary = "Topk split reduction pass."; - let description = [{ - Produces a "map-reduce" style of parallelizing a Topk Op. The op is split - into two, on containing reducitons in parallel and the other contianing the - combination of the parallel reductions into a final result. - }]; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::createTopkSplitReductionPass()"; - let options = [ - ListOption<"splitRatios", "split-ratios", "int", - "List of split reduction ratios">, - ]; -} - -def LinalgExtVectorization: - Pass<"iree-linalg-ext-vectorization", "func::FuncOp"> { - let summary = "Vectorization pass for LinalgExt pack ops."; - let description = [{ - Vectorizes LinalgExt ops when they meet the conditions, e.g., having static - shapes, etc. - }]; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::" - "createLinalgExtVectorizationPass()"; -} - -def TileAndDecomposeWinogradTransform : - Pass<"iree-linalg-ext-tile-and-decompose-winograd", "func::FuncOp"> { - let summary = - "Tiles and decomposes winograd transform ops into linalg ops"; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::" - "createTileAndDecomposeWinogradTransformPass()"; -} - -def ConvertConv2DToWinograd : - Pass<"iree-linalg-ext-convert-conv2d-to-winograd", ""> { - let summary = "Convert linalg convolution ops to winograd based implementation"; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::createConvertConv2DToWinogradPass()"; -} - - -//===---------------------------------------------------------------------====// -// Codegen Strategy passes moved into IREE -// TODO: Deprecate all this. -//===---------------------------------------------------------------------====// - -def LinalgStrategyTileAndFusePass - : Pass<"iree-linalg-strategy-tile-and-fuse-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based tiling and fusion."; - let constructor = "createLinalgStrategyTileAndFusePass()"; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - ]; -} - -def LinalgStrategyTilePass - : Pass<"iree-linalg-strategy-tile-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based linalg tiling."; - let constructor = "createLinalgStrategyTilePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - ]; -} - -def LinalgStrategyPadPass - : Pass<"iree-linalg-strategy-pad-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply padding and hoisting."; - let constructor = "createLinalgStrategyPadPass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - ]; -} - -// TODO: if/when we need finer control add an anchorOp option. -def LinalgStrategyDecomposePass - : Pass<"iree-linalg-strategy-decompose-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based generalization."; - let constructor = "createLinalgStrategyDecomposePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - -def LinalgStrategyPeelPass - : Pass<"iree-linalg-strategy-peel-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based linalg peeling."; - let constructor = "createLinalgStrategyPeelPass()"; - let dependentDialects = [ - "linalg::LinalgDialect", - "scf::SCFDialect" - ]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - ]; -} - -def LinalgStrategyVectorizePass - : Pass<"iree-linalg-strategy-vectorize-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based linalg vectorization."; - let constructor = "createLinalgStrategyVectorizePass()"; - let dependentDialects = ["linalg::LinalgDialect", "vector::VectorDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - Option<"vectorizePadding", "vectorize-padding", "bool", "false", - "Enable vectorization of padding ops.">, - ]; -} - -def LinalgStrategyEnablePass - : Pass<"iree-linalg-strategy-enable-pass", "func::FuncOp"> { - let summary = "Configurable pass to enable the application of other " - "pattern-based linalg passes."; - let constructor = "createLinalgStrategyEnablePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - -def LinalgStrategyLowerVectorsPass - : Pass<"iree-linalg-strategy-lower-vectors-pass", "func::FuncOp"> { - let summary = "Configurable pass to lower vector operations."; - let constructor = "createLinalgStrategyLowerVectorsPass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - -def LinalgStrategyRemoveMarkersPass - : Pass<"iree-linalg-strategy-remove-markers-pass", "func::FuncOp"> { - let summary = "Cleanup pass that drops markers."; - let constructor = "createLinalgStrategyRemoveMarkersPass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - -#endif // IREE_DIALECT_LINALGEXT_PASSES diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h deleted file mode 100644 index ab41b39f4258..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_PASSES_TRANSFORMS_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_PASSES_TRANSFORMS_H_ - -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Interfaces/TilingInterface.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -/// Structure to represent the result of tiling operation. -struct TiledOp { - /// Tiled operations that are created during tiling. - SmallVector op; - /// Loops generated during tiling. - SmallVector loops; - /// Values that are replacements for the untiled operations. - SmallVector results; -}; - -/// Main entry point for tiling LinalgExtOps using TiledOpInterface. -FailureOr tileLinalgExtOp(OpBuilder &b, TilingInterface tilableOp, - const linalg::LinalgTilingOptions &options); - -/// Base rewrite pattern to tile and distribute operations that implement the -/// `TiledOpInterface`. -/// Base pattern for tiling TiledOpInterfaceOps. -struct TilingInterfaceBaseTilingPattern - : public OpInterfaceRewritePattern { - TilingInterfaceBaseTilingPattern( - MLIRContext *context, linalg::LinalgTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), filter(filter), - options(options) {} - - LogicalResult matchAndRewriteBase(TilingInterface tilableOp, - PatternRewriter &rewriter, - TiledOp &result) const; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// Options to control tiling; - linalg::LinalgTilingOptions options; -}; - -struct TilingInterfaceTilingPattern : public TilingInterfaceBaseTilingPattern { - TilingInterfaceTilingPattern( - MLIRContext *context, linalg::LinalgTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : TilingInterfaceBaseTilingPattern(context, options, filter, benefit) {} - - LogicalResult matchAndRewrite(TilingInterface tilableOp, - PatternRewriter &rewriter) const; -}; - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_PASSES_TRANSFORMS_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/README.md b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/README.md deleted file mode 100644 index 60509c11c36e..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/README.md +++ /dev/null @@ -1,13 +0,0 @@ -This folder defines dialects, interfaces, operations and transformations that are -- experimental -- meant to eventually be upstreamed to LLVM. - -These are used (or will be used) within IREE as and when required. They are not -meant to be part of "features" that IREE exposes, or part of IREEs public -API. Their use within IREE is an internal implementation detail. - -Some of the transformations here might not be as well tested as others, mostly -depending on how load-bearing it is within IREE. Those that are heavily used are -expected to be well tested, but that might not be the case for experimental -features. They are expected to achieve the same level of fidelity and testing as -upstream MLIR when they are being transitioned out of IREE. diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/CMakeLists.txt b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/CMakeLists.txt deleted file mode 100644 index 29eb8233029b..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -function(_add_transform_dialect_extension) - set(LLVM_TARGET_DEFINITIONS LinalgExtTransformOps.td) - mlir_tablegen(LinalgExtTransformOps.h.inc -gen-op-decls) - mlir_tablegen(LinalgExtTransformOps.cpp.inc -gen-op-defs) - add_public_tablegen_target(IREELinalgExtTransformOpsIncGen) - add_dependencies(mlir-headers IREELinalgExtTransformOpsIncGen) -endfunction() - -_add_transform_dialect_extension() diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h deleted file mode 100644 index 1fe0baf3579a..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMOPS_LINALGEXTTRANSFORMOPS_H -#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMOPS_LINALGEXTTRANSFORMOPS_H - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/PDL/IR/PDLTypes.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace scf { -class ForOp; -class ForeachThreadOp; -} // namespace scf -} // namespace mlir - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h.inc" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { -class LinalgExtTransformOpsExtension - : public transform::TransformDialectExtension< - LinalgExtTransformOpsExtension, IREELinalgExtDialect> { -public: - LinalgExtTransformOpsExtension(); -}; -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMOPS_LINALGEXTTRANSFORMOPS_H diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td deleted file mode 100644 index 28b7dbce12d8..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECT_LINALGEXT_TRANSFORMOPS -#define IREE_DIALECT_LINALGEXT_TRANSFORMOPS - -include "mlir/Dialect/PDL/IR/PDLTypes.td" -include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/OpBase.td" - -def FuseProducersOp : Op]> { - let description = [{Fuses the producers for the operands to fuse.}]; - - let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$operands_to_fuse); - let results = (outs PDL_Operation:$transformed, - Variadic:$fused_ops); - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; - let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt"; -} - -def RewriteForeachThreadToAsyncOp : - Op { - - let description = [{ - Rewrite a bufferized scf.foreach_thread op to the async dialect. - - Return modes: - ============= - This operation ignores non-Linalg ops and drops them in the return. - This transform is currently only implemented for 1-D scf.foreach_thread that - have been bufferized and definitely fail for the rest. - - If all the operations referred to by the `target` PDLOperation lower - properly, the transform succeeds. Otherwise the transform silently fails. - - The returned handle points to only the subset of successfully produced - async operations, which can all be empty. - }]; - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$transformed); - - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::scf::ForeachThreadOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, - ::mlir::transform::TransformState &state); - }]; -} - -def RewriteForeachThreadToScfForOp : - Op { - - let description = [{ - Rewrite a bufferized scf.foreach_thread to a sequential scf.for. - - Return modes: - ============= - This operation ignores non-Linalg ops and drops them in the return. - This transform is currently only implemented for 1-D scf.foreach_thread that - have been bufferized and definitely fail for the rest. - - If all the operations referred to by the `target` PDLOperation lower - properly, the transform succeeds. Otherwise the transform silently fails. - - The returned handle points to only the subset of successfully produced - scf.for operations, which can all be empty. - }]; - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$transformed); - - let assemblyFormat = "$target attr-dict"; - let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::scf::ForeachThreadOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, - ::mlir::transform::TransformState &state); - }]; -} - -#endif // IREE_DIALECT_LINALGEXT_TRANSFORMOPS diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h deleted file mode 100644 index d803588ef89a..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_ - -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Pass/PassManager.h" - -#include - -//===----------------------------------------------------------------------===// -// Strategies moved from upstream MLIR as IREE still heavily relies on patterns -// that compose through filters. -// TODO: Deprecate everything below. -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -/// Abstract Transformation class applied in a sequence that also handles state -/// through markers. -struct Transformation { - explicit Transformation( - LinalgExt::LinalgTransformationFilter::FilterFunction f) - : filter(std::move(f)) {} - virtual ~Transformation() = default; - virtual void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const = 0; - LinalgExt::LinalgTransformationFilter::FilterFunction filter = nullptr; -}; - -/// Represent one application of LinalgStrategyTileAndFusePass. -struct TileAndFuse : public Transformation { - TileAndFuse(StringRef name, scf::SCFTileAndFuseOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), opName(name), - options(std::move(options)) {} - - void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m)); - } - -private: - std::string opName; - scf::SCFTileAndFuseOptions options; -}; - -/// Represent one application of LinalgStrategyTilePass. -struct Tile : public Transformation { - Tile(StringRef name, scf::SCFTilingOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), opName(name), - options(std::move(options)) {} - - void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyTilePass(opName, options, m)); - } - -private: - std::string opName; - scf::SCFTilingOptions options; -}; - -/// Represent one application of LinalgStrategyPadPass. -struct Pad : public Transformation { - Pad(StringRef name, linalg::LinalgPaddingOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), opName(name), - options(std::move(options)) {} - - void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyPadPass(opName, options, m)); - } - -private: - std::string opName; - linalg::LinalgPaddingOptions options; -}; - -/// Represent one application of createLinalgStrategyDecomposePass. -struct Decompose : public Transformation { - explicit Decompose( - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)) {} - - void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyDecomposePass(m)); - } -}; - -/// Represent one application of createLinalgStrategyPeelPass. -struct Peel : public Transformation { - explicit Peel( - LinalgPeelOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), options(options) {} - - Peel(StringRef name, LinalgPeelOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), opName(name), options(options) {} - - void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyPeelPass(opName, options, m)); - } - -private: - std::string opName; - LinalgPeelOptions options; -}; - -/// Represent one application of createLinalgStrategyVectorizePass. -struct Vectorize : public Transformation { - explicit Vectorize( - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr, - bool padVectorize = false) - : Transformation(std::move(f)), vectorizePadding(padVectorize) {} - - Vectorize(StringRef name, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr, - bool padVectorize = false) - : Transformation(std::move(f)), opName(name), - vectorizePadding(padVectorize) {} - - void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyVectorizePass(opName, m, vectorizePadding)); - } - -private: - std::string opName; - bool vectorizePadding; -}; - -/// Represent one application of createLinalgStrategyLowerVectorsPass. -struct VectorLowering : public Transformation { - explicit VectorLowering( - LinalgVectorLoweringOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), options(options) {} - - void - addToPassPipeline(OpPassManager &pm, - LinalgExt::LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyLowerVectorsPass(options, m)); - } - -private: - LinalgVectorLoweringOptions options; -}; - -/// Codegen strategy controls how a Linalg op is progressively lowered. -struct CodegenStrategy { - /// Append a pattern to tile the Op `opName` and fuse its producers with - /// tiling and fusion `options`. - CodegenStrategy & - tileAndFuse(StringRef opName, const scf::SCFTileAndFuseOptions &options, - const LinalgExt::LinalgTransformationFilter::FilterFunction &f = - nullptr) { - transformationSequence.emplace_back( - std::make_unique(opName, options, f)); - return *this; - } - /// Conditionally append a pattern to tile the Op `opName` and fuse its - /// producers with tiling and fusion `options`. - CodegenStrategy &tileAndFuseIf( - bool b, StringRef opName, scf::SCFTileAndFuseOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this; - } - /// Append a pattern to add a level of tiling for Op `opName` with tiling - /// `options`. - CodegenStrategy & - tile(StringRef opName, const scf::SCFTilingOptions &options, - const LinalgExt::LinalgTransformationFilter::FilterFunction &f = - nullptr) { - transformationSequence.emplace_back( - std::make_unique(opName, options, f)); - return *this; - } - /// Conditionally append a pattern to add a level of tiling for - /// `LinalgOpType` with tiling `options`. - CodegenStrategy & - tileIf(bool b, StringRef opName, scf::SCFTilingOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? tile(opName, std::move(options), std::move(f)) : *this; - } - /// Append a pattern to pad and hoist the operands of Op `opName` with padding - /// `options`. - CodegenStrategy & - pad(StringRef opName, const linalg::LinalgPaddingOptions &options, - const LinalgExt::LinalgTransformationFilter::FilterFunction &f = - nullptr) { - transformationSequence.emplace_back( - std::make_unique(opName, options, f)); - return *this; - } - /// Conditionally append a pattern to pad and hoist the operands of Op - /// `opName` with padding `options`. - CodegenStrategy & - padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? pad(opName, std::move(options), std::move(f)) : *this; - } - /// Append patterns to decompose convolutions. - CodegenStrategy & - decompose(const LinalgExt::LinalgTransformationFilter::FilterFunction &f = - nullptr) { - transformationSequence.emplace_back(std::make_unique(f)); - return *this; - } - /// Conditionally append patterns to decompose convolutions. - CodegenStrategy &decomposeIf( - bool b, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? decompose(std::move(f)) : *this; - } - /// Append a pattern to peel 'LinalgOpType'. - CodegenStrategy & - peel(StringRef opName, const LinalgPeelOptions &options, - const LinalgExt::LinalgTransformationFilter::FilterFunction &f = - nullptr) { - transformationSequence.emplace_back( - std::make_unique(opName, options, f)); - return *this; - } - /// Conditionally append a pattern to peel 'LinalgOpType'. - CodegenStrategy & - peelIf(bool b, StringRef opName, const LinalgPeelOptions &options, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? peel(opName, options, std::move(f)) : *this; - } - /// Append a pattern to rewrite `LinalgOpType` as a vector operation. - CodegenStrategy &vectorize( - StringRef opName, - const LinalgExt::LinalgTransformationFilter::FilterFunction &f = nullptr, - bool vectorizePadding = false) { - transformationSequence.emplace_back( - std::make_unique(opName, f, vectorizePadding)); - return *this; - } - /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector - /// operation. - CodegenStrategy & - vectorizeIf(bool b, StringRef opName, - LinalgExt::LinalgTransformationFilter::FilterFunction f = nullptr, - bool vectorizePadding = false) { - return b ? vectorize(opName, std::move(f), vectorizePadding) : *this; - } - /// Append a pattern to lower all vector operations. - CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) { - transformationSequence.emplace_back( - std::make_unique(options)); - return *this; - } - /// Configure the post staged-patterns global enabling passes options. - CodegenStrategy & - setVectorTransferToSCFOptions(LinalgEnablingOptions options) { - linalgEnablingOptions = options; - return *this; - } - - /// Apply the transformation patterns in sequence with cleanup - /// transformations interleaved. - void configurePassPipeline(OpPassManager &pm, MLIRContext *context, - bool addEnablePass = true) const; - -private: - LogicalResult postPatternTransforms(Operation *func) const; - - LinalgEnablingOptions linalgEnablingOptions; - SmallVector, 4> transformationSequence; -}; - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h deleted file mode 100644 index 34db971559dc..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h +++ /dev/null @@ -1,485 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_ - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace scf { -class ForOp; -class ForeachThreadOp; -} // namespace scf -namespace linalg { -class LinalgOp; -} - -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -/// Pattern to swap a `TilingInterface` op -> `tensor::ExtractSliceOp`. -struct SwapTilingInterfaceOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - FailureOr - returningMatchAndRewrite(tensor::ExtractSliceOp sliceOp, - PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(sliceOp, rewriter); - } -}; - -/// Pattern to rewrite a scf::ForEachThreadOp to the async dialect. -struct ForeachThreadOpToAsyncRewriter - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - FailureOr - returningMatchAndRewrite(scf::ForeachThreadOp foreachThreadOp, - PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(scf::ForeachThreadOp foreachThreadOp, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(foreachThreadOp, rewriter); - } -}; - -/// Pattern to rewrite a ForeachThreadOp to an scf::ForOp. -struct ForeachThreadOpToScfForRewriter - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - FailureOr - returningMatchAndRewrite(scf::ForeachThreadOp foreachThreadOp, - PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(scf::ForeachThreadOp foreachThreadOp, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(foreachThreadOp, rewriter); - } -}; - -struct FusionResult { - TilingInterface consumerOp; - SmallVector fusedOps; -}; - -/// Pattern to fuse the producers of a tileable op. -struct LinalgExtFusionPattern - : public OpInterfaceRewritePattern { - LinalgExtFusionPattern(MLIRContext *context, ArrayRef operandsToFuse) - : OpInterfaceRewritePattern(context), - operandsToFuse(operandsToFuse.begin(), operandsToFuse.end()) {} - - FailureOr - returningMatchAndRewrite(TilingInterface consumerOp, - PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface consumerOp, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(consumerOp, rewriter); - } - -private: - SmallVector operandsToFuse; -}; - -//===----------------------------------------------------------------------===// -// Transformations exposed as patterns, moved from upstream MLIR as IREE still -// heavily relies on patterns that compose through filters. -// TODO: Deprecate all the code below. -//===----------------------------------------------------------------------===// -/// Wrap upstream linalg::splitReduction with a filter. -inline FailureOr -splitReduction(PatternRewriter &b, linalg::LinalgOp op, - const linalg::ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &filter, - bool useAlloc = false) { - if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || - op.getNumReductionLoops() != 1 || op.getNumDpsInits() != 1 || - !op.hasOnlyProjectedPermutations()) - return b.notifyMatchFailure(op, "precondition not met"); - - FailureOr res = - linalg::splitReduction(b, op, controlSplitReductionFn, useAlloc); - if (failed(res)) - return failure(); - - filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); - filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); - - return res->splitLinalgOp; -} - -/// -/// Linalg tiling pattern. -/// -/// Apply the `tiling` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `tiling` for more details. -// TODO: TiledOpInterface -struct LinalgTilingPattern - : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgTilingPattern( - MLIRContext *context, linalg::LinalgTilingOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgTilingPattern( - StringRef opName, MLIRContext *context, - linalg::LinalgTilingOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(linalg::LinalgOp op, - PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(linalg::LinalgOp op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// Options to control tiling; - linalg::LinalgTilingOptions options; -}; - -/// -/// Linalg SCF tiling pattern. -/// -/// Apply the `tiling` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `tiling` for more details. -struct LinalgSCFTilingPattern - : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgSCFTilingPattern( - MLIRContext *context, scf::SCFTilingOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgSCFTilingPattern( - StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - LogicalResult returningMatchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// Options to control tiling; - scf::SCFTilingOptions options; -}; - -template -class TilingPatterns; - -template <> -class TilingPatterns<> { -public: - static void insert(RewritePatternSet &patterns, - const linalg::LinalgTilingOptions &options, - const LinalgTransformationFilter &f) {} -}; - -template -class TilingPatterns { -public: - static void insert(RewritePatternSet &patterns, - const linalg::LinalgTilingOptions &options, - const LinalgTransformationFilter &f) { - patterns.add(OpTy::getOperationName(), - patterns.getContext(), options, f); - TilingPatterns::insert(patterns, options, f); - } -}; - -/// -/// Linalg SCF tile and fuse patterns. -/// -/// `filter` controls LinalgTransformMarker matching and update when specified. -struct LinalgSCFTileAndFusePattern - : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgSCFTileAndFusePattern( - MLIRContext *context, - scf::SCFTileAndFuseOptions options = scf::SCFTileAndFuseOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgSCFTileAndFusePattern( - StringRef opName, MLIRContext *context, - scf::SCFTileAndFuseOptions options = scf::SCFTileAndFuseOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - - scf::SCFTileAndFuseOptions options; -}; - -/// -/// Linalg vectorization patterns. -/// -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. -struct LinalgVectorizationPattern - : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgVectorizationPattern( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - -template -class VectorizationPatterns; - -template <> -class VectorizationPatterns<> { -public: - static void insert(RewritePatternSet &patterns, - const LinalgTransformationFilter &f) {} -}; - -template -class VectorizationPatterns { -public: - static void insert(RewritePatternSet &patterns, - const LinalgTransformationFilter &f) { - patterns.add(OpTy::getOperationName(), - patterns.getContext(), f); - VectorizationPatterns::insert(patterns, f); - } -}; - -/// -/// Linalg promotion patterns. -/// -/// Apply the `promoteSubViews` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `promoteSubViews` for more details. -struct LinalgBasePromotionPattern : public RewritePattern { - /// Entry point to match any LinalgOp - /// OpInterface. MatchAnyOpTag-based constructor - /// with a mandatory `filter`. - LinalgBasePromotionPattern( - MLIRContext *context, LinalgTransformationFilter f, - linalg::LinalgPromotionOptions options = linalg::LinalgPromotionOptions(), - PatternBenefit benefit = 1) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(f)), options(std::move(options)) {} - /// Entry point to match a specific Linalg op. - LinalgBasePromotionPattern( - StringRef opName, MLIRContext *context, - linalg::LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), - options(std::move(options)) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - if (failed(promoteSubviewsPrecondition(op, options))) - return failure(); - - // TODO: We cannot use root update here. This - // pattern is creating other ops, so if the - // promotion fails, those need to be cleaned - // up, which doesnt seem to be happening here. - // So to fail properly, we should be cloning - // the op and deleting the previous op. This - // needs more investigation. - rewriter.startRootUpdate(op); - Optional promotedOp = - promoteSubViews(rewriter, op, options); - if (!promotedOp) { - rewriter.cancelRootUpdate(op); - return op->emitError("subview promotion failed"); - } - rewriter.finalizeRootUpdate(op); - filter.replaceLinalgTransformationFilter(rewriter, op); - return success(); - } - -private: - /// LinalgTransformMarker handles special - /// attribute manipulations. - LinalgTransformationFilter filter; - /// Promotion options. - linalg::LinalgPromotionOptions options; -}; - -template -struct LinalgPromotionPattern : public LinalgBasePromotionPattern { - /// SFINAE: This constructor can only trigger for - /// concrete ops that have a static - /// `getOperationName` method. - template - LinalgPromotionPattern( - MLIRContext *context, linalg::LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options, - f, benefit) {} - /// This constructor is available to anyone. - LinalgPromotionPattern( - StringRef opName, MLIRContext *context, - linalg::LinalgPromotionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(opName, context, options, f, benefit) {} -}; - -/// Wraps upstream Linalg pattern in a filter check + update. -template -struct DownscaleSizeOneWindowed2DConvolution final - : public OpRewritePattern { - DownscaleSizeOneWindowed2DConvolution(MLIRContext *context, - LinalgTransformationFilter f) - : OpRewritePattern(context, /*benefit=*/1), - filter(std::move(f)) {} - - LogicalResult matchAndRewrite(Conv2DOp convOp, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); - linalg::DownscaleSizeOneWindowed2DConvolution p( - convOp.getContext()); - auto maybeConv1DOp = p.returningMatchAndRewrite(convOp, rewriter); - if (failed(maybeConv1DOp)) - return failure(); - filter.replaceLinalgTransformationFilter(rewriter, *maybeConv1DOp); - return success(); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - -/// Wraps upstream Linalg pattern in a filter check + update. -struct DownscaleDepthwiseConv2DNhwcHwcOp final - : public OpRewritePattern { - DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context, - LinalgTransformationFilter f) - : OpRewritePattern(context, - /*benefit=*/1), - filter(std::move(f)) {} - - LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); - linalg::DownscaleDepthwiseConv2DNhwcHwcOp p(convOp.getContext()); - auto maybeConv1DOp = p.returningMatchAndRewrite(convOp, rewriter); - if (failed(maybeConv1DOp)) - return failure(); - filter.replaceLinalgTransformationFilter(rewriter, *maybeConv1DOp); - return success(); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - -/// Wraps upstream Linalg pattern in a filter check + update. -struct LinalgPaddingPattern - : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgPaddingPattern( - MLIRContext *context, - linalg::LinalgPaddingOptions options = linalg::LinalgPaddingOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter()) - : OpInterfaceRewritePattern(context, - /*benefit=*/1), - filter(std::move(f)), options(options) {} - - /// Construct a pattern specifically applied to `opName`. - LinalgPaddingPattern( - StringRef opName, MLIRContext *context, - linalg::LinalgPaddingOptions options = linalg::LinalgPaddingOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter()) - : OpInterfaceRewritePattern(context, /*benefit=*/1), - filter(f.addOpNameFilter(opName)), options(std::move(options)) {} - - LogicalResult matchAndRewrite(linalg::LinalgOp op, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - linalg::LinalgPaddingPattern p(op.getContext(), options); - auto maybeRes = p.returningMatchAndRewrite(op, rewriter); - if (failed(maybeRes)) - return failure(); - filter.replaceLinalgTransformationFilter(rewriter, *maybeRes); - return success(); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// Options to control padding and hoisting. - linalg::LinalgPaddingOptions options; -}; - -FailureOr tileConsumerAndFuseProducers( - OpBuilder &b, linalg::LinalgOp consumerOp, ArrayRef tileSizes, - ArrayRef tileInterchange, - const Optional &tileDistribution); - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h deleted file mode 100644 index d6998f110d43..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_UTILS_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_UTILS_H_ - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -class Location; -class OpBuilder; -class Operation; -class Value; - -namespace tensor { -class ExtractSliceOp; -} - -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -/// Helper function which auto-completes the missing trailing dimensions to -/// always be offset = 0, size = dim, stride = 1. -void completeOffsetsSizesAndStrides(OpBuilder &b, Location loc, Value tensor, - ArrayRef leadingOffsets, - ArrayRef leadingSizes, - ArrayRef leadingStrides, - SmallVectorImpl &offsets, - SmallVectorImpl &sizes, - SmallVectorImpl &strides); - -/// Create a tensor::ExtractSliceOp by auto-completing the missing trailing -/// dimensions to always be offset = 0, size = dim, stride = 1. -Value createSubsetExtractOpFromLeadingOffsetsSizesAndStrides( - OpBuilder &b, Location loc, Value tensor, - llvm::ArrayRef leadingOffsets, ArrayRef leadingSizes, - ArrayRef leadingStrides); - -/// Create a tensor::InsertSliceOp by auto-completing the missing trailing -/// dimensions to always be offset = 0, size = dim, stride = 1. -Value createSubsetInsertOpFromLeadingOffsetsSizesAndStrides( - OpBuilder &b, Location loc, Value tensor, Value dest, - ArrayRef leadingOffsets, ArrayRef leadingSizes, - ArrayRef leadingStrides); - -/// Insert the `source` tensor into the `dest` tensor by creating the relevant -/// `subset_insert` op. The details of the `subset_insert` op are retrieved -/// from the `subset_extract` op so that they form a matching extract/insert -/// pair. -Value createMatchingSubsetInsertOp(OpBuilder &b, Location loc, - tensor::ExtractSliceOp subsetExtractOp, - Value source, Value dest); - -/// Create the parallel insertion terminator version of -/// `createMatchingSubsetInsertOp`. -void createMatchingParallelSubsetInsertOp( - OpBuilder &b, Location loc, tensor::ExtractSliceOp subsetExtractOp, - Value source, Value dest); - -struct AffineValueExpr { - explicit AffineValueExpr(AffineExpr e) : e(e) {} - AffineValueExpr bind(Value v) { - this->v = v; - return *this; - } - operator AffineExpr() const { return e; } - operator Value() const { return v; } - AffineExpr e; - Value v; -}; - -/// Helper struct to build simple arithmetic quantiAffineValueExprs with minimal -/// type inference support. -// TODO: move into ArithBuilder once ops have been moved into arith. -struct AffineBuilder { - AffineBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {} - - Value add(AffineValueExpr lhs, AffineValueExpr rhs) { - return b.createOrFold( - loc, ArrayRef{lhs.e + rhs.e}, ValueRange{lhs, rhs}); - } - Value sub(AffineValueExpr lhs, AffineValueExpr rhs) { - return b.createOrFold( - loc, ArrayRef{lhs.e - rhs.e}, ValueRange{lhs, rhs}); - } - Value mul(AffineValueExpr lhs, AffineValueExpr rhs) { - return b.createOrFold( - loc, ArrayRef{lhs.e * rhs.e}, ValueRange{lhs, rhs}); - } - Value ceil(AffineValueExpr lhs, AffineValueExpr rhs) { - return b.createOrFold( - loc, ArrayRef{lhs.e.ceilDiv(rhs.e)}, ValueRange{lhs, rhs}); - } - Value min(ValueRange vals) { - return b.createOrFold( - loc, AffineMap::getMultiDimIdentityMap(vals.size(), b.getContext()), - vals); - } - Value max(ValueRange vals) { - return b.createOrFold( - loc, AffineMap::getMultiDimIdentityMap(vals.size(), b.getContext()), - vals); - } - -private: - OpBuilder &b; - Location loc; -}; - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_UTILS_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h deleted file mode 100644 index 722e1a96d0bc..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_UTILS_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_UTILS_H_ - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at -/// `dim`. -Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim); - -/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at -/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`. -OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim); -SmallVector getDims(OpBuilder &builder, Location loc, Value v); - -/// Returns a vector that interchanges `elements` starting at offset `offset` -/// based on the indexes in `interchangeVector`. -template -SmallVector interchange(ArrayRef elements, - ArrayRef interchangeVector, - int offset = 0) { - SmallVector vec = llvm::to_vector(elements); - for (auto en : llvm::enumerate(interchangeVector)) { - vec[en.index() + offset] = elements[en.value() + offset]; - } - return vec; -} -template -SmallVector undoInterchange(ArrayRef elements, - ArrayRef interchangeVector, - int offset = 0) { - SmallVector vec = llvm::to_vector(elements); - for (auto en : llvm::enumerate(interchangeVector)) { - vec[en.value() + offset] = elements[en.index() + offset]; - } - return vec; -} - -/// Returns the `interchangeVector` based on `dimsPos`. -SmallVector computeInterchangeFromDimPos(ArrayRef dimsPos, - int64_t rank); - -/// Converts a 2D float array to a constant value. The 2D array is stored as -/// a 1D row-major array in `val` and has shape `rows` x `cols`. -Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols, - Location loc, PatternRewriter &rewriter); - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_UTILS_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h deleted file mode 100644 index 77dbb09135b2..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_ -#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_ - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { -namespace Winograd { - -// This file contains the Winograd constant matrices for different -// output tile sizes - -//===----------------------------------------------------------------------===// -// Output tile size = 6, Kernel size = 3 -//===----------------------------------------------------------------------===// -// These constants were obtained from this paper: -// -// Liu, J. et al (2021) Optimizing Winograd-Based Convolution with Tensor Cores. -// https://dl.acm.org/doi/abs/10.1145/3472456.3472473 -// - -// clang-format off - -const float BT_6x6_3x3[] = { - 1, 0, -21./4., 0, 21./4., 0, -1, 0, - 0, 1, 1, -17./4., -17./4., 1, 1, 0, - 0, -1, 1, 17./4., -17./4., -1, 1, 0, - 0, 1./2, 1./4., -5./2., -5./4., 2, 1, 0, - 0, -1./2, 1./4., 5./2., -5./4., -2, 1, 0, - 0, 2, 4, -5./2., -5, 1./2., 1, 0, - 0, -2, 4, 5./2., -5, -1./2., 1, 0, - 0, -1, 0, 21./4., 0, -21./4., 0, 1 -}; - -const float B_6x6_3x3[] = { - 1, 0, 0, 0, 0, 0, 0, 0, - 0, 1, -1, 1./2, -1./2, 2, -2, -1, - -21./4., 1, 1, 1./4., 1./4., 4, 4, 0, - 0, -17./4., 17./4., -5./2., 5./2., -5./2., 5./2., 21./4., - 21./4., -17./4., -17./4., -5./4., -5./4., -5, -5, 0, - 0, 1, -1, 2, -2, 1./2., -1./2., -21./4., - -1, 1, 1, 1, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1 -}; - -const float G_6x6_3x3[] = { - 1, 0, 0, - -2./9., -2./9., -2./9., - -2./9., 2./9., -2./9., - 1./90, 1./45, 2./45, - 1./90, -1./45, 2./45, - 32./45, 16./45, 8./45, - 32./45, -16./45, 8./45, - 0, 0, 1 -}; - -const float AT_6x6_3x3[] = { - 1, 1, 1, 1, 1, 1, 1, 0, - 0, 1, -1, 2, -2, 1./2, -1./2, 0, - 0, 1, 1, 4, 4, 1./4, 1./4, 0, - 0, 1, -1, 8, -8, 1./8, -1./8, 0, - 0, 1, 1, 16, 16, 1./16, 1./16, 0, - 0, 1, -1, 32, -32, 1./32, -1./32, 1 -}; - -const float A_6x6_3x3[] = { - 1, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 1, -1, 1, -1, 1, -1, - 1, 2, 4, 8, 16, 32, - 1, -2, 4, -8, 16, -32, - 1, 1./2, 1./4, 1./8, 1./16, 1./32, - 1, -1./2, 1./4, -1./8, 1./16, -1./32, - 0, 0, 0, 0, 0, 1 -}; - -// clang-format on - -} // namespace Winograd -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir -#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_WINOGRAD_CONSTANTS_H_ diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt deleted file mode 100644 index 6e2356423800..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -function(_add_dialect) - set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td) - mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls) - mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs) - mlir_tablegen(LinalgTransformDialect.h.inc --gen-dialect-decls --dialect=iree_linalg_transform) - mlir_tablegen(LinalgTransformDialect.cpp.inc --gen-dialect-defs --dialect=iree_linalg_transform) - add_public_tablegen_target(IREELinalgTransformIncGen) - add_dependencies(mlir-headers IREELinalgTransformIncGen) -endfunction() - -function(_add_transform_dialect_extension) - set(LLVM_TARGET_DEFINITIONS StructuredTransformOpsExt.td) - mlir_tablegen(StructuredTransformOpsExt.h.inc -gen-op-decls) - mlir_tablegen(StructuredTransformOpsExt.cpp.inc -gen-op-defs) - add_public_tablegen_target(IREELinalgTransformExtIncGen) - add_dependencies(mlir-headers IREELinalgTransformExtIncGen) -endfunction() - -_add_dialect() -_add_transform_dialect_extension() diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h deleted file mode 100644 index 5bf920ebf180..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H -#define MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/PDL/IR/PDLTypes.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace scf { -class ForOp; -} // namespace scf -} // namespace mlir - -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.h.inc" - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h.inc" - -#endif // MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td deleted file mode 100644 index 565d4d458009..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef LINALG_TRANSFORM_OPS -#define LINALG_TRANSFORM_OPS - -include "mlir/IR/OpBase.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/Dialect/PDL/IR/PDLTypes.td" -include "mlir/Dialect/Transform/IR/TransformInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" - -def Linalg_Transform_Dialect : Dialect { - let name = "iree_linalg_transform"; - let cppNamespace = "::mlir::linalg::transform"; - let dependentDialects = [ - "linalg::LinalgDialect", - ]; -} - -// Operations with this trait must provide the following methods: -// - `Value target()` - returns the operation handle (value of !pdl.operation -// type) targeted by this transformation, if available; -// - `Optional matcher()` - returns the name of the PDL matcher -// that selects the ops targeted by this transformation, if provided. -class Linalg_Transform_Operation props = []> - : Op { - let cppNamespace = "::mlir::linalg::transform"; -} - -class Transform_Op props = []> - : Linalg_Transform_Operation])>; - -//===----------------------------------------------------------------------===// - -def ScopeOp : Linalg_Transform_Operation<"util.scope", - [IsolatedFromAbove, DeclareOpInterfaceMethods]> { - let description = [{An operation to restrict transformation scopes.}]; - - let regions = (region AnyRegion:$body); - let arguments = (ins Variadic:$ins); - let results = (outs Variadic:$outs); - let assemblyFormat = [{ `(` operands `)` attr-dict-with-keyword $body - `:` functional-type(operands, results) }]; -} - -def ForwardOp : Linalg_Transform_Operation<"util.forward", - [Terminator, HasParent<"ScopeOp">]> { - let description = [{Terminator for a scope operation, indicating the results - that should be forwarded out of the scope.}]; - - let arguments = (ins Variadic:$ins); - let assemblyFormat = "operands attr-dict `:` type(operands)"; -} - -//===----------------------------------------------------------------------===// - -def ExpertOp : Linalg_Transform_Operation<"expert"> { - let description = [{A "transformation expert" that can be lowered to a - sequence of transformations. The details of the lowering depend on the name - and are expressed declaratively.}]; - - let arguments = (ins PDL_Operation:$target, - StrAttr:$expertName); - let results = (outs PDL_Operation:$transformed); - - let assemblyFormat = "`apply` $expertName `to` $target attr-dict"; -} - -#endif // LINALG_TRANSFORM_OPS diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/Passes.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/Passes.h deleted file mode 100644 index 47bdbcfc3d30..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/Passes.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "mlir/Support/LLVM.h" -#include - -namespace mlir { -namespace linalg { -namespace transform { - -void registerTransformDialectInterpreterPass(); -void registerLinalgTransformExpertExpansionPass(); -void registerDropSchedulePass(); - -} // namespace transform -} // namespace linalg -} // namespace mlir - -namespace mlir { -class Pass; - -// Pass to schedule a dispatch region by using the transform dialect. -// The schedule is specified by the transform module that is parsed from -// `transformFileName`. -std::unique_ptr createTransformDialectInterpreterPass( - llvm::StringRef transformFileName = llvm::StringRef()); -std::unique_ptr createDropSchedulePass(); -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/ScopedTransform.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/ScopedTransform.h deleted file mode 100644 index a0594fe76e66..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/ScopedTransform.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_SCOPEDTRANSFORM_H -#define IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_SCOPEDTRANSFORM_H - -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" - -namespace mlir { -namespace linalg { -namespace transform { -ScopeOp wrapInScope(Operation *op); -FailureOr> unwrapScope(ScopeOp scope); - -template -auto scoped(Operation *target, TransformT &&transform) { - auto scope = wrapInScope(target); - Operation &op = *scope.getBody().front().begin(); - auto result = transform(scope, &op); - if (failed(unwrapScope(scope)) || failed(result)) - return decltype(result)(failure()); - return result; -} -} // namespace transform -} // namespace linalg -} // namespace mlir - -#endif // IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_SCOPEDTRANSFORM_H diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h deleted file mode 100644 index 99cb64399a95..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -class MLIRContext; - -/// The only purpose of this class is to enable creation of PatternRewriter -/// instances as the base class doesn't have a public constructor. -/// The op-based constructor sets the insertion point before the `op`. -class SimplePatternRewriter : public PatternRewriter { -public: - SimplePatternRewriter(MLIRContext *context) : PatternRewriter(context) {} - - SimplePatternRewriter(Operation *op) : PatternRewriter(op->getContext()) { - setInsertionPoint(op); - } -}; -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h deleted file mode 100644 index 954deaad994d..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H -#define IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H - -#include "iree-dialects/Transforms/Listener.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace linalg { -class LinalgOp; -} // namespace linalg -namespace scf { -class ForOp; -} // namespace scf - -class TrackingListener : public RewriteListener, - public transform::TransformState::Extension { -public: - explicit TrackingListener(transform::TransformState &state) - : transform::TransformState::Extension(state) {} - - ~TrackingListener() override { -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - assert(errorStateChecked && "must check listener error state"); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - } - - void notifyRootReplaced(Operation *op, ValueRange newValues) override; - - void notifyOperationRemoved(Operation *op) override; - - LogicalResult checkErrorState() const { -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - errorStateChecked = true; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - return failure(hadErrors); - } - - /// Remove the mappings between the given operation and any handle that may be - /// associated with it in the transform op. - void removeMappings(Operation *op); - -private: - InFlightDiagnostic emitError(Operation *op, const llvm::Twine &message = {}) { - mayFail(failure()); - return op->emitError(message); - } - - void mayFail(LogicalResult result) { - hadErrors |= result.failed(); -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - errorStateChecked = false; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - } - - bool hadErrors = false; - -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - mutable bool errorStateChecked = false; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS -}; - -} // namespace mlir - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc" - -namespace mlir { -namespace transform_ext { -class StructuredTransformOpsExtension - : public mlir::transform::TransformDialectExtension< - StructuredTransformOpsExtension> { -public: - StructuredTransformOpsExtension(); -}; -} // namespace transform_ext -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td deleted file mode 100644 index e2b716f2c9df..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef STRUCTURED_TRANSFORM_OPS_EXT -#define STRUCTURED_TRANSFORM_OPS_EXT - -include "mlir/Dialect/PDL/IR/PDLTypes.td" -include "mlir/Dialect/Transform/IR/TransformAttrs.td" -include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/OpBase.td" - -def CanonicalizedSequenceOp - : TransformDialectOp<"structured.canonicalized_sequence", - [OpAsmOpInterface, - PossibleTopLevelTransformOpTrait, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { - - let summary = "A transformation sequence interspersed with canonicalizations"; - let description = [{ - This op is a copy of `transform.sequence`, but applies canonicalizations - after each step in the sequence. - }]; - - let arguments = (ins FailurePropagationMode:$failure_propagation_mode, - Optional:$root); - let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$body); - - let assemblyFormat = - "($root^)? `failures` `(` $failure_propagation_mode `)` attr-dict-with-keyword regions (`:` type($results)^)?"; - - let builders = [ - OpBuilder<(ins "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, - CArg<"::llvm::function_ref", - "nullptr">)> - ]; - - let extraClassDeclaration = [{ - using BodyBuilderFn = - ::llvm::function_ref; - - /// Allow the dialect prefix to be omitted. - static ::llvm::StringRef getDefaultDialect() { return "transform"; } - }]; - - let cppNamespace = "mlir::transform_ext"; - let hasVerifier = 1; -} - -//===----------------------------------------------------------------------===// - -def BufferizeOp : Op]> { - let description = [{Indicates that the entire module should be bufferized.}]; - let assemblyFormat = "attr-dict"; - let cppNamespace = "mlir::transform_ext"; -} - -def LowerVectorsOp : Op]> { - let description = [{Indicates that the vector operations in the entire - module should be lowered to simpler primitives (multiple stages of lowering - be executed at once).}]; - - let arguments = - (ins DefaultValuedAttr:$stages, - DefaultValuedAttr:$contraction_lowering, - DefaultValuedAttr:$multireduction_lowering, - DefaultValuedAttr:$split_transfers, - DefaultValuedAttr:$unroll_vector_transfers, - DefaultValuedAttr:$transpose_lowering, - DefaultValuedAttr:$transpose_avx2_lowering - ); - - let assemblyFormat = "attr-dict"; - let cppNamespace = "mlir::transform_ext"; -} - -def LowerToLLVMOp : Op]> { - let description = [{Indicates that the entire module should be converted - to the LLVM dialect. This is expected to be the last transformation in - a sequence.}]; - - let arguments = - (ins DefaultValuedAttr:$reassociate_fp_reductions, - DefaultValuedAttr:$enable_index_optimizations, - DefaultValuedAttr:$enable_arm_neon, - DefaultValuedAttr:$enable_arm_sve, - DefaultValuedAttr:$enable_amx, - DefaultValuedAttr:$enable_x86vector, - DefaultValuedAttr:$enable_async); - - let assemblyFormat = "attr-dict"; - let cppNamespace = "mlir::transform_ext"; -} - - -def RegisterMatchCallbacksOp : - Op, - DeclareOpInterfaceMethods]> { - let description = [{ - Registers named structured op matcher callbacks specific for IREE to use - with `transform.iree.match_callback`. This should be called before first - `match_callback` may be executed following the transform dialect control - flow. - - The callbacks must have a unique name and a signature compatible with - `MatchCallbacksRegistry::MatchCallbackFn`, which currently means - `DiagnosedSilenceableFailure(MatchCallbackResult &, Location, - const TransformState &, ValueRange)`. The callback receives a "result", - followed by a location at which errors should be reported, a transform - state at the moment of the _match_ (not registration) and a list of - handle values passed as operands to the `match_callback` operation. - It is expected to populate the "result" object with lists of payload - operations that will be bound to the handles produced by the - `match_callback` operation. The callback may fail, at which point - it should produce a silenceable error. The callback currently is not - allowed to modify the payload IR (though this may be revised in the - future for the purpose of communicating the properties of the IR - captured by the match). Therefore, it should not have a reason to - produce a definite error. - }]; - - let arguments = (ins); - let results = (outs); - let assemblyFormat = "attr-dict"; - let cppNamespace = "mlir::transform_ext"; -} - -def MatchCallbackOp : - Op, - DeclareOpInterfaceMethods]> { - let description = [{ - Performs payload IR matching using a C++ callback registered beforehand. - The callback is identified by name and is passed the current transform - state and the list of handle operands, along with information necessary - for error propagation. See `register_match_callbacks` for the description - of the callback contract. - - If `failure_propagation_mode` is set to `suppress`, any silenceable errors - in the callback (typically, "failure to match") will be ignored and the - resulting handles will be associated with empty lists of payload - operations. Otherwise, silenceable failures are propagated. - }]; - - let arguments = (ins StrAttr:$callback_name, - FailurePropagationMode:$failure_propagation_mode, - Variadic:$inputs); - let results = (outs Variadic:$outputs); - let assemblyFormat = "`failures` `(` $failure_propagation_mode `)` " - "$callback_name `(` $inputs `)` attr-dict " - "`:` functional-type($inputs, $outputs)"; - let cppNamespace = "mlir::transform_ext"; -} - -def TakeFirstOp : - Op, - DeclareOpInterfaceMethods]> { - let description = [{ - Given an arbitrary list of handles associated with potentially empty lists - of payload operations, produces two new handles: - - - a handle pointing to the same payload operations as the first operand - handle with a non-empty list of payload operations; - - a handle pointing to the concatenated list of payload operations - associated with any other handle. - - Note that this does not perform any deduplication. - - This operation is useful to select a single target after some potentially - unsuccessful matches. - }]; - - let arguments = (ins Variadic:$inputs); - let results = (outs TransformTypeInterface:$first, - TransformTypeInterface:$rest); - let assemblyFormat = - "$inputs attr-dict `:` functional-type($inputs, results)"; - let cppNamespace = "mlir::transform_ext"; -} - -#endif // STRUCTURED_TRANSFORM_OPS_EXT diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h deleted file mode 100644 index 126042b0d631..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_LINALG_TRANSFORM_TRANSFORM_INTERPRETER_UTILS_H -#define IREE_DIALECTS_LINALG_TRANSFORM_TRANSFORM_INTERPRETER_UTILS_H - -#include - -namespace mlir { -class Operation; -class LogicalResult; -namespace transform { -class TransformOpInterface; - -/// Utility to parse the content of a `transformFileName` mlir file containing -/// a transform dialect specification. -LogicalResult -parseTransformModuleFromFile(MLIRContext *context, - llvm::StringRef transformFileName, - OwningOpRef &transformModule); - -/// Utility to extract the `TransformOpInterface` ops that have the trait -/// `PossibleTopLevelTransformOpTrait`. Such ops are -LogicalResult -extractTopLevelTransformOps(Region &r, - SmallVectorImpl &res); - -/// Utility to run a transform dialect specification contained in a -/// `transformRegion`, on a `target` op. -/// Since the transform dialect may use PDL which may modify the IR, the -/// underlying implementation clones the transform dialect operations before -/// applying them. -LogicalResult applyTransformsInRegion(Region &transformRegion, - Operation *target); -} // namespace transform -} // namespace mlir -#endif // IREE_DIALECTS_LINALG_TRANSFORM_TRANSFORM_INTERPRETER_UTILS_H diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/Listener.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/Listener.h deleted file mode 100644 index 41a0972d675f..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/Listener.h +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H -#define IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H - -#include "mlir/IR/PatternMatch.h" - -namespace mlir { - -//===----------------------------------------------------------------------===// -// RewriteListener -//===----------------------------------------------------------------------===// - -/// This class represents a listener that can be used to hook on to various -/// rewrite events in an `OpBuilder` or `PatternRewriter`. The class is notified -/// by when: -/// -/// - an operation is removed -/// - an operation is inserted -/// - an operation is replaced -/// - a block is created -/// -/// Listeners can be used to track IR mutations throughout pattern rewrites. -struct RewriteListener { - virtual ~RewriteListener(); - - /// These are the callback methods that subclasses can choose to implement if - /// they would like to be notified about certain types of mutations. - - /// Notification handler for when an operation is inserted into the builder. - /// op` is the operation that was inserted. - virtual void notifyOperationInserted(Operation *op) {} - - /// Notification handler for when a block is created using the builder. - /// `block` is the block that was created. - virtual void notifyBlockCreated(Block *block) {} - - /// Notification handler for when the specified operation is about to be - /// replaced with another set of operations. This is called before the uses of - /// the operation have been replaced with the specific values. - virtual void notifyRootReplaced(Operation *op, ValueRange newValues) {} - - /// Notification handler for when an the specified operation is about to be - /// deleted. At this point, the operation has zero uses. - virtual void notifyOperationRemoved(Operation *op) {} - - /// Notify the listener that a pattern failed to match the given operation, - /// and provide a callback to populate a diagnostic with the reason why the - /// failure occurred. This method allows for derived listeners to optionally - /// hook into the reason why a rewrite failed, and display it to users. - virtual LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) { - return failure(); - } -}; - -//===----------------------------------------------------------------------===// -// ListenerList -//===----------------------------------------------------------------------===// - -/// This class contains multiple listeners to which rewrite events can be sent. -class ListenerList : public RewriteListener { -public: - /// Add a listener to the list. - void addListener(RewriteListener *listener) { listeners.push_back(listener); } - - /// Send notification of an operation being inserted to all listeners. - void notifyOperationInserted(Operation *op) override; - /// Send notification of a block being created to all listeners. - void notifyBlockCreated(Block *block) override; - /// Send notification that an operation has been replaced to all listeners. - void notifyRootReplaced(Operation *op, ValueRange newValues) override; - /// Send notification that an operation is about to be deleted to all - /// listeners. - void notifyOperationRemoved(Operation *op) override; - /// Notify all listeners that a pattern match failed. - LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) override; - -private: - /// The list of listeners to send events to. - SmallVector listeners; -}; - -//===----------------------------------------------------------------------===// -// PatternRewriterListener -//===----------------------------------------------------------------------===// - -/// This class implements a pattern rewriter with a rewrite listener. Rewrite -/// events are forwarded to the provided rewrite listener. -class PatternRewriterListener : public PatternRewriter, public ListenerList { -public: - PatternRewriterListener(MLIRContext *context) : PatternRewriter(context) {} - - /// When an operation is about to be replaced, send out an event to all - /// attached listeners. - void replaceOp(Operation *op, ValueRange newValues) override { - ListenerList::notifyRootReplaced(op, newValues); - PatternRewriter::replaceOp(op, newValues); - } - - void notifyOperationInserted(Operation *op) override { - ListenerList::notifyOperationInserted(op); - } - void notifyBlockCreated(Block *block) override { - ListenerList::notifyBlockCreated(block); - } - void notifyOperationRemoved(Operation *op) override { - ListenerList::notifyOperationRemoved(op); - } -}; - -} // namespace mlir - -#endif // IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h deleted file mode 100644 index a062f2aee36d..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H -#define LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H - -#include "iree-dialects/Transforms/Listener.h" - -namespace mlir { -class DominanceInfo; -class Operation; - -LogicalResult eliminateCommonSubexpressions(Operation *op, - DominanceInfo *domInfo, - RewriteListener *listener); -} // namespace mlir - -#endif // LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h deleted file mode 100644 index 2c0adf667e95..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/Listener.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -// The following are iree-dialects extensions to MLIR. -namespace mlir { -struct GreedyRewriteConfig; -struct RewriteListener; - -/// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns -/// success if no more patterns can be matched. `erased` is set to true if `op` -/// was folded away or erased as a result of becoming dead. Note: This does not -/// apply any patterns recursively to the regions of `op`. Accepts a listener -/// so the caller can be notified of rewrite events. -LogicalResult applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, RewriteListener *listener); - -/// Apply the given list of transformations to the regions of the -/// isolated-from-above operation `root` greedily until convergence. Update -/// Linalg operations in values of `trackedOperations` if they are replaced by -/// other Linalg operations during the rewriting process. Tracked operations -/// must be replaced with Linalg operations and must not be erased in the -/// patterns. -static inline LogicalResult applyPatternsTrackAndFoldGreedily( - Operation *root, RewriteListener &listener, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config = GreedyRewriteConfig()) { - return applyPatternsAndFoldGreedily(root, patterns, config, &listener); -} - -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h deleted file mode 100644 index 430d879870b0..000000000000 --- a/integrations/tensorflow/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ -#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ - -#include -#include -#include - -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/Matchers.h" -#include "llvm/ADT/StringMap.h" - -namespace mlir { -namespace transform_ext { - -//===---------------------------------------------------------------------===// -// StructuredOpMatcher and predicates. -//===---------------------------------------------------------------------===// - -class StructuredOpMatcher; -StructuredOpMatcher m_StructuredOp(); - -/// A tag indicating the shape being static or dynamic, for use with the -/// structured op matcher. -enum class ShapeKind { Static, Dynamic }; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all dimensions. -struct AllDims {}; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all operands of the relevant kind. -struct AllOperands {}; - -/// A tag indicating to look for any user of the operation's result that would -/// satisfy the predicate. -struct HasAnyUse {}; - -/// Base class for predicate parameters that can be described with the single -/// value. Concrete predicate parameters should inherit this and forward the -/// constructor via `using Base::Base`. -template -struct SingleValuePredicateParam { - using Base = SingleValuePredicateParam; - explicit SingleValuePredicateParam(T value) : value(value) {} - const T value; -}; - -/// Indicates that the dimension must be divisible by the given value. -struct DivisibleBy : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be equal to the given value. -struct NumEqualsTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the bit width of the elemental type must be equal to the give -/// value. -struct ElementTypeBitWidth : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Predicate tag indicating that the affine map is a permutation. -struct IsPermutation {}; - -/// Indicates that the match optional. The matcher is still expected to run and -/// capture if successful. The parameter can be set to false -struct OptionalMatch : public SingleValuePredicateParam { - OptionalMatch() : Base(true) {} - explicit OptionalMatch(bool set) : Base(set) {} -}; - -/// Predicate tag indicating that the reduction is produced by a single combiner -/// operation. -struct SingleCombinerReduction {}; - -/// Indicates that it suffices for only a subset of an operand or result value -/// to be used. -struct SubsetOf { - explicit SubsetOf(StructuredOpMatcher &matcher) : matcher(matcher) {} - StructuredOpMatcher &matcher; -}; - -namespace detail { -template -using has_reset_capture_t = decltype(std::declval().resetCapture()); -} // namespace detail - -/// Structured op matcher with additional predicates attachable through the -/// fluent, a.k.a. chainable, API. Note that public API must *not* accept -/// additional callbacks even; new predicates should be added instead when -/// necessary. Not only this decreases the depth of the callback stack and -/// increases readability, it also allows us to port the matcher to a -/// declarative format using PDL and/or Transform dialect in the future. The -/// latter will become impossible with arbitrary C++ callbacks. -class StructuredOpMatcher { - friend StructuredOpMatcher m_StructuredOp(); - using PredicateFn = std::function; - using CaptureResetFn = std::function; - - /// Matches a structured operation if the given predicate is satisfied. - StructuredOpMatcher(PredicateFn &&firstPredicate) { - predicates.push_back(std::move(firstPredicate)); - } - -public: - /// Matches any structured operation, i.e., operation with LinalgOp interface. - StructuredOpMatcher() {} - - /// Creates a matcher for a structured operation with one of the given types. - template - static StructuredOpMatcher create() { - return StructuredOpMatcher( - [](linalg::LinalgOp op) { return isa(op.getOperation()); }); - } - - /// Returns the matched operation if the match was successful. - linalg::LinalgOp getCaptured() const { return captured; } - - /// Matches the given operation, hook for `matchPattern`. - bool match(Operation *op); - - /// Adds a predicate checking that the given iteration space dimension is - /// static/dynamic. The dimension index may be negative, in which case - /// dimensions are counted from the last one (i.e. Python-style), or be an - /// AllDims tag, in which case all dimensions are checked. This may be - /// eventually extended to slices and/or lists of dimensions. - StructuredOpMatcher &dim(int64_t dimension, ShapeKind kind); - StructuredOpMatcher &dim(AllDims tag, ShapeKind kind); - - /// Adds a predicate checking that the given iteration space dimension has the - /// given iterator type, e.g., parallel or reduction. The dimension index may - /// be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style), or be an AllDims tag, in which case all dimensions - /// are checked. This may be eventually extended to slices and/or lists of - /// dimensions. - StructuredOpMatcher &dim(int64_t dimension, utils::IteratorType kind); - StructuredOpMatcher &dim(AllDims tag, utils::IteratorType kind); - - /// Adds a predicate checking that the given iteration space dimension is - /// statically known to be divisible by the given value. The dimension index - /// may be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style). - StructuredOpMatcher &dim(int64_t dimension, DivisibleBy divisibleBy); - - /// Adds a predicate checking that the structured op has the given number of - /// inputs. - StructuredOpMatcher &input(NumEqualsTo num) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - return linalgOp.getNumDpsInputs() == num.value; - }); - return *this; - } - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the `position`-th operand. The position may be - /// negative, in which case positions are counted from the last one - /// (i.e. Python-style). When the match is optional, the predicate check - /// succeeds as long as the `position` is in bounds. The matcher is executed - /// if there is a defining operation for the input operand. - template - std::enable_if_t< - llvm::is_detected<::mlir::detail::has_operation_or_value_matcher_t, T, - Operation *>::value, - StructuredOpMatcher &> - input(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - predicates.push_back([position, optional, - &operandMatcher](linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp.getNumDpsInputs() + position; - if (transformedPosition >= linalgOp.getNumDpsInputs()) - return false; - - Operation *definingOp = linalgOp.getDpsInputOperand(transformedPosition) - ->get() - .getDefiningOp(); - if (!definingOp) - return optional.value; - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - if (operandMatcher.match(definingOp)) - return true; - return optional.value; - }); - recordNestedMatcher(operandMatcher); - return *this; - } - - /// Adds a predicate checking that all input operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &input(AllOperands tag, IsPermutation); - - /// Adds a predicate that recursively applies another predicate to the - /// operation defining the `position`-th input operand, looking through any - /// "subsetting" operation such as "tensor.extract_slice". - StructuredOpMatcher &input(int64_t position, SubsetOf subset); - - /// Adds a predicate that recursively applies another predicate to the - /// operation defining the `position`-th output operand, looking through any - /// "subsetting" operation such as "tensor.extract_slice". - StructuredOpMatcher &output(int64_t position, SubsetOf subset); - - /// Adds a predicate checking that the structured op has the given number of - /// outputs. - StructuredOpMatcher &output(NumEqualsTo num) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - return linalgOp.getNumDpsInits() == num.value; - }); - return *this; - } - - /// Adds a predicate checking that all output operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &output(AllOperands tag, IsPermutation); - - /// Adds a predicate checking that the bit width of the elemental type of the - /// structured op output at the given position is equal to the given value. - StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width); - - /// Adds a predicate checking that the output of the structured op is produced - /// by a reduction with a single-operation combinator (such as addf or mulf, - /// but not a compare+select pair). - StructuredOpMatcher &output(int64_t position, SingleCombinerReduction tag); - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the init/out operand corresponding to `position`-th - /// output. The position may be negative, in which case positions are counted - /// from the last one (i.e. Python-style). When the match is optional, the - /// predicate check succeeds as long as the `position` is in bounds. The - /// matcher executed if there is a defining operation for the output operand. - template - std::enable_if_t< - llvm::is_detected<::mlir::detail::has_operation_or_value_matcher_t, T, - Operation *>::value, - StructuredOpMatcher &> - output(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - predicates.push_back([position, optional, - &operandMatcher](linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp.getNumDpsInits() + position; - if (transformedPosition >= linalgOp.getNumDpsInits()) - return false; - - Operation *definingOp = linalgOp.getDpsInitOperand(transformedPosition) - ->get() - .getDefiningOp(); - if (!definingOp) - return optional.value; - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - if (operandMatcher.match(definingOp)) - return true; - return optional.value; - }); - recordNestedMatcher(operandMatcher); - return *this; - } - - /// Adds a predicate that recursively applies to users of the `position`-th - /// result of the structured op. Succeeds if any user matches the predicate. - /// When the match is optional, the predicate check succeeds as long as the - /// `position` is in bounds, after running the given matcher. - template - std::enable_if_t< - llvm::is_detected<::mlir::detail::has_operation_or_value_matcher_t, T, - Operation *>::value, - StructuredOpMatcher &> - result(int64_t position, HasAnyUse tag, T &resultUserMatcher, - OptionalMatch optional = OptionalMatch(false)) { - predicates.push_back([&resultUserMatcher, optional, - position](linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp->getNumResults() + position; - if (transformedPosition >= linalgOp->getNumResults()) - return false; - - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - if (llvm::any_of(linalgOp->getResult(transformedPosition).getUsers(), - [&resultUserMatcher](Operation *op) { - return resultUserMatcher.match(op); - })) { - return true; - } - return optional.value; - }); - recordNestedMatcher(resultUserMatcher); - return *this; - } - - /// Adds a predicate that recursively applies to users of the `positions`-th - /// result, looking through any "subsetting" operation such as - /// "tensor.extract_slice". Succeeds if any user matches the predicate. - /// When the match is optional, the predicate check succeeds as long as the - /// `position` is in bounds, after running the given matcher. - StructuredOpMatcher &result(int64_t position, HasAnyUse tag, SubsetOf subset, - OptionalMatch optional = OptionalMatch(false)); - - /// Resets the captured value to null. This should be called if the same - /// pattern needs to be applied more than once as it may keep captured values - /// for optional nested predicates from the previous application. - void resetCapture() { - captured = nullptr; - for (const CaptureResetFn &fn : captureResetFns) - fn(); - } - -private: - /// Informs the matcher that it has another, nested matcher. Practically, - /// records the captured value cleanup function so it runs when required. - template - std::enable_if_t::value> - recordNestedMatcher(T &nested) { - captureResetFns.push_back([&nested] { nested.resetCapture(); }); - } - template - std::enable_if_t::value> - recordNestedMatcher(T &nested) {} - - /// Additional predicates to be checked on the structured op. - SmallVector predicates; - - /// Callbacks to reset captures of nested matchers. - SmallVector captureResetFns; - - /// Matched value. - linalg::LinalgOp captured = nullptr; -}; - -/// Creates a matcher of an arbitrary structured op. -inline StructuredOpMatcher m_StructuredOp() { return StructuredOpMatcher(); } - -/// Creates a matcher of a structured op with kinds provided as template -/// arguments. -template -inline StructuredOpMatcher m_StructuredOp() { - return StructuredOpMatcher::create(); -} - -//===---------------------------------------------------------------------===// -// MatchCallback functionality. -//===---------------------------------------------------------------------===// - -/// Additional results of the C++ callback usable in the `match_callback` -/// transform operation. Conceptually, a list of lists of payload operations to -/// be associated with each result handle. -class MatchCallbackResult { -public: - /// Returns the number of lists of payload operations. - unsigned getNumPayloadGroups() const { return payloadGroupLengths.size(); } - - /// Returns the `position`-th list of payload operations. - ArrayRef getPayloadGroup(unsigned position) const; - - /// Adds a new list of payload operations to the list of lists. The new list - /// must not contain null operations. - template - unsigned addPayloadGroup(Range operations) { - int64_t originalLength = payloadOperations.size(); - assert(llvm::all_of(operations, [](Operation *op) -> bool { return op; }) && - "null operation"); - llvm::append_range(payloadOperations, operations); - payloadGroupLengths.push_back(payloadOperations.size() - originalLength); - return payloadGroupLengths.size() - 1; - } - void addPayloadGroup(ArrayRef operations) { - addPayloadGroup>(operations); - } - - /// Adds a new singleton list of payload operation to the list of lists if the - /// operation is non-null, adds an empty list otherwise. Useful for results of - /// optional matches. - void addPotentiallyEmptyPayloadGroup(Operation *op) { - if (!op) - addPayloadGroup(ArrayRef()); - else - addPayloadGroup(ArrayRef(op)); - } - -private: - /// The flat list of all payload opreations. `payloadGroupLengths` can be used - /// to compute the sublist that corresponds to one nested list. - // TODO: if somebody implements such a flattened vector generically, use it. - SmallVector payloadOperations; - SmallVector payloadGroupLengths; -}; - -/// A transform state extension that maintains the mapping between callback -/// names as strings usable in `match_callback` and their implementations. -class MatchCallbacksRegistry : public transform::TransformState::Extension { -public: - using MatchCallbackFn = std::function; - - /// Constructs the extension. - MatchCallbacksRegistry(transform::TransformState &state) - : transform::TransformState::Extension(state) {} - - /// Registers the given function as a callback with the given name. The name - /// must not be already present in the registry. The callback must be - /// convertible to MatchCallbackFn. - template - void registerCallback(StringRef name, Fn &&fn) { - bool succeeded = callbacks.try_emplace(name, std::forward(fn)).second; - (void)succeeded; - assert(succeeded && "adding a callback with a repeated name"); - } - - /// Returns a pointer to the implementation of the callback with the given - /// name, or null if it is not present in the registry. - const MatchCallbackFn *get(StringRef name) const { - auto iter = callbacks.find(name); - if (iter == callbacks.end()) - return nullptr; - return &iter->getValue(); - } - -private: - llvm::StringMap callbacks; -}; - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -/// Creates a group of matchers for: -/// -/// trailing(reduction(leading(), fill())) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional. Each matcher will capture the corresponding operation. -void makeReductionMatcher(StructuredOpMatcher &reduction, - StructuredOpMatcher &fill, - StructuredOpMatcher &leading, - StructuredOpMatcher &trailing); - -/// Creates a group of matchers for: -/// -/// trailing( -/// combiner_reduction( -/// parallel_reduction(leading(), parallel_fill()), -/// original_fill()))) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional, and with subsetting ops potentially present on the operand use-def -/// chains. -void makeSplitReductionMatcher(StructuredOpMatcher ¶llel_reduction, - StructuredOpMatcher &combiner_reduction, - StructuredOpMatcher ¶llel_fill, - StructuredOpMatcher &original_fill, - StructuredOpMatcher &leading, - StructuredOpMatcher &trailing); - -} // namespace transform_ext -} // namespace mlir - -#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ diff --git a/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt index 17561a6e7d4d..3cfd7dfa265a 100644 --- a/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt +++ b/integrations/tensorflow/iree-dialects/lib/CAPI/CMakeLists.txt @@ -4,10 +4,6 @@ add_mlir_public_c_api_library(IREEDialectsCAPI MLIRIR MLIRTransformDialect IREEInputDialect - IREELinalgExtDialect - IREELinalgExtTransformOps - IREELinalgTransformDialect - IREELinalgTransformDialectPasses ) iree_dialects_target_includes(IREEDialectsCAPI) diff --git a/integrations/tensorflow/iree-dialects/lib/CAPI/Dialects.cpp b/integrations/tensorflow/iree-dialects/lib/CAPI/Dialects.cpp index 576e626d63fd..0a90e2a30fba 100644 --- a/integrations/tensorflow/iree-dialects/lib/CAPI/Dialects.cpp +++ b/integrations/tensorflow/iree-dialects/lib/CAPI/Dialects.cpp @@ -7,11 +7,6 @@ #include "iree-dialects-c/Dialects.h" #include "iree-dialects/Dialect/Input/InputDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h" -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" -#include "iree-dialects/Dialect/LinalgTransform/Passes.h" -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Pass.h" #include "mlir/CAPI/Registration.h" @@ -32,41 +27,3 @@ using namespace mlir::iree_compiler::IREE; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION( IREEInput, iree_input, mlir::iree_compiler::IREE::Input::IREEInputDialect) -//===--------------------------------------------------------------------===// -// IREELinalgExt -//===--------------------------------------------------------------------===// - -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION( - IREELinalgExt, iree_linalg_ext, - mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect) - -//===--------------------------------------------------------------------===// -// IREELinalgTransform -//===--------------------------------------------------------------------===// - -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION( - IREELinalgTransform, iree_linalg_transform, - mlir::linalg::transform::LinalgTransformDialect) - -void mlirIREELinalgTransformRegisterPasses() { - mlir::linalg::transform::registerTransformDialectInterpreterPass(); - mlir::linalg::transform::registerLinalgTransformExpertExpansionPass(); - mlir::linalg::transform::registerDropSchedulePass(); -} -//===--------------------------------------------------------------------===// -// TransformDialect -//===--------------------------------------------------------------------===// - -void ireeRegisterTransformExtensions(MlirContext context) { - MLIRContext *ctx = unwrap(context); - DialectRegistry registry; - registry.addExtensions< - mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension, - mlir::transform_ext::StructuredTransformOpsExtension>(); - ctx->appendDialectRegistry(registry); -} - -void mlirIREETransformRegisterPasses() { - mlir::linalg::transform::registerDropSchedulePass(); - mlir::linalg::transform::registerTransformDialectInterpreterPass(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/CMakeLists.txt index 76b98c3aea72..47ce6dd5a538 100644 --- a/integrations/tensorflow/iree-dialects/lib/CMakeLists.txt +++ b/integrations/tensorflow/iree-dialects/lib/CMakeLists.txt @@ -1,3 +1,2 @@ add_subdirectory(CAPI) add_subdirectory(Dialect) -add_subdirectory(Transforms) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/CMakeLists.txt index 16d52d437fde..ab1d7407c50a 100644 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/CMakeLists.txt +++ b/integrations/tensorflow/iree-dialects/lib/Dialect/CMakeLists.txt @@ -1,3 +1 @@ add_subdirectory(Input) -add_subdirectory(LinalgExt) -add_subdirectory(LinalgTransform) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt deleted file mode 100644 index 3e119c93901e..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Passes) -add_subdirectory(TransformOps) -add_subdirectory(Transforms) -add_subdirectory(Utils) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt deleted file mode 100644 index 777b9c10f363..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt +++ /dev/null @@ -1,34 +0,0 @@ -add_mlir_library(IREELinalgExtDialect - LinalgExtDialect.cpp - LinalgExtInterfaces.cpp - LinalgExtOps.cpp - - ADDITIONAL_HEADER_DIRS - ${IREE_DIALECTS_SOURCE_DIR}/include - - DEPENDS - IREELinalgExtIncGen - - LINK_LIBS PUBLIC - IREELinalgExtUtils - MLIRAffineDialect - MLIRArithUtils - MLIRDestinationStyleOpInterface - MLIRDialectUtils - MLIRIR - MLIRInferTypeOpInterface - MLIRLinalgDialect - MLIRMathDialect - MLIRMemRefDialect - MLIRPass - MLIRSideEffectInterfaces - MLIRSupport - MLIRSCFDialect - MLIRFuncDialect - MLIRTensorDialect - MLIRTensorUtils - MLIRTilingInterface - MLIRViewLikeInterface -) - -iree_dialects_target_includes(IREELinalgExtDialect) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp deleted file mode 100644 index fe92b3ca0246..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/SourceMgr.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtEnums.cpp.inc" // IWYU pragma: keep - -#define GET_ATTRDEF_CLASSES -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtAttrs.cpp.inc" // IWYU pragma: keep - -void IREELinalgExtDialect::initialize() { - // TODO(hanchung): Add interface to the dialect. - // addInterfaces(); - - addAttributes< -#define GET_ATTRDEF_LIST -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtAttrs.cpp.inc" - >(); - -#define GET_OP_LIST - addOperations< -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc" - >(); -} - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.cpp.inc" - -//==----------------------------------------------------------------------===// -// iree_linalg_ext.encoding -//==----------------------------------------------------------------------===// - -EncodingAttr EncodingAttr::get(MLIRContext *context, TensorEncoding encoding) { - auto tensorEncodingAttr = TensorEncodingAttr::get(context, encoding); - return get(context, tensorEncodingAttr); -} \ No newline at end of file diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp deleted file mode 100644 index b78ba0a8b518..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" - -using namespace mlir; -namespace IREE = mlir::iree_compiler::IREE; -using namespace IREE::LinalgExt; - -LogicalResult -IREE::LinalgExt::detail::verifyLinalgExtOpInterface(Operation *op) { - LinalgExtOp linalgExtOp = cast(op); - if (op->getNumResults()) { - if (op->getNumResults() != linalgExtOp.getNumOutputs()) { - return linalgExtOp.emitOpError( - "expected number of outputs to be same as the number of results"); - } - for (auto en : llvm::enumerate(op->getResultTypes())) { - Type outputType = linalgExtOp.getOutputs()[en.index()].getType(); - if (en.value() != outputType) { - return linalgExtOp.emitOpError("expected type of `outs` operand #") - << en.index() << " " << outputType - << " to be same as result type " << en.value(); - } - } - } - return success(); -} - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc" // IWYU pragma: export - -template -static void getDimValues(OpBuilder &b, Location loc, Value v, Ty t, - SmallVector &dimVals) { - for (auto dim : llvm::enumerate(t.getShape())) { - if (ShapedType::isDynamic(dim.value())) { - dimVals.push_back(b.create(loc, v, dim.index())); - } else { - dimVals.push_back(b.create(loc, dim.value())); - } - } -} - -LogicalResult LinalgExtOp::reifyResultShapes( - OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - Operation *op = getOperation(); - for (auto output : getOutputs()) { - SmallVector dims; - Type outputType = output.getType(); - if (auto rankedTensorType = outputType.dyn_cast()) { - getDimValues(b, op->getLoc(), output, - rankedTensorType, dims); - } else if (auto memrefType = outputType.dyn_cast()) { - getDimValues(b, op->getLoc(), output, - memrefType, dims); - } else if (!outputType.isIntOrIndexOrFloat()) { - return op->emitOpError( - "invalid type for output operand, expected tensor, " - "memref or scalar type"); - } - reifiedReturnShapes.emplace_back(std::move(dims)); - } - return success(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp deleted file mode 100644 index fc0db72f8dd8..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ /dev/null @@ -1,2889 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/Utils.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/FunctionImplementation.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Support/MathExtras.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/SMLoc.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; -namespace IREE = mlir::iree_compiler::IREE; - -//===----------------------------------------------------------------------===// -// Utils. -//===----------------------------------------------------------------------===// - -static void getEffectsImpl( - SmallVectorImpl> - &effects, - ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) { - for (Value value : results) { - effects.emplace_back(MemoryEffects::Allocate::get(), value, - SideEffects::DefaultResource::get()); - } - for (Value value : inputBuffers) { - effects.emplace_back(MemoryEffects::Read::get(), value, - SideEffects::DefaultResource::get()); - } - for (Value value : outputBuffers) { - effects.emplace_back(MemoryEffects::Read::get(), value, - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - SideEffects::DefaultResource::get()); - } -} - -/// Returns a memref.subview or a tensor.extract_slice based on the type of the -/// `source`. -static Value getSlice(OpBuilder &b, Location loc, Value source, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - return TypeSwitch(source.getType()) - .Case([&](RankedTensorType t) -> Value { - return b.create(loc, source, offsets, sizes, - strides); - }) - .Case([&](MemRefType type) -> Value { - return b.create(loc, source, offsets, sizes, - strides); - }) - .Default([&](Type t) { return nullptr; }); -} - -/// Returns true if the dimensions of ShapedType are compatible. -static bool isShapedTypeDimCompatible(int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || - lhs == rhs; -} - -/// Returns true if the dimensions of ShapedType are compatible. -static bool areShapesCompatible(ArrayRef lhs, ArrayRef rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - return llvm::all_of(llvm::zip(lhs, rhs), [](std::tuple it) { - return isShapedTypeDimCompatible(std::get<0>(it), std::get<1>(it)); - }); -} - -/// Return true if `dimsPos` is invalid. It is invalid when: a) it contains -/// duplicate. b) At least one dimension is out of bound (`dimPos` is >= 0 and < -/// rank). c) the number of elements in `dimsPos` is > than `rank`. -static bool isInvalid(ArrayRef dimsPos, int64_t rank) { - // early exit. - if (dimsPos.size() > rank) - return true; - DenseSet uniqued; - for (int64_t dim : dimsPos) - uniqued.insert(dim); - if (dimsPos.size() != uniqued.size()) - return true; - return llvm::any_of( - dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; }); -} - -/// Returns true if the dimension of `sourceShape` is smaller than the dimension -/// of the `limitShape`. -static bool isSmallerThan(ArrayRef sourceShape, - ArrayRef limitShape) { - assert( - sourceShape.size() == limitShape.size() && - "expected source shape rank, and limit of the shape to have same rank"); - return llvm::all_of( - llvm::zip(sourceShape, limitShape), [](std::tuple it) { - int64_t sourceExtent = std::get<0>(it); - int64_t limit = std::get<1>(it); - return sourceExtent == ShapedType::kDynamic || - limit == ShapedType::kDynamic || sourceExtent <= limit; - }); -} - -//===----------------------------------------------------------------------===// -// ScatterOp -//===----------------------------------------------------------------------===// - -LogicalResult ScatterOp::verify() { - Operation *op = getOperation(); - if (getInputs().size() != 2) { - return op->emitOpError("expected two input operands"); - } - if (getOutputs().size() != 1) { - return op->emitOpError("expected one output operand"); - } - auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { - return t1.getShape()[dim] == t2.getShape()[dim]; - }; - - auto indicesType = getIndicesType(); - if (indicesType.getRank() != 2 || - !indicesType.getElementType().isInteger(32)) { - return op->emitOpError( - "expected indices to be of rank 2 of i32 element type"); - } - auto indexDepth = getIndexDepth(); - if (indexDepth == ShapedType::kDynamic) { - return op->emitOpError("expected index depth is static"); - } - - ArrayRef dimMap = getDimensionMap(); - if (dimMap.size() != indexDepth) { - return op->emitOpError("invalid number of dimension map entries "); - } - - auto originalType = getOriginalType(); - if (isInvalid(dimMap, originalType.getRank())) - return op->emitOpError("dimension map is invalid"); - - // The first dimension of the indices should match the first dimension of the - // output. They indicate to the number of updates. - auto updateType = getUpdateType(); - if (updateType.getRank() < 1) { - return op->emitOpError("expected update value to be at least rank 1"); - } - if (!checkDimensionsMatch(indicesType, updateType, 0)) { - return op->emitOpError( - "mismatch in shape of indices and update value at dim#0"); - } - if (updateType.getRank() - 1 > originalType.getRank()) { - return op->emitOpError( - "update value rank exceeds the rank of the original value"); - } - - // indexDepth + update dims should cover the original dims. The first dim of - // update is the number of updates. - if (originalType.getRank() > indexDepth + updateType.getRank() - 1) { - return op->emitOpError( - "index depth and update value does not cover rank of original value"); - } - - // Validate the non-indexed update dims cover the full slice size of the - // original tensor. - int64_t fullSliceDims = originalType.getRank() - indexDepth; - for (auto it : - llvm::zip(llvm::seq(indexDepth, originalType.getRank()), - llvm::seq(updateType.getRank() - fullSliceDims, - updateType.getRank()))) { - int64_t originalDim = std::get<0>(it); - int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) > - originalType.getDimSize(originalDim)) { - return op->emitOpError("shape of update value dim#") - << updateDim << " exceeds original value at dim#" << originalDim; - } - } - - // Check that the remaining update indices do not exceed the update length. - int64_t insertDims = originalType.getRank() - updateType.getRank() + 1; - for (auto it : llvm::zip( - llvm::seq(insertDims, indexDepth), - llvm::seq(1, updateType.getRank() - fullSliceDims))) { - int64_t originalDim = std::get<0>(it); - int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) > - originalType.getDimSize(originalDim)) { - return op->emitOpError("indexed shape of update value dim#") - << updateDim << " exceeds original value at dim#" << originalDim - << " " << updateType.getDimSize(updateDim) << " " - << originalType.getDimSize(originalDim); - } - } - - Region ®ion = this->getRegion(); - Block *body = ®ion.front(); - if (body->getNumArguments() != 2) { - return op->emitOpError("expected region to have two arguments"); - } - Type arg0Type = body->getArgument(0).getType(); - Type arg1Type = body->getArgument(1).getType(); - if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { - return op->emitOpError( - "expected region to have scalar argument of integer or float types"); - } - if (arg0Type != updateType.getElementType()) { - return op->emitOpError("mismatch in argument 0 of region ") - << arg0Type << " and element type of update value " - << updateType.getElementType(); - } - if (arg1Type != originalType.getElementType()) { - return op->emitOpError("mismatch in argument 1 of region ") - << arg1Type << " and element type of original value " - << originalType.getElementType(); - } - if (arg0Type != arg1Type) { - return op->emitOpError("mismatch in region argument types ") - << arg0Type << " and " << arg1Type; - } - auto yieldOp = cast(body->getTerminator()); - if (yieldOp->getNumOperands() != 1) { - return yieldOp.emitOpError("expected region to yield a single value"); - } - auto yieldedType = yieldOp->getOperand(0).getType(); - if (yieldedType != arg0Type) { - return yieldOp.emitOpError("mismatch in type of yielded value ") - << yieldedType << " and argument of the region " << arg0Type; - } - return success(); -} - -SmallVector ScatterOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getUpdateType().getRank(), - utils::IteratorType::parallel); - if (!getUniqueIndices()) { - iteratorTypes[0] = utils::IteratorType::reduction; - } - return iteratorTypes; -} - -SmallVector ScatterOp::getIterationDomain(OpBuilder &builder) { - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - SmallVector ranges; - for (auto dim : llvm::seq(0, getUpdateType().getRank())) { - Value ub = getDimValue(builder, loc, updates(), dim); - ranges.emplace_back(Range{zero, ub, one}); - } - return ranges; -} - -SmallVector -ScatterOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - assert(offsets.size() >= 1 && sizes.size() >= 1); - Location loc = getLoc(); - auto zeroAttr = builder.getI64IntegerAttr(0); - auto oneAttr = builder.getI64IntegerAttr(1); - - // Slice of the updates. - auto updateRank = getUpdateType().getRank(); - SmallVector updateStrides(updateRank, oneAttr); - Value tiledUpdate = - getSlice(builder, loc, updates(), offsets, sizes, updateStrides); - assert(tiledUpdate && "failed to get slice of update"); - - // Slice of indices. - auto indicesRank = getIndicesType().getRank(); - SmallVector indicesOffsets(indicesRank, zeroAttr); - SmallVector indicesSizes(indicesRank); - indicesOffsets[0] = offsets[0]; - indicesSizes[0] = sizes[0]; - for (auto dim : llvm::seq(1, indicesRank)) { - indicesSizes[dim] = getDim(builder, loc, indices(), dim); - } - SmallVector indicesStrides(indicesRank, oneAttr); - Value tiledIndices = getSlice(builder, loc, indices(), indicesOffsets, - indicesSizes, indicesStrides); - assert(tiledIndices && "failed to get slice of indices"); - - // Slice of the original. - SmallVector originalOffsets, originalSizes; - if (failed(getResultTilePosition(builder, 0, offsets, sizes, originalOffsets, - originalSizes))) { - return {}; - } - auto originalRank = getOriginalType().getRank(); - SmallVector originalStrides(originalRank, oneAttr); - Value tiledOriginal = getSlice(builder, loc, original(), originalOffsets, - originalSizes, originalStrides); - assert(tiledOriginal && "failed to get slice of original tensor"); - - SmallVector resultTypes; - if (getNumResults()) { - resultTypes.push_back(tiledOriginal.getType()); - } - Operation *tiledScatterOp = - mlir::clone(builder, getOperation(), resultTypes, - ValueRange{tiledUpdate, tiledIndices, tiledOriginal}); - return {tiledScatterOp}; -} - -LogicalResult ScatterOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - auto zeroAttr = builder.getI64IntegerAttr(0); - // Slice of the original. - auto originalRank = getOriginalType().getRank(); - resultOffsets.resize(originalRank, zeroAttr); - resultSizes.resize(originalRank); - - auto updateRank = getUpdateType().getRank(); - Location loc = getLoc(); - for (auto dim : llvm::seq(0, originalRank - updateRank + 1)) { - resultSizes[dim] = getDim(builder, loc, original(), dim); - } - for (auto dim : - llvm::seq(originalRank - updateRank + 1, originalRank)) { - resultOffsets[dim] = offsets[dim - (originalRank - updateRank)]; - resultSizes[dim] = sizes[dim - (originalRank - updateRank)]; - } - return success(); -} - -LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, - Location loc, - ValueRange ivs) { - auto indexDepth = getIndexDepth(); - Value update = b.create(loc, updates(), ivs); - SmallVector starts; - SmallVector loadIndices; - loadIndices.push_back(ivs.front()); - loadIndices.push_back(Value()); - - // Populate with empty values. - auto originalTy = original().getType().cast(); - starts.resize(originalTy.getRank(), Value()); - auto updateIvs = ivs.drop_front(1); - - int64_t offset = starts.size() - updateIvs.size(); - for (auto it : llvm::enumerate(updateIvs)) { - starts[it.index() + offset] = it.value(); - } - - ArrayRef dimMap = getDimensionMap(); - - for (auto i : llvm::seq(0, indexDepth)) { - loadIndices.back() = b.create(loc, i); - Value idx = b.create(loc, indices(), loadIndices); - Value ret = b.create(loc, b.getIndexType(), idx); - - auto dim = dimMap[i]; - - if (starts[dim]) - ret = b.create(loc, ret, starts[dim]); - starts[dim] = ret; - } - - Value init = b.create(loc, original(), starts); - - BlockAndValueMapping bvm; - Block &block = getRegion().front(); - bvm.map(block.getArgument(0), update); - bvm.map(block.getArgument(1), init); - for (auto &blockOp : block.without_terminator()) { - b.clone(blockOp, bvm); - } - // The last op is linalg_ext.yield op. Store the operand to - // destination. - b.create( - loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)), - original(), starts); - return success(); -} - -LogicalResult -ScatterOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// SortOp -//===----------------------------------------------------------------------===// - -LogicalResult SortOp::verify() { - Operation *op = getOperation(); - if (getNumInputs()) { - return op->emitOpError("does not expect to take any inputs"); - } - if (getNumOutputs() == 0) { - return op->emitOpError("expected at least one `outs` operand"); - } - - Block &block = getRegion().front(); - size_t numOutputs = getNumOutputs(); - if (block.getNumArguments() != 2 * numOutputs) { - return op->emitOpError("region block should have ") - << 2 * numOutputs << " arguments"; - } - - int64_t rank = getOperandRank(); - int sortDim = getDimension(); - if (sortDim < 0 || sortDim >= rank) { - return op->emitOpError("dimension must be within (0, ") << rank << "]"; - } - - ArrayRef shape = getOperandShape(); - for (auto indexedOperand : llvm::enumerate(getOutputs())) { - int index = indexedOperand.index(); - auto operandType = getOperandType(index); - if (operandType.getRank() != rank) { - return op->emitOpError("expected operand ") - << index << " to be rank " << rank << ", same as other operands"; - } - if (operandType.getShape() != shape) { - return op->emitOpError("expected operand ") - << index << " to have same shape as other operands"; - } - Type elemType = operandType.getElementType(); - for (int i : {2 * index, 2 * index + 1}) { - Type argType = block.getArgument(i).getType(); - if (argType != elemType) { - return op->emitOpError("region block argument #") - << i << " should be of type " << elemType << " but got " - << argType; - } - } - } - - auto yieldOp = cast(block.getTerminator()); - if (yieldOp.getNumOperands() != 1) { - return op->emitOpError("should yield exactly one operand"); - } - auto ty = yieldOp.getOperand(0).getType().dyn_cast(); - if (!ty || ty.getWidth() != 1) { - return op->emitOpError("should yield i1 type"); - } - - return success(); -} - -SmallVector SortOp::getLoopIteratorTypes() { - // All loops except the dimension to sort along are parallel. - SmallVector iteratorTypes(getOperandRank(), - utils::IteratorType::parallel); - iteratorTypes[getDimension()] = utils::IteratorType::reduction; - return iteratorTypes; -} - -SmallVector SortOp::getIterationDomain(OpBuilder &builder) { - int64_t operandRank = getOperandRank(); - SmallVector loopBounds(operandRank); - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = operand(0); - for (auto dim : llvm::seq(0, operandRank)) { - loopBounds[dim].offset = zero; - loopBounds[dim].size = getDimValue(builder, loc, source, dim); - loopBounds[dim].stride = one; - } - return loopBounds; -} - -SmallVector -SortOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - int64_t rank = getOperandRank(); - assert(offsets.size() == static_cast(rank) && - sizes.size() == static_cast(rank)); - auto oneAttr = builder.getI64IntegerAttr(1); - SmallVector strides(rank, oneAttr); - SmallVector tiledOperands(getOutputs().size()); - for (auto en : llvm::enumerate(getOutputs())) { - tiledOperands[en.index()] = - getSlice(builder, getLoc(), en.value(), offsets, sizes, strides); - assert(tiledOperands[en.index()] && "failed to get slice of operand"); - } - SmallVector resultTypes; - if (getNumResults()) { - resultTypes = llvm::to_vector<4>( - llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); })); - } - Operation *tiledSortOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return {tiledSortOp}; -} - -LogicalResult SortOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets = llvm::to_vector(offsets); - resultSizes = llvm::to_vector(sizes); - return success(); -} - -LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc, - ValueRange ivs) { - auto sortDim = getDimension(); - SmallVector indices, sortBlkArgs; - indices.append(ivs.begin(), ivs.end()); - // Bubble sort innermost loop. - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); - Value ub; - if (getOperandType(0).isDynamicDim(sortDim)) { - ub = b.create(loc, operand(0), sortDim); - } else { - ub = b.create( - loc, getOperandType(0).getDimSize(sortDim)); - } - ub = b.create(loc, ub, one); - auto scfFor = b.create( - loc, zero, ub, one, ValueRange{}, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iters) { - SmallVector indices(ivs); - Value ivPlusOne = b.create(loc, iv, one); - for (auto output : getOutputOperands()) { - indices[sortDim] = iv; - sortBlkArgs.push_back( - b.create(loc, output->get(), indices)); - indices[sortDim] = ivPlusOne; - sortBlkArgs.push_back( - b.create(loc, output->get(), indices)); - } - }); - - auto &srcBlock = getRegion().front(); - Region ®ion = scfFor.getRegion(); - BlockAndValueMapping bvm; - { - OpBuilder::InsertionGuard guard(b); - auto &block = region.front(); - b.setInsertionPointToEnd(&block); - for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) { - bvm.map(std::get<0>(it), std::get<1>(it)); - } - for (auto &blockOp : srcBlock.without_terminator()) { - b.clone(blockOp, bvm); - } - } - Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)); - - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToEnd(®ion.front()); - b.create( - loc, TypeRange{}, cond, - [&](OpBuilder &b, Location loc) { - // Do not swap the pairs if true. - b.create(loc); - }, - [&](OpBuilder &b, Location loc) { - // Swap the pairs if false. - SmallVector indices(ivs.begin(), ivs.end()); - Value ivPlusOne = - b.create(loc, scfFor.getInductionVar(), one); - for (int i = 0, e = getNumOutputs(); i < e; ++i) { - Value v1 = sortBlkArgs[i * 2]; - Value v2 = sortBlkArgs[i * 2 + 1]; - indices[sortDim] = scfFor.getInductionVar(); - b.create(loc, v2, getOutputOperand(i)->get(), - indices); - indices[sortDim] = ivPlusOne; - b.create(loc, v1, getOutputOperand(i)->get(), - indices); - } - b.create(loc); - }); - b.create(loc); - return success(); -} - -LogicalResult -SortOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// FftOp -//===----------------------------------------------------------------------===// - -LogicalResult FftOp::verify() { - Operation *op = getOperation(); - auto length = getFftLength(); - // After tiling, it could be dynamic shape. (Because - // subview/subtensor does not inference the type correctly - // on (1 << x)) cases). - if (length == ShapedType::kDynamic) - return success(); - if (length & (length - 1)) { - return op->emitOpError("only powers of 2 are handled currently"); - } - if (!getNumInputs() || !isScalar(getInputOperand(0))) { - return op->emitOpError("expected to carry `stage` input"); - } - if (getNumInputs() != 1) { - if (getNumInputs() != 3 || isScalar(getInputOperand(1)) || - isScalar(getInputOperand(2))) { - return op->emitOpError("expected to carry real and imag coeff inputs"); - } - } - if (getNumOutputs() != 2) { - return op->emitOpError( - "expected outputs to be real and imag tensor/memref"); - } - return success(); -} - -SmallVector FftOp::getLoopIteratorTypes() { - // There are `rank-1` outer loops. The fft itselfs has one loop for each - // stage, which handles the merge step -- taking two half size tensors and - // merge them into one tensor. - SmallVector iteratorTypes(getOperandRank(), - utils::IteratorType::parallel); - return iteratorTypes; -} - -SmallVector FftOp::getIterationDomain(OpBuilder &builder) { - SmallVector res; - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - for (auto en : llvm::enumerate(getOperandShape().drop_back())) { - Value size; - if (en.value() == ShapedType::kDynamic) { - size = getDimValue(builder, loc, getReal(), en.index()); - } else { - size = builder.create(loc, en.value()); - } - res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/one}); - } - - Value size = getDimValue(builder, loc, getReal(), getOperandRank() - 1); - Value stride = builder.create(loc, one, getStage()); - res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/stride}); - return res; -} - -void FftOp::generateScalarImplWithoutCoeffBuf(OpBuilder &b, Location loc, - ArrayRef operands, - Value wholeSize) { - auto rank = getOperandRank(); - SmallVector maps(operands.size(), b.getMultiDimIdentityMap(rank)); - - auto f32Type = b.getF32Type(); - auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value { - v = builder.create(loc, builder.getI32Type(), v); - return builder.create(loc, builder.getF32Type(), v); - }; - - // We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part - // first. - Value coeff = b.create( - loc, llvm::APFloat(static_cast(-2 * acos(-1))), f32Type); - coeff = b.create(loc, coeff, indexToF32(b, loc, wholeSize)); - - b.create( - loc, TypeRange{}, ValueRange{}, operands, maps, getLoopIteratorTypes(), - [&](OpBuilder &b, Location loc, ValueRange args) { - Value lhsReal = args[0]; - Value lhsImag = args[1]; - Value rhsReal = args[2]; - Value rhsImag = args[3]; - - // Compute "-2 * PI / m * j" - Value w = b.create( - loc, coeff, - indexToF32(b, loc, b.create(loc, rank - 1))); - Value wReal = b.create(loc, w); - Value wImag = b.create(loc, w); - - // t = w * a[k + j + mh]; - // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i - Value xu = b.create(loc, wReal, rhsReal); - Value yv = b.create(loc, wImag, rhsImag); - Value xv = b.create(loc, wReal, rhsImag); - Value yu = b.create(loc, wImag, rhsReal); - Value tReal = b.create(loc, xu, yv); - Value tImag = b.create(loc, xv, yu); - - // cplx u = a[k + j]; - // a[k + j] = u + t; - // a[k + j + mh] = u - t; - Value r1 = b.create(loc, lhsReal, tReal); - Value r2 = b.create(loc, lhsImag, tImag); - Value r3 = b.create(loc, lhsReal, tReal); - Value r4 = b.create(loc, lhsImag, tImag); - b.create(loc, ValueRange{r1, r2, r3, r4}); - }); -} - -void FftOp::generateScalarImplWithCoeffBuf(OpBuilder &b, Location loc, - ArrayRef operands) { - auto rank = getOperandRank(); - SmallVector maps; - // The size of coefficent buffer is epxected to match `2^(stage-1)`, which - // equals to the last dim of operands. - maps.append( - 2, AffineMap::get(rank, 0, b.getAffineDimExpr(rank - 1), b.getContext())); - maps.append(operands.size(), b.getMultiDimIdentityMap(rank)); - - b.create( - loc, TypeRange{}, ValueRange{getRealCoeff(), getImagCoeff()}, operands, - maps, getLoopIteratorTypes(), - [&](OpBuilder &b, Location loc, ValueRange args) { - Value wReal = args[0]; - Value wImag = args[1]; - Value lhsReal = args[2]; - Value lhsImag = args[3]; - Value rhsReal = args[4]; - Value rhsImag = args[5]; - - // t = w * a[k + j + mh]; - // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i - Value xu = b.create(loc, wReal, rhsReal); - Value yv = b.create(loc, wImag, rhsImag); - Value xv = b.create(loc, wReal, rhsImag); - Value yu = b.create(loc, wImag, rhsReal); - Value tReal = b.create(loc, xu, yv); - Value tImag = b.create(loc, xv, yu); - - // cplx u = a[k + j]; - // a[k + j] = u + t; - // a[k + j + mh] = u - t; - Value r1 = b.create(loc, lhsReal, tReal); - Value r2 = b.create(loc, lhsImag, tImag); - Value r3 = b.create(loc, lhsReal, tReal); - Value r4 = b.create(loc, lhsImag, tImag); - b.create(loc, ValueRange{r1, r2, r3, r4}); - }); -} - -// Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT -// algorithm. The pseudo reference code is: -// let s <- stage of linalg_ext.fft -// int m = 1 << s; -// int mh = m >> 1; -// for (int k = 0; k < n; k += m) { -// for (int j = 0; j < mh; ++j) { -// cplx w = exp(-2 * PI * j / m * I); -// cplx t = w * a[k + j + mh]; -// cplx u = a[k + j]; -// a[k + j] = u + t; -// a[k + j + mh] = u - t; -// } -// } -LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc, - ValueRange ivs) { - Value real = getReal(); - Value imag = getImag(); - Value stage = getStage(); - Value one = b.create(loc, 1); - Value wholeSize = b.create(loc, one, stage); - Value halfSize = b.create(loc, wholeSize, one); - - auto rank = getOperandRank(); - SmallVector operands; - SmallVector lhsIvs(ivs.begin(), ivs.end()); - SmallVector ones(rank, b.getIndexAttr(1)); - SmallVector sizes(rank, b.getIndexAttr(1)); - sizes.back() = halfSize; - operands.push_back( - b.create(loc, real, lhsIvs, sizes, ones)); - operands.push_back( - b.create(loc, imag, lhsIvs, sizes, ones)); - - SmallVector rhsIvs(ivs.begin(), ivs.end()); - rhsIvs.back() = - b.create(loc, ivs.back(), halfSize).getResult(); - operands.push_back( - b.create(loc, real, rhsIvs, sizes, ones)); - operands.push_back( - b.create(loc, imag, rhsIvs, sizes, ones)); - - if (hasCoeff()) { - generateScalarImplWithCoeffBuf(b, loc, operands); - } else { - generateScalarImplWithoutCoeffBuf(b, loc, operands, wholeSize); - } - - return success(); -} - -SmallVector -FftOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - int64_t rank = getOperandRank(); - SmallVector strides(rank, builder.getI64IntegerAttr(1)); - SmallVector tiledOperands(3); - tiledOperands[0] = getStage(); - tiledOperands[1] = getRealCoeff(); - tiledOperands[2] = getImagCoeff(); - SmallVector resultTypes; - - for (auto out : getOutputs()) { - tiledOperands.push_back( - getSlice(builder, getLoc(), out, offsets, sizes, strides)); - if (hasTensorSemantics()) { - resultTypes.push_back(tiledOperands.back().getType()); - } - } - Operation *tiledFftOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return {tiledFftOp}; -} - -LogicalResult FftOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets.assign(offsets.begin(), offsets.end()); - resultSizes.assign(sizes.begin(), sizes.end()); - return success(); -} - -LogicalResult -FftOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// ScanOp -//===----------------------------------------------------------------------===// - -LogicalResult ScanOp::verify() { - Operation *op = getOperation(); - if (getNumInputs() != 1) { - return op->emitOpError("expected one input operands"); - } - if (getNumOutputs() != 2) { - return op->emitOpError("expected two output operands"); - } - if (!input().getType().isa()) { - return op->emitOpError("expected first input element type to be shaped"); - } - auto accumulatorType = accumulator().getType().cast(); - auto inputType = input().getType().cast(); - auto outputType = output().getType().cast(); - ArrayRef inputShapes = inputType.getShape(); - ArrayRef outputShapes = outputType.getShape(); - if (accumulatorType.getElementType() != inputType.getElementType()) { - return op->emitOpError( - "expected input/accumulator element types to be identical"); - } - ArrayRef accumulatorShape = accumulatorType.getShape(); - int64_t accumulatorRank = accumulatorType.getRank(); - if (accumulatorRank != inputType.getRank() - 1) { - return op->emitOpError( - "expected accumulator rank to be equal to input rank - 1"); - } - SmallVector expectedAccumulatorShape; - for (int i = 0; i < inputType.getRank(); i++) { - if (i != getDimension()) - expectedAccumulatorShape.push_back(inputShapes[i]); - } - if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape), - [](std::tuple s) { - return std::get<0>(s) != ShapedType::kDynamic && - std::get<1>(s) != ShapedType::kDynamic && - std::get<0>(s) != std::get<1>(s); - })) { - return op->emitOpError("incompatible input/accumulator shapes"); - } - if (inputType.getElementType() != outputType.getElementType()) { - return op->emitOpError( - "expected input/output element types to be identical"); - } - if (inputShapes.size() != outputShapes.size()) { - return op->emitOpError("expected input/output to have identical ranks"); - } - if (llvm::any_of(llvm::zip(inputShapes, outputShapes), - [](std::tuple s) { - return std::get<0>(s) != ShapedType::kDynamic && - std::get<1>(s) != ShapedType::kDynamic && - std::get<0>(s) != std::get<1>(s); - })) { - return op->emitOpError("incompatible input/output shapes"); - } - return success(); -} - -SmallVector ScanOp::getIterationDomain(OpBuilder &builder) { - int64_t operandRank = getOperandRank(); - SmallVector loopBounds(operandRank); - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = input(); - for (auto dim : llvm::seq(0, operandRank)) { - loopBounds[dim].offset = zero; - loopBounds[dim].size = getDimValue(builder, loc, source, dim); - loopBounds[dim].stride = one; - } - return loopBounds; -} - -SmallVector ScanOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getOperandRank(), - utils::IteratorType::parallel); - iteratorTypes[getDimension()] = utils::IteratorType::reduction; - return iteratorTypes; -} - -// Generates naive scalar implementation of scan for a given operator f. -// For inclusive, -// output[0] = input[0] -// output[i] = f(output[i-1], input[i]) -// -// For exclusive, -// output[0] = 0 -// output[i] = f(output[i-1], input[i-1]) - -LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, - ValueRange ivs) { - SmallVector indices, scanBlkArgs; - indices.append(ivs.begin(), ivs.end()); - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); - auto scanDim = getDimension(); - auto cond = b.create(loc, arith::CmpIPredicate::eq, - indices[scanDim], zero); - bool isInclusive = getInclusive(); - SmallVector accIndices; - for (int i = 0; i < indices.size(); i++) { - if (i != scanDim) - accIndices.push_back(indices[i]); - } - - auto scfIf = b.create( - loc, TypeRange{}, cond, - [&](OpBuilder &b, Location loc) { - if (isInclusive) { - auto value = b.create(loc, input(), indices); - b.create(loc, value, output(), indices); - } else { - auto value = b.create(loc, accumulator(), accIndices); - b.create(loc, value, output(), indices); - } - b.create(loc); - }, - [&](OpBuilder &b, Location loc) { - SmallVector indices(ivs.begin(), ivs.end()); - Value iv = indices[scanDim]; - Value ivMinusOne = b.create(loc, iv, one); - indices[scanDim] = ivMinusOne; - scanBlkArgs.push_back(b.create(loc, output(), indices)); - Value i0; - if (!isInclusive) - i0 = b.create(loc, input(), indices); - indices[scanDim] = iv; - if (isInclusive) - i0 = b.create(loc, input(), indices); - scanBlkArgs.push_back(i0); - }); - - auto &srcBlock = getRegion().front(); - Region ®ion = scfIf.getElseRegion(); - BlockAndValueMapping bvm; - { - OpBuilder::InsertionGuard guard(b); - auto &block = region.front(); - b.setInsertionPointToEnd(&block); - for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) { - bvm.map(std::get<0>(it), std::get<1>(it)); - } - for (auto &blockOp : srcBlock.without_terminator()) { - b.clone(blockOp, bvm); - } - b.create( - loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), - output(), indices); - b.create( - loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), - accumulator(), accIndices); - b.create(loc); - } - return success(); -} - -SmallVector -ScanOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - int64_t rank = getOperandRank(); - assert(offsets.size() == static_cast(rank) && - sizes.size() == static_cast(rank)); - auto oneAttr = builder.getI64IntegerAttr(1); - SmallVector strides(rank, oneAttr); - SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, getLoc(), input(), offsets, sizes, strides)); - tiledOperands.emplace_back( - getSlice(builder, getLoc(), getOutputs()[0], offsets, sizes, strides)); - if (rank > 1) { - SmallVector accumOffsets, accumSizes; - if (failed(getResultTilePosition(builder, 1, offsets, sizes, accumOffsets, - accumSizes))) { - return {}; - } - SmallVector accumStrides(rank - 1, oneAttr); - tiledOperands.emplace_back(getSlice(builder, getLoc(), getOutputs()[1], - accumOffsets, accumSizes, - accumStrides)); - } else { - tiledOperands.emplace_back(getOutputs()[1]); - } - - SmallVector resultTypes; - if (hasTensorSemantics()) { - resultTypes.push_back(tiledOperands[1].getType()); - resultTypes.push_back(tiledOperands[2].getType()); - } - - Operation *tiledScanOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return {tiledScanOp}; -} - -LogicalResult ScanOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - if (resultNumber == 0) { - resultOffsets.assign(offsets.begin(), offsets.end()); - resultSizes.assign(sizes.begin(), sizes.end()); - return success(); - } - if (resultNumber == 1) { - int64_t rank = getOperandRank(); - if (rank > 1) { - for (auto i : llvm::seq(0, rank)) { - if (i == getDimension()) - continue; - resultOffsets.push_back(offsets[i]); - resultSizes.push_back(sizes[i]); - } - } - return success(); - } - return failure(); -} - -LogicalResult ScanOp::fold(ArrayRef, - SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -LogicalResult -ScanOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// ReverseOp -//===----------------------------------------------------------------------===// - -LogicalResult ReverseOp::verify() { - Operation *op = getOperation(); - if (getNumInputs() != 1) { - return op->emitOpError("expected exactly one input"); - } - if (getNumOutputs() != 1) { - return op->emitOpError("expected exactly one output"); - } - auto inputType = input().getType().cast(); - auto outputType = output().getType().cast(); - if (inputType.getElementType() != outputType.getElementType()) { - return op->emitOpError( - "expected input/output element types to be identical"); - } - ArrayRef inputShapes = inputType.getShape(); - ArrayRef outputShapes = outputType.getShape(); - if (inputShapes.size() != outputShapes.size()) { - return op->emitOpError("expexted input/output to have identical ranks"); - } - if (llvm::any_of(llvm::zip(inputShapes, outputShapes), - [](std::tuple s) { - return std::get<0>(s) != ShapedType::kDynamic && - std::get<1>(s) != ShapedType::kDynamic && - std::get<0>(s) != std::get<1>(s); - })) { - return op->emitOpError("incompatible input/output shapes"); - } - - int64_t rank = getOperandRank(); - llvm::SmallSetVector s; - for (auto dim : dims()) { - if (dim < 0 || dim >= rank) { - return op->emitOpError("all the dimensions must be within [0, ") - << rank << ")"; - } - if (s.contains(dim)) { - return op->emitOpError("expected dimensions numbers are all unique"); - } - s.insert(dim); - } - - return success(); -} - -SmallVector ReverseOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getOperandRank(), - utils::IteratorType::parallel); - return iteratorTypes; -} - -SmallVector ReverseOp::getIterationDomain(OpBuilder &builder) { - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - SmallVector ranges; - for (auto dim : llvm::seq(0, getOperandRank())) { - Value ub = getDimValue(builder, loc, input(), dim); - ranges.emplace_back(Range{zero, ub, one}); - } - return ranges; -} - -LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b, - Location loc, - ValueRange ivs) { - SmallVector mirrorIndices(ivs.begin(), ivs.end()); - for (auto dim : dims()) { - auto size = getDimValue(b, loc, input(), dim); - size = b.create(loc, size, - b.create(loc, 1)); - mirrorIndices[dim] = b.create(loc, size, mirrorIndices[dim]); - } - Value val = b.create(loc, input(), ivs); - b.create(loc, val, output(), mirrorIndices); - return success(); -} - -SmallVector -ReverseOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - int64_t rank = getOperandRank(); - SmallVector strides(rank, builder.getI64IntegerAttr(1)); - Location loc = getLoc(); - SmallVector mirrorOffsets, mirrorSizes; - if (failed(getResultTilePosition(builder, 0, offsets, sizes, mirrorOffsets, - mirrorSizes))) { - return {}; - } - - SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, loc, input(), offsets, sizes, strides)); - - SmallVector resultTypes; - if (hasTensorSemantics()) { - tiledOperands.emplace_back( - getSlice(builder, loc, output(), mirrorOffsets, sizes, strides)); - resultTypes.push_back(tiledOperands[1].getType()); - } else { - tiledOperands.emplace_back( - getSlice(builder, loc, output(), mirrorOffsets, sizes, strides)); - } - - Operation *tiledRevOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - - return {tiledRevOp}; -} - -LogicalResult ReverseOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - AffineExpr sym0, sym1, sym2; - bindSymbols(builder.getContext(), sym0, sym1, sym2); - AffineMap map = - AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2}); - resultOffsets.assign(offsets.begin(), offsets.end()); - Location loc = getLoc(); - for (auto dim : dims()) { - Value size = getDimValue(builder, loc, input(), dim); - Value offset = - getValueOrCreateConstantIndexOp(builder, loc, resultOffsets[dim]); - Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]); - resultOffsets[dim] = - builder - .create(loc, map, ValueRange{size, offset, tileSize}) - .getResult(); - } - resultSizes.assign(sizes.begin(), sizes.end()); - return success(); -} - -LogicalResult -ReverseOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// TopkOp -//===----------------------------------------------------------------------===// - -LogicalResult TopkOp::verify() { - Operation *op = getOperation(); - if (getNumInputs() != 1 && getNumInputs() != 2) { - return op->emitOpError("expected one or two input operands"); - } - if (getNumOutputs() != 2) { - return op->emitOpError("expected two output operands"); - } - if (getDimension() >= getInputRank()) { - return op->emitOpError("dimension exceeds rank"); - } - // Ensure input/output element types match - auto inputValuesType = values().getType().cast(); - auto outputValuesType = outputValues().getType().cast(); - if (inputValuesType.getElementType() != outputValuesType.getElementType()) { - return op->emitOpError("expected input/output value types to be identical"); - } - // Indices must be int if provided - auto outputIndicesType = outputIndices().getType().cast(); - if (auto inputIndices = indices()) { - auto inputIndicesType = inputIndices->getType().cast(); - if (!inputIndicesType.getElementType().isInteger(32) || - !outputIndicesType.getElementType().isInteger(32)) { - return op->emitOpError("expected input/output indices types to be int32"); - } - } - - // Ranks must match - if (inputValuesType.getRank() != outputValuesType.getRank()) { - return op->emitOpError("expected input/output to have the same rank"); - } - if (auto inputIndices = indices()) { - auto inputIndicesType = inputIndices->getType().cast(); - if (inputIndicesType.getRank() != outputIndicesType.getRank()) { - return op->emitOpError("expected input/output to have the same rank"); - } - } - // Input indicies and values must have the same shape. - if (auto inputIndices = indices()) { - auto inputIndicesType = inputIndices->getType().cast(); - if (!areShapesCompatible(inputValuesType.getShape(), - inputIndicesType.getShape())) - return op->emitOpError("input indices/values shape must match"); - } - // Output indicies and values must have the same shape. - if (!areShapesCompatible(outputValuesType.getShape(), - outputIndicesType.getShape())) - return op->emitOpError("output indices/values shape must match"); - // Input shape must match the output shape except for the dimension() - uint64_t dim = getDimension(); - if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(), - outputValuesType.getShape())), - [dim](auto e) { - if (e.index() == dim) { - return true; - } - std::tuple s = e.value(); - return isShapedTypeDimCompatible(std::get<0>(s), - std::get<1>(s)); - })) { - return op->emitOpError("incompatible input/output shapes"); - } - // Check region compatibility - Block &block = getRegion().front(); - if (block.getNumArguments() != 2) { - return op->emitOpError("region block should have 2 arguments"); - } - if (block.getArgument(0).getType() != inputValuesType.getElementType() || - block.getArgument(1).getType() != inputValuesType.getElementType()) { - return op->emitOpError("region block types must match input"); - } - auto terminatorOp = llvm::dyn_cast(block.getTerminator()); - if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) { - return op->emitOpError("region block must end with a linalg_ext.yield i1!"); - } - return success(); -} - -SmallVector TopkOp::getIterationDomain(OpBuilder &builder) { - int64_t operandRank = getInputRank(); - SmallVector loopBounds(operandRank); - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = values(); - for (auto dim : llvm::enumerate(getInputType().getShape())) { - loopBounds[dim.index()].offset = zero; - loopBounds[dim.index()].size = - getDimValue(builder, loc, source, dim.index()); - loopBounds[dim.index()].stride = one; - } - return loopBounds; -} - -SmallVector TopkOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getInputRank(), - utils::IteratorType::parallel); - iteratorTypes[getDimension()] = utils::IteratorType::reduction; - return iteratorTypes; -} - -LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, - ValueRange ivs) { - uint64_t kDim = getDimension(); - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); - Value initialValue = b.create(loc, values(), ivs); - - // If the indices tensor is not provided, the value index is derived from the - // loop induction variables. - Value initialIndex; - if (indices()) { - initialIndex = b.create(loc, *indices(), ivs); - } else { - Value rawInitialIndex = ivs[kDim]; - initialIndex = - b.create(loc, b.getI32Type(), rawInitialIndex); - } - - // Compute K (ub) from the selected dim of the output - Value ub = b.create(loc, outputValues(), getDimension()); - - // Inner K loop functions: - // Load current K value and index - // Compare N/K using inserted block compare - // Check if N == K using strict weak ordering, select which index came first - // Select new K value from N/K comparison - // Select new K index from N/K comparison or which index came first - // Store new k value and index - // Yield loop carry values after K selection - Value kValue, kIndex; - auto scfFor = b.create( - loc, zero, ub, one, ValueRange{initialValue, initialIndex}, - [&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) { - SmallVector indices(ivs); - indices[kDim] = iv; - kValue = b.create(loc, outputValues(), indices); - kIndex = b.create(loc, outputIndices(), indices); - }); - - SmallVector indices(ivs); - indices[kDim] = scfFor.getInductionVar(); - auto loopCarryValues = scfFor.getRegionIterArgs(); - - // Retrieve region as black box comparision function f(x,y). Plug into op. - auto &srcBlock = getRegion().front(); - BlockAndValueMapping bvmF; // f(x,y) - BlockAndValueMapping bvmR; // f(y,x) - { - // Save previous insertion point. Continue within loop body. - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToEnd(&scfFor.getRegion().front()); - SmallVector forwardValues{loopCarryValues[0], kValue}; - SmallVector reverseValues{kValue, loopCarryValues[0]}; - for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) { - bvmF.map(std::get<0>(it), std::get<1>(it)); - } - for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) { - bvmR.map(std::get<0>(it), std::get<1>(it)); - } - for (auto &blockOp : srcBlock.without_terminator()) { - b.clone(blockOp, bvmF); - b.clone(blockOp, bvmR); - } - Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0)); - Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0)); - - // Check value equality using strictly weak ordering from the region: - // f(x,y) --> forwardCmpRes - // f(y,x) --> reverseCmpRes - // if forwardCmpRes == reverseCmpRes then select which came first - Value cmpValuesEqual = b.create( - loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes); - Value cmpFirstIndex = b.create( - loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex); - Value combinedCmpEqRes = - b.create(loc, cmpValuesEqual, cmpFirstIndex); - // True if N > K or N came before K - Value indexCmpRes = - b.create(loc, forwardCmpRes, combinedCmpEqRes); - // Select results for K based on comparisons - Value resultKValue = b.create(loc, forwardCmpRes, - loopCarryValues[0], kValue); - Value resultKIndex = - b.create(loc, indexCmpRes, loopCarryValues[1], kIndex); - b.create(loc, resultKValue, outputValues(), indices); - b.create(loc, resultKIndex, outputIndices(), indices); - // Select loop carry, opposite of K results - Value resultCarryValue = b.create( - loc, forwardCmpRes, kValue, loopCarryValues[0]); - Value resultCarryIndex = - b.create(loc, indexCmpRes, kIndex, loopCarryValues[1]); - b.create(loc, ValueRange{resultCarryValue, resultCarryIndex}); - } - return success(); -} - -SmallVector -TopkOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - int64_t rank = getInputRank(); - assert(offsets.size() == static_cast(rank) && - sizes.size() == static_cast(rank)); - SmallVector strides(rank, builder.getI64IntegerAttr(1)); - Location loc = getLoc(); - - SmallVector outputOffsets, outputSizes; - if (failed(getResultTilePosition(builder, 0, offsets, sizes, outputOffsets, - outputSizes))) { - return {}; - } - - SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, loc, values(), offsets, sizes, strides)); - if (indices()) { - tiledOperands.emplace_back( - getSlice(builder, loc, *indices(), offsets, sizes, strides)); - } - - // Replace the tile size for the K dimension to use the output size instead of - // the input size. - Value kSize = getDimValue(builder, getLoc(), outputValues(), getDimension()); - outputSizes[getDimension()] = getAsOpFoldResult(kSize); - - tiledOperands.emplace_back( - getSlice(builder, loc, getOutputs()[0], offsets, outputSizes, strides)); - tiledOperands.emplace_back( - getSlice(builder, loc, getOutputs()[1], offsets, outputSizes, strides)); - SmallVector resultTypes; - if (hasTensorSemantics()) { - resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType()); - resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType()); - } - - Operation *tiledTopkOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return {tiledTopkOp}; -} - -LogicalResult TopkOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets.assign(offsets.begin(), offsets.end()); - resultSizes.assign(sizes.begin(), sizes.end()); - Value kSize = getDimValue( - builder, getLoc(), getOutputOperand(resultNumber)->get(), getDimension()); - resultSizes[getDimension()] = getAsOpFoldResult(kSize); - return success(); -} - -LogicalResult -TopkOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// PackOp and UnPackOp utils -//===----------------------------------------------------------------------===// - -/// Return true if at least one element in `tiles` is zero. -static bool hasZeros(ArrayRef tiles) { - return llvm::any_of( - tiles, [&](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); -} - -/// Check if we have enough static information to catch undefined behavior when -/// the tile size does not divide perfectly the dimension of the input tensor. -static bool -areNotFullTiles(ArrayRef inputShape, - DenseMap const &dimAndTileMapping) { - int64_t rank = inputShape.size(); - for (int64_t dim = 0; dim < rank; dim++) { - if (inputShape[dim] == ShapedType::kDynamic) - continue; - auto it = dimAndTileMapping.find(dim); - if (it != dimAndTileMapping.end()) { - Optional constantTile = getConstantIntValue(it->second); - if (!constantTile) - continue; - if (inputShape[dim] % (*constantTile) != 0) - return true; - } - } - return false; -} - -/// Utility function shared between Pack and UnPack to get the tile sizes as -/// OpFoldResults. -// TODO: interface or base class in .td -template -static SmallVector getMixedTiles(OpTy op) { - static_assert(llvm::is_one_of::value, - "applies to only pack or unpack operations"); - SmallVector mixedInnerTiles; - unsigned dynamicValIndex = 0; - OpBuilder b(op.getContext()); - for (int64_t tileSize : op.getStaticInnerTiles()) { - if (!ShapedType::isDynamic(tileSize)) - mixedInnerTiles.push_back(b.getIndexAttr(tileSize)); - else - mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); - } - return mixedInnerTiles; -} - -/// Return the tile sizes as `int64_t`. If a tile size is dynamic a sentinel -/// `kDynamic` is introduced at that position in the returned vector. -template -static SmallVector getStaticTiles(OpTy op) { - static_assert(llvm::is_one_of::value, - "applies to only pack or unpack operations"); - SmallVector dynamicTiles; - SmallVector staticTiles; - dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles, - ShapedType::kDynamic); - return staticTiles; -} - -/// Utility function shared between Pack and UnPack to get a map between -/// `dim_pos` and `inner_tiles`. -// TODO: interface or base class in .td -template -static DenseMap getDimAndTileMapping(OpTy op) { - static_assert(llvm::is_one_of::value, - "applies to only pack or unpack operations"); - DenseMap dimAndTileMapping; - ArrayRef dimsToBlock = op.getInnerDimsPos(); - SmallVector tiles = op.getMixedTiles(); - assert(tiles.size() == dimsToBlock.size() && - "tiles must match indices of dimension to block"); - // bind the dimension with the tile factor. - for (auto i : llvm::seq(0, dimsToBlock.size())) - dimAndTileMapping[dimsToBlock[i]] = tiles[i]; - return dimAndTileMapping; -} - -/// Utility function to build the iteration domain for `packOp` or `unPackOp`. -template -static SmallVector getIterationDomain(OpTy op, OpBuilder &builder) { - static_assert(llvm::is_one_of::value, - "applies to only pack or unpack operations"); - OpBuilder::InsertionGuard g(builder); - Location loc = op.getLoc(); - int64_t rank = (std::is_same::value) ? op.getInputRank() - : op.getOutputRank(); - SmallVector loopBounds(rank); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - ReifiedRankedShapedTypeDims resultShape; - (void)op.reifyResultShapes(builder, resultShape); - for (auto dim : llvm::seq(0, rank)) { - loopBounds[dim].offset = zero; - loopBounds[dim].stride = one; - loopBounds[dim].size = resultShape[0][dim]; - } - return loopBounds; -} - -/// Common verifier for `PackOp` and `UnPackOp`. -template -static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { - static_assert(llvm::is_one_of::value, - "applies to only pack or unpack operations"); - Operation *op = packOrUnPack.getOperation(); - ShapedType unpackedType = (std::is_same::value) - ? packOrUnPack.getInputType() - : packOrUnPack.getOutputType(); - int64_t unpackedRank = unpackedType.getRank(); - ArrayRef innerDimsPos = packOrUnPack.getInnerDimsPos(); - ArrayRef outerDimPerm = packOrUnPack.getOuterDimsPerm(); - // Verify tiles. Make sure each provided tile is non-zero. - SmallVector mixedTiles = packOrUnPack.getMixedTiles(); - if (hasZeros(mixedTiles)) - return op->emitError("invalid tile factor"); - if (isInvalid(innerDimsPos, unpackedRank)) - return op->emitError("invalid inner_dims_pos vector"); - if (isInvalid(outerDimPerm, unpackedRank)) - return op->emitError("invalid outer_dims_perm vector"); - if (mixedTiles.size() != innerDimsPos.size()) { - return op->emitError( - "blocking factors must equal the number of dimensions to block"); - } - - // Blocking factors must be less or equal than the input rank, and must - // match the number of `dims_pos`. - if (mixedTiles.size() > unpackedRank) { - return op->emitError( - "blocking factors must be less or equal than the input rank"); - } - - ShapedType packedType = (std::is_same::value) - ? packOrUnPack.getOutputType() - : packOrUnPack.getInputType(); - int64_t packedRank = packedType.getRank(); - // Require output rank to match input rank + number of blocking factors. - if (unpackedRank + mixedTiles.size() != packedRank) { - return op->emitError( - "packed rank must equal unpacked rank + blocking factors"); - } - - // Verify result shape is greater than the minimum expected - // by the pack operation, and that the output shape - // represents full tiles. - ShapedType expectedPackedType = PackOp::getPackedType( - unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); - if (!isSmallerThan(expectedPackedType.getShape(), packedType.getShape())) { - return op->emitError("the shape of output is not large enough to hold the " - "packed data. Expected at least ") - << expectedPackedType << ", got " << packedType; - } - if (!llvm::all_of( - llvm::zip(packedType.getShape().take_back(mixedTiles.size()), - mixedTiles), - [](std::tuple it) { - Optional constTileSize = - getConstantIntValue(std::get<1>(it)); - int64_t shape = std::get<0>(it); - if (!constTileSize) { - // If specified tile size is dynamic, output shape should - // be dynamic too. - return shape == ShapedType::kDynamic; - } else { - if (shape == ShapedType::kDynamic) { - // For the shape being dynamic when tile size is - // specified, return true. In canonical form a constant - // tile size should lead to constant shape of the tiled - // dimension, but not needed for verification. - return true; - } - return shape == constTileSize.value(); - } - })) { - return op->emitError("mismatch in inner tile sizes specified and shaped of " - "tiled dimension in the packed type"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// PackOp -//===----------------------------------------------------------------------===// - -/// Custom builder methods for pack ops. -void PackOp::build(OpBuilder &builder, OperationState &state, Value source, - Value output, ArrayRef innerDimsPos, - ArrayRef innerTiles, - Optional paddingValue, - ArrayRef outerDimsPerm) { - assert(innerDimsPos.size() == innerTiles.size() && - "number of tile sizes specified must match the specified number of " - "original dimensions to be tiled"); - SmallVector staticTileSizes; - SmallVector dynamicTileSizes; - dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes, - ShapedType::kDynamic); - build(builder, state, output.getType(), source, output, - outerDimsPerm.empty() ? nullptr - : builder.getDenseI64ArrayAttr(outerDimsPerm), - builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes, - builder.getDenseI64ArrayAttr(staticTileSizes), - (paddingValue ? paddingValue.value() : nullptr)); -} - -LogicalResult PackOp::verify() { - if (failed(commonVerifierPackAndUnPackOp(*this))) { - return failure(); - } - - // Bail out if the tile does not divide the dimension fully. In the case of - // dynamic tile factors or dimensions, having a partial tile is undefined - // behavior. - auto dimAndTileMapping = getDimAndTileMapping(); - if (!getPaddingValue() && - areNotFullTiles(getInputShape(), dimAndTileMapping)) { - return emitOpError("invalid tile factor provided. Only full tiles are " - "supported when padding_value is not set"); - } - - if (auto paddingValue = getPaddingValue()) { - if (paddingValue.getType() != getInputType().getElementType()) { - return emitOpError("expected padding_value has ") - << getInputType().getElementType() - << " but got: " << paddingValue.getType(); - } - } - return success(); -} - -SmallVector PackOp::getMixedTiles() { - return ::getMixedTiles(*this); -} - -SmallVector PackOp::getStaticTiles() { - return ::getStaticTiles(*this); -} - -SmallVector PackOp::getResultShape( - OpBuilder &builder, Location loc, ArrayRef sourceDims, - ArrayRef innerTileSizes, ArrayRef innerDimsPos, - ArrayRef outerDimsPerm) { - SmallVector resultDims = llvm::to_vector(sourceDims); - - AffineExpr s0, s1; - bindSymbols(builder.getContext(), s0, s1); - AffineExpr ceilDivExpr = s0.ceilDiv(s1); - for (auto tiledDim : llvm::enumerate(innerDimsPos)) { - resultDims[tiledDim.value()] = makeComposedFoldedAffineApply( - builder, loc, ceilDivExpr, - {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]}); - } - if (!outerDimsPerm.empty()) { - resultDims = - interchange(resultDims, outerDimsPerm, /*offset=*/0); - } - resultDims.append(innerTileSizes.begin(), innerTileSizes.end()); - return resultDims; -} - -SmallVector PackOp::getResultShape(OpBuilder &builder) { - return tensor::createDimValues(builder, getLoc(), getOutput()); -} - -ShapedType PackOp::getPackedType(ShapedType sourceType, - ArrayRef innerTileSizes, - ArrayRef innerDimsPos, - ArrayRef outerDimsPerm) { - SmallVector resultShape = llvm::to_vector(sourceType.getShape()); - for (auto tiledDim : llvm::enumerate(innerDimsPos)) { - if (ShapedType::isDynamic(resultShape[tiledDim.value()])) - continue; - if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { - resultShape[tiledDim.value()] = ShapedType::kDynamic; - continue; - } - resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()], - innerTileSizes[tiledDim.index()]); - } - - // Swap tile loops if outer_dims_perm is available. - resultShape = interchange(resultShape, outerDimsPerm, /*offset=*/0); - - // Append the inner tile dimensions. - resultShape.append(innerTileSizes.begin(), innerTileSizes.end()); - return TypeSwitch(sourceType) - .Case([&](auto shapedType) { - return RankedTensorType::get(resultShape, shapedType.getElementType()); - }) - .Case([&](auto shapedType) { - return MemRefType::get(resultShape, shapedType.getElementType()); - }) - .Default([&](Type t) { - assert(false && "unexpected type"); - return nullptr; - }); -} - -SmallVector PackOp::getLoopIteratorTypes() { - // Note that here we consider only the tiled loops, the point loops are - // materialized when building the body of the operation. - SmallVector iteratorTypes(getInputRank(), - utils::IteratorType::parallel); - return iteratorTypes; -} - -DenseMap PackOp::getDimAndTileMapping() { - return ::getDimAndTileMapping(*this); -} - -SmallVector PackOp::getIterationDomain(OpBuilder &builder) { - return ::getIterationDomain(*this, builder); -} - -/// Generate the body of the innermost loop of the scalar implementation -/// of `pack` operation. -static void generatePackOpScalarImplementationBody(PackOp packOp, - OpBuilder &builder, - Location loc, - ValueRange ivs) { - // Note: `ivs` are already in the correct order, possibly interchanged based - // on `dims_pos`. However, connecting the loops with the access patterns is - // difficult - What is the relation between the position of the tile loop and - // the point loop? However, if we interchange `ivs` once more to go to the - // canonical blocking format: ABCabc, this connection becomes trivial: Each - // point loop is pointLoopsOffset + inputRank away from the tiled loop. - ArrayRef dimsToInnerBlock = packOp.getInnerDimsPos(); - ArrayRef dimsToOuterBlock = packOp.getOuterDimsPerm(); - - SmallVector interchangedIvs = ivs; - SmallVector interchangeVector = - computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getInputRank()); - interchangedIvs = interchange(interchangedIvs, interchangeVector, - /*offset=*/packOp.getInputRank()); - if (!dimsToOuterBlock.empty()) { - interchangeVector = - computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getInputRank()); - interchangedIvs = - interchange(interchangedIvs, interchangeVector, /*offset=*/0); - } - - SmallVector tiles = packOp.getMixedTiles(); - DenseMap dimAndTileMapping = - packOp.getDimAndTileMapping(); - SmallVector sourceIndices; - size_t pointLoopsOffset = 0; - int64_t inputRank = packOp.getInputRank(); - for (auto dim : llvm::seq(0, inputRank)) { - if (dimAndTileMapping.count(dim)) { - AffineExpr i, j, tile; - bindDims(builder.getContext(), i, j); - bindSymbols(builder.getContext(), tile); - OpFoldResult sourceIndex = makeComposedFoldedAffineApply( - builder, loc, i * tile + j, - ArrayRef{ - interchangedIvs[dim], - interchangedIvs[pointLoopsOffset + packOp.getInputRank()], - dimAndTileMapping[dim]}); - sourceIndices.push_back(sourceIndex); - ++pointLoopsOffset; - } else { - sourceIndices.push_back(interchangedIvs[dim]); - } - } - - auto createLoad = [&]() -> Value { - return builder.create( - loc, packOp.getInput(), getAsValues(builder, loc, sourceIndices)); - }; - Value scalar; - if (auto paddingValue = packOp.getPaddingValue()) { - ArithBuilder arithBuilder(builder, loc); - Value isInBounds; - for (auto dim : llvm::seq(0, inputRank)) { - Value idx = - getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]); - Value cond = arithBuilder.slt( - idx, getDimValue(builder, loc, packOp.getInput(), dim)); - isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond); - } - scalar = builder - .create( - loc, packOp.getElementType(), isInBounds, /*thenBuilder=*/ - [&](OpBuilder &b, Location l) { - b.create(l, createLoad()); - }, - /*elseBuilder=*/ - [&](OpBuilder &b, Location l) { - b.create(l, paddingValue); - }) - .getResult(0); - } else { - scalar = createLoad(); - } - - builder.create(loc, scalar, packOp.getOutput(), ivs); -} - -LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder, - Location loc, - ValueRange ivs) { - OpBuilder::InsertionGuard g(builder); - // The `ivs` already represent the position into the output tensor for the - // non data-tile dimensions. - SmallVector ivVec = llvm::to_vector(ivs); - ReifiedRankedShapedTypeDims outputShape; - if (failed(reifyResultShapes(builder, outputShape))) - return getOperation()->emitOpError("failed to reify result shape"); - if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) { - return getOperation()->emitOpError( - "expected shape of one result value of rank") - << getOutputRank(); - } - - // Generate the loops that iterate over the data tile. - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - - // All loops except the innermost are simple loops that just iterate - // over the tile dimensions. - for (auto dataTileDim : - llvm::seq(getInputRank(), getOutputRank() - 1)) { - Value ub = outputShape[0][dataTileDim]; - scf::ForOp loop = builder.create(loc, zero, ub, one); - builder.setInsertionPointToStart(loop.getBody()); - ivVec.push_back(loop.getInductionVar()); - } - // The body of the innermost loops does the actual data movement. - builder.create(loc, zero, outputShape[0].back(), one, - ValueRange{}, - [&](OpBuilder &bodyBuilder, Location bodyLoc, - Value iv, ValueRange regionIterArgs) { - ivVec.push_back(iv); - generatePackOpScalarImplementationBody( - *this, bodyBuilder, bodyLoc, ivVec); - bodyBuilder.create(bodyLoc); - }); - return success(); -} - -SmallVector -PackOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - Location loc = getLoc(); - auto ctx = builder.getContext(); - - // Take the minimum of two integers. - auto idMap = AffineMap::getMultiDimIdentityMap(2, ctx); - auto min = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { - return makeComposedFoldedAffineMin(builder, loc, idMap, {v1, v2}); - }; - // Subtract two integers. - AffineExpr dim0, dim1; - bindDims(ctx, dim0, dim1); - auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); - auto sub = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { - return makeComposedFoldedAffineApply(builder, loc, subMap, {v1, v2}); - }; - - // The tiling is applied on interchanged dimensions. We have to undo the - // interchange to map sizes and offsets to the original input. - ArrayRef dimsToOuterBlock = getOuterDimsPerm(); - SmallVector origOffsets(offsets.begin(), offsets.end()); - SmallVector origSizes(sizes.begin(), sizes.end()); - if (!dimsToOuterBlock.empty()) { - SmallVector vec = - computeInterchangeFromDimPos(dimsToOuterBlock, getInputRank()); - origOffsets = undoInterchange(origOffsets, vec); - origSizes = undoInterchange(origSizes, vec); - } - - int64_t inputRank = getInputRank(); - DenseMap dimAndTileMapping = getDimAndTileMapping(); - SmallVector inputIndices, inputSizes; - for (auto dim : llvm::seq(0, inputRank)) { - if (dimAndTileMapping.count(dim)) { - // If the dimension is tiled, the i-th index is the product of offset_i - // and tile_i, and the i-th size is the product of sizes_i and tile_i. - AffineExpr i, tile; - bindDims(ctx, i); - bindSymbols(ctx, tile); - OpFoldResult inputIndex = makeComposedFoldedAffineApply( - builder, loc, i * tile, - ArrayRef{origOffsets[dim], dimAndTileMapping[dim]}); - inputIndices.push_back(inputIndex); - - OpFoldResult inputSize = makeComposedFoldedAffineApply( - builder, loc, i * tile, - ArrayRef{origSizes[dim], dimAndTileMapping[dim]}); - inputSizes.push_back(inputSize); - } else { - inputIndices.push_back(origOffsets[dim]); - inputSizes.push_back(origSizes[dim]); - } - - // Limit the size of the input operand for incomplete tiles. - OpFoldResult dimSize = getDim(builder, loc, getInput(), dim); - inputSizes.back() = - min(inputSizes.back(), sub(dimSize, inputIndices.back())); - } - - auto oneAttr = builder.getI64IntegerAttr(1); - SmallVector strides(inputRank, oneAttr); - - SmallVector tiledOperands; - tiledOperands.push_back( - getSlice(builder, loc, getInput(), inputIndices, inputSizes, strides)); - - SmallVector outputOffsets, outputSizes; - if (failed(getResultTilePosition(builder, 0, offsets, sizes, outputOffsets, - outputSizes))) { - return {}; - } - strides.append(getOutputRank() - inputRank, oneAttr); - tiledOperands.push_back( - getSlice(builder, loc, getOutput(), outputOffsets, outputSizes, strides)); - - for (auto tile : getInnerTiles()) { - tiledOperands.push_back(tile); - } - if (auto val = getPaddingValue()) { - tiledOperands.push_back(val); - } - - // There are exactly one input and one output, the output is the second - // operand. - SmallVector tiledResultTypes; - if (hasTensorSemantics()) { - tiledResultTypes.push_back(tiledOperands[1].getType()); - } - - Operation *tiledPackOp = - mlir::clone(builder, getOperation(), tiledResultTypes, tiledOperands); - - return {tiledPackOp}; -} - -LogicalResult PackOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - // The tiling is applied on outer dimensions. In this context, the outer - // dimensions of result tile position is the same. The inner offsets are - // zeros because tiling is not applied to them. - auto zeroAttr = builder.getI64IntegerAttr(0); - resultOffsets.assign(offsets.begin(), offsets.end()); - resultOffsets.append(getOutputRank() - getInputRank(), zeroAttr); - - ReifiedRankedShapedTypeDims outputShape; - if (failed(reifyResultShapes(builder, outputShape))) - return getOperation()->emitOpError("failed to reify result shape"); - if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) { - return getOperation()->emitOpError( - "expected shape of one result value of rank") - << getOutputRank(); - } - - // The outer sizes are the same because the iteration space is over outer - // dimensions. The inner sizes are whole sizes because tiling is not applied - // on them. - resultSizes.assign(sizes.begin(), sizes.end()); - for (auto dataTileDim : - llvm::seq(getInputRank(), getOutputRank())) { - resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim])); - } - - return success(); -} - -LogicalResult -PackOp::reifyResultShapes(OpBuilder &builder, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(builder, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// UnPackOp -//===----------------------------------------------------------------------===// - -/// Custom builder methods for unpack ops. -void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source, - Value output, ArrayRef innerDimsPos, - ArrayRef innerTiles, - ArrayRef outerDimsPerm) { - SmallVector staticTileSizes; - SmallVector dynamicTileSizes; - dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes, - ShapedType::kDynamic); - build(builder, state, output.getType(), source, output, - outerDimsPerm.empty() ? nullptr - : builder.getDenseI64ArrayAttr(outerDimsPerm), - builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes, - builder.getDenseI64ArrayAttr(staticTileSizes)); -} - -SmallVector UnPackOp::getMixedTiles() { - return ::getMixedTiles(*this); -} - -SmallVector UnPackOp::getStaticTiles() { - return ::getStaticTiles(*this); -} - -DenseMap UnPackOp::getDimAndTileMapping() { - return ::getDimAndTileMapping(*this); -} - -LogicalResult UnPackOp::generateScalarImplementation(OpBuilder &builder, - Location loc, - ValueRange ivs) { - assert(ivs.size() == getOutputRank() && - "number of ivs must match the rank of the output tensor"); - OpBuilder::InsertionGuard g(builder); - ReifiedRankedShapedTypeDims outputShape; - if (failed(reifyResultShapes(builder, outputShape))) - return getOperation()->emitOpError("failed to reify result shape"); - if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) { - return getOperation()->emitOpError( - "expected shape of one result value of rank") - << getOutputRank(); - } - - DenseMap dimAndTileMapping = getDimAndTileMapping(); - // untiled loops and tile loops induction variables. - SmallVector inputIvs; - // point loops induction variables. - SmallVector inputIvsPointLoops; - inputIvs.reserve(getOutputRank()); - inputIvsPointLoops.reserve(dimAndTileMapping.size()); - for (auto dim : llvm::seq(0, getOutputRank())) { - if (dimAndTileMapping.count(dim)) { - DivModValue divMod = getDivMod(builder, loc, ivs[dim], - getValueOrCreateConstantIndexOp( - builder, loc, dimAndTileMapping[dim])); - inputIvsPointLoops.push_back(divMod.remainder); - inputIvs.push_back(divMod.quotient); - } else { - inputIvs.push_back(ivs[dim]); - } - } - - // TODO: (lorenzo) simplify the logic a bit. There is `ivs`, - // `inputIvsPointLoops` and `inputIvs`. - assert(inputIvsPointLoops.size() + inputIvs.size() == getInputRank() && - "expect same number of iduction variables equals to input rank"); - // interchange the point loops induction variables based on `inner_dim_pos`. - ArrayRef innerDims = getInnerDimsPos(); - SmallVector interchangeVector = - computeInterchangeFromDimPos(innerDims, getOutputRank()); - SmallVector interchangedInputIvsPointLoops = inputIvsPointLoops; - interchangedInputIvsPointLoops = interchange( - interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0); - // interchange the tiled loops induction variables based on `outer_dims_perm`. - ArrayRef outerDims = getOuterDimsPerm(); - if (!outerDims.empty()) { - inputIvs = interchange(inputIvs, outerDims, /*offset=*/0); - } - - llvm::append_range(inputIvs, interchangedInputIvsPointLoops); - Value scalar = builder.create(loc, getInput(), inputIvs); - builder.create(loc, scalar, getOutput(), ivs); - return success(); -} - -LogicalResult -UnPackOp::reifyResultShapes(OpBuilder &builder, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(builder, reifiedReturnShapes); -} - -SmallVector UnPackOp::getIterationDomain(OpBuilder &builder) { - return ::getIterationDomain(*this, builder); -} - -SmallVector -UnPackOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - // TODO(hanchung): Extend it to handle memref version. - // Tiling on buffers needs extra buffer because tiled unpack op could produce - // more data for incomplete tiles. Tiling on tensors satisfies IREE's needs. - if (!hasTensorSemantics()) - return {}; - - Location loc = getLoc(); - auto ctx = builder.getContext(); - - AffineExpr dim0, dim1; - bindDims(ctx, dim0, dim1); - auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); - auto add = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { - return makeComposedFoldedAffineApply(builder, loc, addMap, {v1, v2}); - }; - auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); - auto sub = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { - return makeComposedFoldedAffineApply(builder, loc, subMap, {v1, v2}); - }; - auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { - return makeComposedFoldedAffineApply(builder, loc, dim0.ceilDiv(dim1), - {v1, v2}); - }; - auto floorDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { - return makeComposedFoldedAffineApply(builder, loc, dim0.floorDiv(dim1), - {v1, v2}); - }; - - // The perfect tiling case indicates that the tiling sizes is are multiple of - // inner_tile_size. In this context, The indices of input slice are all - // aligned to head. No extra data is needed when representing the tiled unpack - // op. - bool isPerfectTilingCase = true; - - int64_t outputRank = getOutputRank(); - Attribute zeroAttr = builder.getIndexAttr(0); - Attribute oneAttr = builder.getIndexAttr(1); - DenseMap dimAndTileMapping = getDimAndTileMapping(); - SmallVector inputIndices, inputSizes, outputNewOffsets, - outputExpandedSizes; - for (auto dim : llvm::seq(0, outputRank)) { - if (!dimAndTileMapping.count(dim)) { - inputIndices.push_back(offsets[dim]); - inputSizes.push_back(sizes[dim]); - outputNewOffsets.push_back(zeroAttr); - outputExpandedSizes.push_back(sizes[dim]); - continue; - } - - FailureOr cstSize = linalg::getConstantUpperBoundForIndex( - getValueOrCreateConstantIndexOp(builder, loc, sizes[dim])); - Optional cstInnerSize = - getConstantIntValue(dimAndTileMapping[dim]); - bool isAlignedToInnerTileSize = false; - if (!failed(cstSize) && cstInnerSize) { - // If the tiling size equals to the inner tiling size, the outer dims are - // always 1. - if (cstInnerSize.value() == cstSize.value()) { - inputIndices.push_back(floorDiv(offsets[dim], dimAndTileMapping[dim])); - inputSizes.push_back(builder.getIndexAttr(1)); - outputNewOffsets.push_back(zeroAttr); - outputExpandedSizes.push_back(sizes[dim]); - continue; - } - if (cstSize.value() % cstInnerSize.value() == 0) - isAlignedToInnerTileSize = true; - } - - if (!isAlignedToInnerTileSize) - isPerfectTilingCase = false; - - DivModValue firstCoord = getDivMod( - builder, loc, - getValueOrCreateConstantIndexOp(builder, loc, offsets[dim]), - getValueOrCreateConstantIndexOp(builder, loc, dimAndTileMapping[dim])); - DivModValue lastCoord = getDivMod( - builder, loc, - getValueOrCreateConstantIndexOp( - builder, loc, sub(add(offsets[dim], sizes[dim]), oneAttr)), - getValueOrCreateConstantIndexOp(builder, loc, dimAndTileMapping[dim])); - - if (isAlignedToInnerTileSize) { - inputIndices.push_back(floorDiv(offsets[dim], dimAndTileMapping[dim])); - outputNewOffsets.push_back(zeroAttr); - outputExpandedSizes.push_back(sizes[dim]); - - // The ceilDiv is needed here because there could be incomplete tile even - // it is perfect tiling cases. E.g., - // %0 = unpack tensor<33x2xf32> into tensor<64xf32> - // If the tiling size is 32, there will be three tiles. Two of them have - // size=32; one of them have size=2. The size is represented using - // affine_min op; we need ceilDiv. - inputSizes.push_back(ceilDiv(sizes[dim], dimAndTileMapping[dim])); - } else { - inputIndices.push_back(firstCoord.quotient); - inputSizes.push_back( - add(sub(lastCoord.quotient, firstCoord.quotient), oneAttr)); - outputNewOffsets.push_back(firstCoord.remainder); - - AffineExpr i, tile; - bindDims(builder.getContext(), i); - bindSymbols(builder.getContext(), tile); - OpFoldResult size = makeComposedFoldedAffineApply( - builder, loc, i * tile, - ArrayRef{inputSizes.back(), dimAndTileMapping[dim]}); - outputExpandedSizes.push_back(size); - } - } - - // The tiling is applied on output dimensions. We have to apply the - // interchange on input dimensions if outer_dims_perm is set. - int64_t inputRank = getInputRank(); - ArrayRef dimsToOuterBlock = getOuterDimsPerm(); - if (!dimsToOuterBlock.empty()) { - SmallVector vec = - computeInterchangeFromDimPos(dimsToOuterBlock, inputRank); - inputIndices = interchange(inputIndices, vec); - inputSizes = interchange(inputSizes, vec); - } - - inputIndices.append(inputRank - outputRank, zeroAttr); - auto mixedTiles = getMixedTiles(); - inputSizes.append(mixedTiles.begin(), mixedTiles.end()); - SmallVector inputStrides(inputRank, oneAttr); - - SmallVector tiledOperands; - tiledOperands.push_back(getSlice(builder, loc, getInput(), inputIndices, - inputSizes, inputStrides)); - - SmallVector outputStrides(outputRank, oneAttr); - if (isPerfectTilingCase) { - tiledOperands.push_back( - getSlice(builder, loc, getOutput(), offsets, sizes, outputStrides)); - } else { - // The tiling is only avaiable on tensors. It's fine to create a - // tensor.empty instead of tensor.pad because the op is not a - // destination-style op. - auto empty = builder.create( - loc, outputExpandedSizes, getOutputType().getElementType()); - tiledOperands.push_back(empty.getResult()); - } - - SmallVector tiledResultTypes; - tiledResultTypes.push_back(tiledOperands[1].getType()); - - Operation *tiledUnpackOp = - mlir::clone(builder, getOperation(), tiledResultTypes, tiledOperands); - - if (isPerfectTilingCase) - return {tiledUnpackOp}; - - Operation *extractSlice = builder.create( - loc, tiledUnpackOp->getResult(0), outputNewOffsets, sizes, outputStrides); - - return {tiledUnpackOp, extractSlice}; -} - -LogicalResult UnPackOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets = llvm::to_vector(offsets); - resultSizes = llvm::to_vector(sizes); - return success(); -} - -LogicalResult UnPackOp::verify() { - if (failed(commonVerifierPackAndUnPackOp(*this))) { - return failure(); - } - return success(); -} - -SmallVector UnPackOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getOutputRank(), - utils::IteratorType::parallel); - return iteratorTypes; -} - -FailureOr -UnPackOp::generateResultTileValue(OpBuilder &b, unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) { - return getTiledImplementation(b, offsets, sizes) - .back() - ->getResult(resultNumber); -} - -//===----------------------------------------------------------------------===// -// WinogradInputTransformOp -//===----------------------------------------------------------------------===// - -LogicalResult WinogradInputTransformOp::verify() { - Operation *op = getOperation(); - if (getNumInputs() != 1) { - return op->emitOpError("expected one input operand"); - } - if (getNumOutputs() != 1) { - return op->emitOpError("expected one output operand"); - } - auto inputType = input().getType().cast(); - auto outputType = output().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - if (inputShape.size() != 4) { - return op->emitOpError("expected input operand to have rank 4"); - } - ArrayRef outputShape = outputType.getShape(); - if (outputType.getElementType() != inputType.getElementType()) { - return op->emitOpError( - "expected input/output element types to be identical"); - } - if (getOutputOperandRank() != getInputOperandRank() + 2) { - return op->emitOpError( - "expected output rank to be equal to input rank + 2"); - } - const SmallVector imageDims = imageDimensions(); - const size_t numImageDims = imageDims.size(); - llvm::SmallSetVector imageDimsSet(imageDims.begin(), - imageDims.end()); - if (imageDims.size() != 2) { - return op->emitOpError("expected only 2 image dimensions"); - } - for (auto dim : imageDims) { - if ((dim < 0) || (dim > 3)) { - return op->emitOpError( - "expect image dimensions to be in the range: [0, 3]"); - } - } - const int64_t outputTileSize = getOutputTileSize(); - const int64_t kernelSize = getKernelSize(); - const int64_t inputTileSize = getInputTileSize(); - SmallVector expectedOutputShape(getOutputOperandRank(), - inputTileSize); - int outputIndex; - for (int i = 0; i < inputShape.size(); i++) { - outputIndex = i + numImageDims; - if (ShapedType::isDynamic(inputShape[i])) { - expectedOutputShape[outputIndex] = inputShape[i]; - continue; - } - if (!imageDimsSet.contains(i)) { - expectedOutputShape[outputIndex] = inputShape[i]; - } else { - expectedOutputShape[outputIndex] = - std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize); - } - } - if (!areShapesCompatible(expectedOutputShape, outputShape)) { - return op->emitOpError("incompatible output shape"); - } - return success(); -} - -SmallVector -WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = input(); - SmallVector imageDims = imageDimensions(); - llvm::SmallSetVector imageDimsSet(imageDims.begin(), - imageDims.end()); - SmallVector loopBounds(imageDims.size()); - int count = 0; - for (auto dim : llvm::seq(0, getInputOperandRank())) { - if (!imageDimsSet.contains(dim)) { - loopBounds[count].offset = zero; - loopBounds[count].size = getDimValue(builder, loc, source, dim); - loopBounds[count].stride = one; - count++; - } - } - return loopBounds; -} - -SmallVector -WinogradInputTransformOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getIterationDomainRank(), - utils::IteratorType::parallel); - return iteratorTypes; -} - -SmallVector -WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - - Location loc = getLoc(); - auto one = builder.getIndexAttr(1); - auto zero = builder.getIndexAttr(0); - - assert(offsets.size() == 2); - SmallVector inputOffsets(getInputOperandRank(), zero); - SmallVector outputOffsets(getOutputOperandRank(), zero); - outputOffsets[2] = inputOffsets[0] = offsets[0]; - outputOffsets[5] = inputOffsets[3] = offsets[1]; - - SmallVector inputStrides(getInputOperandRank(), one); - SmallVector outputStrides(getOutputOperandRank(), one); - - assert(sizes.size() == 2); - auto inputShape = input().getType().cast().getShape(); - auto outputShape = output().getType().cast().getShape(); - SmallVector inputSizes = - getAsOpFoldResult(builder.getIndexArrayAttr(inputShape)); - SmallVector outputSizes = - getAsOpFoldResult(builder.getIndexArrayAttr(outputShape)); - outputSizes[2] = inputSizes[0] = sizes[0]; - outputSizes[5] = inputSizes[3] = sizes[1]; - - SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, loc, input(), inputOffsets, inputSizes, inputStrides)); - tiledOperands.emplace_back(getSlice(builder, loc, output(), outputOffsets, - outputSizes, outputStrides)); - - SmallVector resultTypes; - if (hasTensorSemantics()) { - resultTypes.push_back(tiledOperands[1].getType()); - } - - Operation *tiledOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - - return {tiledOp}; -} - -LogicalResult WinogradInputTransformOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - if (resultNumber == 0) { - auto resultShape = output().getType().cast().getShape(); - resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape)); - resultOffsets = SmallVector(getOutputOperandRank(), - builder.getIndexAttr(0)); - resultOffsets[2] = offsets[0]; - resultOffsets[5] = offsets[1]; - resultSizes[2] = sizes[0]; - resultSizes[5] = sizes[1]; - return success(); - } - return failure(); -} - -LogicalResult WinogradInputTransformOp::fold(ArrayRef, - SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -LogicalResult WinogradInputTransformOp::reifyResultShapes( - OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -//===----------------------------------------------------------------------===// -// WinogradOutputTransformOp -//===----------------------------------------------------------------------===// - -LogicalResult WinogradOutputTransformOp::verify() { - Operation *op = getOperation(); - if (getNumInputs() != 1) { - return op->emitOpError("expected one input operand"); - } - if (getNumOutputs() != 1) { - return op->emitOpError("expected one output operand"); - } - auto inputType = input().getType().cast(); - auto outputType = output().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - if (inputShape.size() != 6) { - return op->emitOpError("expected input operand to have rank 6"); - } - ArrayRef outputShape = outputType.getShape(); - if (outputType.getElementType() != inputType.getElementType()) { - return op->emitOpError( - "expected input/output element types to be identical"); - } - if (getOutputOperandRank() != getInputOperandRank() - 2) { - return op->emitOpError( - "expected output rank to be equal to input rank - 2"); - } - const SmallVector imageDims = imageDimensions(); - const size_t numImageDims = imageDims.size(); - llvm::SmallSetVector imageDimsSet(imageDims.begin(), - imageDims.end()); - if (imageDims.size() != 2) { - return op->emitOpError("expected only 2 image dimensions"); - } - for (auto dim : imageDims) { - if ((dim < 0) || (dim > 3)) { - return op->emitOpError( - "expect image dimensions to be in the range: [0, 3]"); - } - } - const int64_t outputTileSize = getOutputTileSize(); - SmallVector expectedOutputShape(getOutputOperandRank(), 1); - int outputIndex; - for (int i = numImageDims; i < inputShape.size(); i++) { - outputIndex = i - numImageDims; - if (!imageDimsSet.contains(outputIndex)) { - expectedOutputShape[outputIndex] = inputShape[i]; - } else { - expectedOutputShape[outputIndex] = outputTileSize * inputShape[i]; - } - } - if (!areShapesCompatible(expectedOutputShape, outputShape)) { - return op->emitOpError("incompatible output shape"); - } - return success(); -} - -SmallVector -WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) { - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = output(); - SmallVector imageDims = imageDimensions(); - llvm::SmallSetVector imageDimsSet(imageDims.begin(), - imageDims.end()); - SmallVector loopBounds(imageDims.size()); - int count = 0; - for (auto dim : llvm::seq(0, getOutputOperandRank())) { - if (!imageDimsSet.contains(dim)) { - loopBounds[count].offset = zero; - loopBounds[count].size = getDimValue(builder, loc, source, dim); - loopBounds[count].stride = one; - count++; - } - } - return loopBounds; -} - -SmallVector -WinogradOutputTransformOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getIterationDomainRank(), - utils::IteratorType::parallel); - return iteratorTypes; -} - -SmallVector WinogradOutputTransformOp::getTiledImplementation( - OpBuilder &builder, ArrayRef offsets, - ArrayRef sizes) { - - Location loc = getLoc(); - auto one = builder.getIndexAttr(1); - auto zero = builder.getIndexAttr(0); - - assert(offsets.size() == 2); - SmallVector inputOffsets(getInputOperandRank(), zero); - SmallVector outputOffsets(getOutputOperandRank(), zero); - inputOffsets[2] = outputOffsets[0] = offsets[0]; - inputOffsets[5] = outputOffsets[3] = offsets[1]; - - SmallVector inputStrides(getInputOperandRank(), one); - SmallVector outputStrides(getOutputOperandRank(), one); - - assert(sizes.size() == 2); - auto inputShape = input().getType().cast().getShape(); - auto outputShape = output().getType().cast().getShape(); - SmallVector inputSizes = - getAsOpFoldResult(builder.getIndexArrayAttr(inputShape)); - SmallVector outputSizes = - getAsOpFoldResult(builder.getIndexArrayAttr(outputShape)); - inputSizes[2] = outputSizes[0] = sizes[0]; - inputSizes[5] = outputSizes[3] = sizes[1]; - - SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, loc, input(), inputOffsets, inputSizes, inputStrides)); - tiledOperands.emplace_back(getSlice(builder, loc, output(), outputOffsets, - outputSizes, outputStrides)); - - SmallVector resultTypes; - if (hasTensorSemantics()) { - resultTypes.push_back(tiledOperands[1].getType()); - } - - Operation *tiledOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - - return {tiledOp}; -} - -LogicalResult WinogradOutputTransformOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - if (resultNumber == 0) { - auto resultShape = output().getType().cast().getShape(); - resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape)); - resultOffsets = SmallVector(getOutputOperandRank(), - builder.getIndexAttr(0)); - resultOffsets[0] = offsets[0]; - resultOffsets[3] = offsets[1]; - resultSizes[0] = sizes[0]; - resultSizes[3] = sizes[1]; - return success(); - } - return failure(); -} - -LogicalResult WinogradOutputTransformOp::fold(ArrayRef, - SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -LogicalResult WinogradOutputTransformOp::reifyResultShapes( - OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -#define DEFINE_OP_GET_EFFECTS(OP_NAME) \ - void OP_NAME::getEffects( \ - SmallVectorImpl> \ - &effects) { \ - SmallVector inputBuffers = getInputBufferOperands(); \ - SmallVector outputBuffers = getOutputBufferOperands(); \ - getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \ - outputBuffers); \ - } - -DEFINE_OP_GET_EFFECTS(ScatterOp) -DEFINE_OP_GET_EFFECTS(SortOp) -DEFINE_OP_GET_EFFECTS(FftOp) -DEFINE_OP_GET_EFFECTS(ReverseOp) -DEFINE_OP_GET_EFFECTS(ScanOp) -DEFINE_OP_GET_EFFECTS(TopkOp) -DEFINE_OP_GET_EFFECTS(PackOp) -DEFINE_OP_GET_EFFECTS(UnPackOp) -DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp) -DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp) - -//===----------------------------------------------------------------------===// -// iree_linalg_ext.set_encoding -//===----------------------------------------------------------------------===// - -void SetEncodingOp::build(OpBuilder &builder, OperationState &state, - Value source, TensorEncoding encoding) { - auto encodingAttr = EncodingAttr::get(builder.getContext(), encoding); - auto sourceType = source.getType().cast(); - RankedTensorType encodingType = RankedTensorType::get( - sourceType.getShape(), sourceType.getElementType(), encodingAttr); - build(builder, state, encodingType, source); -} - -LogicalResult SetEncodingOp::verify() { - // Source and the result have the same rank. - if (getSourceType().getEncoding()) { - return emitOpError( - "source of set_encoding op cannot have a tensor encoding"); - } - if (!getResultType().getEncoding().isa_and_nonnull()) { - return emitOpError( - "result of set_encoding op expected to have a valid tensor encoding"); - } - // The source and result must have the same rank. - if (getResultType().getRank() != getSourceType().getRank()) - return emitOpError("cannot change the rank of the tensor"); - if (!areShapesCompatible(getResultType().getShape(), - getSourceType().getShape())) - return emitOpError("expected to preserve the logical shape of the tensor"); - return success(); -} - -LogicalResult SetEncodingOp::reifyResultShapes( - OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPoint(getOperation()); - reifiedReturnShapes.resize(1); - reifiedReturnShapes[0] = getValueOrCreateConstantIndexOp( - builder, getLoc(), getDims(builder, getLoc(), getSource())); - return success(); -} - -//===----------------------------------------------------------------------===// -// iree_linalg_ext.unset_encoding -//===----------------------------------------------------------------------===// - -void UnsetEncodingOp::build(OpBuilder &builder, OperationState &state, - Value source) { - auto sourceType = source.getType().cast(); - auto resultType = - RankedTensorType::get(sourceType.getShape(), sourceType.getElementType()); - return build(builder, state, resultType, source); -} - -LogicalResult UnsetEncodingOp::verify() { - if (getResultType().getEncoding()) { - return emitOpError( - "result of unset_encoding op cannot have a tensor encoding"); - } - if (!getSourceType().getEncoding().isa_and_nonnull()) { - return emitOpError( - "source of unset_encoding op expected to have a valid tensor encoding"); - } - // The source and result must have the same rank. - if (getResultType().getRank() != getSourceType().getRank()) - return emitOpError("cannot change the rank of the tensor"); - if (!areShapesCompatible(getResultType().getShape(), - getSourceType().getShape())) - return emitOpError("expected to preserve the logical shape of the tensor"); - return success(); -} - -LogicalResult UnsetEncodingOp::reifyResultShapes( - OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPoint(getOperation()); - reifiedReturnShapes.resize(1); - reifiedReturnShapes[0] = getValueOrCreateConstantIndexOp( - builder, getLoc(), getDims(builder, getLoc(), getSource())); - return success(); -} - -namespace { -/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any -/// changes. -struct FoldTensorCastOp : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(LinalgExtOp op, - PatternRewriter &rewriter) const override { - // If no operand comes from a tensor::CastOp and can be folded then fail. - bool hasTensorCastOperand = - llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - if (opOperand->get().isa()) - return false; - auto castOp = opOperand->get().getDefiningOp(); - return castOp && canFoldIntoConsumerOp(castOp); - }); - if (!hasTensorCastOperand) - return failure(); - - SmallVector newResultTypes; - newResultTypes.reserve(op->getNumResults()); - SmallVector newOperands; - newOperands.reserve(op->getNumOperands()); - // Inputs may fold. - for (OpOperand *opOperand : op.getInputOperands()) { - auto tensorCastOp = opOperand->get().getDefiningOp(); - newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) - ? tensorCastOp.getSource() - : opOperand->get()); - } - // Init tensors may fold, in which case the resultType must also change. - for (OpOperand *opOperand : op.getOutputOperands()) { - auto tensorCastOp = opOperand->get().getDefiningOp(); - bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() - : opOperand->get()); - newResultTypes.push_back(newOperands.back().getType()); - } - // Add the other operands. - for (OpOperand *opOperand : op.getNonInputOrOutputOperands()) { - auto tensorCastOp = opOperand->get().getDefiningOp(); - newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) - ? tensorCastOp.getSource() - : opOperand->get()); - } - // Clone op. - Operation *newOp = mlir::clone(rewriter, op, newResultTypes, newOperands); - SmallVector replacements; - replacements.reserve(newOp->getNumResults()); - for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { - Value oldResult = std::get<0>(result); - Value newResult = std::get<1>(result); - if (newResult.getType() != oldResult.getType()) { - replacements.push_back(rewriter.create( - op->getLoc(), oldResult.getType(), newResult)); - } else { - replacements.push_back(newResult); - } - } - rewriter.replaceOp(op, replacements); - - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// LinalgExtDialect -//===----------------------------------------------------------------------===// - -void IREELinalgExtDialect::getCanonicalizationPatterns( - RewritePatternSet &results) const { - results.add(getContext()); -} - -// clang-format off -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc" // IWYU pragma: keep -// clang-format: on diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt deleted file mode 100644 index 68fdab83c168..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -add_mlir_library(IREELinalgExtPasses - ConvertConv2DToWinograd.cpp - ConvertToLoops.cpp - FoldIntoPackAndUnpackOps.cpp - MaterializeEncoding.cpp - PadContractionToBlockSize.cpp - Passes.cpp - SplitReduction.cpp - TileAndDecomposeWinogradPass.cpp - Tiling.cpp - - DEPENDS - IREELinalgExtPassesIncGen - - LINK_LIBS PUBLIC - IREEInputDialect - IREELinalgExtDialect - IREELinalgExtUtils - MLIRAffineDialect - MLIRIR - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRMathDialect - MLIRMemRefDialect - MLIRMemRefTransforms - MLIRPass - MLIRSCFDialect - MLIRFuncDialect - MLIRSupport - MLIRTensorDialect - MLIRTransforms -) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp deleted file mode 100644 index 303756b42725..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SetVector.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -static inline int index(int y, int x, int dimy, int dimx) { - return (x + dimx * y); -} - -static inline int index(int z, int y, int x, int w, int dimz, int dimy, - int dimx, int dimw) { - return (w + dimw * (x + dimx * (y + dimy * z))); -} - -static bool hasAllOneValues(DenseIntElementsAttr attr) { - return llvm::all_of(attr, [](APInt element) { return element.isOne(); }); -} - -// TODO: Make this a user-settable parameter once we have support -// for more tile sizes -static constexpr int64_t outputTileSize = 6; - -/// This function computes the Winograd filter transform when -/// the filter is known to be a constant. Specifically, this -/// function computes matmul(G, matmul(F, transpose(G))) where -/// F is a tile of the convolution filter of size m x m -/// (single input channel, single output channel) and G has -/// shape m x (m + r - 1) where r is the output tile size and -/// (m + r - 1) is the input tile size. -/// The time complexity of this function is O(ic * oc) -/// where ic is the number of input channels and oc is the -/// number of output channels since input tile size and kernel size -/// are constants. So for large ic and oc, this function is -/// time intensive. -/// TODO: Codegen this as a kernel and run once at initialization -static DenseElementsAttr foldFilterTransform( - ArrayRef shape, int64_t inputTileSize, int64_t kernelSize, - Type outputType, const float *G, bool isSplat, float splatValue, - DenseElementsAttr::iterator_range &input, FloatType floatType) { - const int &kh = shape[0]; - const int &kw = shape[1]; - const int &ic = shape[2]; - const int &oc = shape[3]; - const int64_t numElements = inputTileSize * inputTileSize * ic * oc; - SmallVector output(numElements, APFloat(0.0f)); - for (int d0 = 0; d0 < inputTileSize; d0++) { - for (int d1 = 0; d1 < inputTileSize; d1++) { - for (int d2 = 0; d2 < ic; d2++) { - for (int d3 = 0; d3 < oc; d3++) { - APFloat accum(0.0f); - for (int d4 = 0; d4 < kernelSize; d4++) { - for (int d5 = 0; d5 < kernelSize; d5++) { - APFloat ival(splatValue); - if (!isSplat) { - ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)]; - } - int idx0 = index(d0, d4, inputTileSize, kernelSize); - int idx1 = index(d1, d5, inputTileSize, kernelSize); - accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]); - } - } - int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc); - output[odx] = accum; - if (floatType.isF16()) { - bool losesInfo; - output[odx].convert(APFloat::IEEEhalf(), - APFloat::rmNearestTiesToEven, &losesInfo); - } - } - } - } - } - return DenseElementsAttr::get(outputType, output); -} - -namespace { - -class FoldWinogradFilterTransform final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) const override { - // Check that kernel size = 3x3 - Value kernel = convOp.getInputs()[1]; - auto kernelType = kernel.getType().cast(); - if (!kernelType) - return failure(); - ArrayRef kernelShape = kernelType.getShape(); - const int64_t kh = kernelShape[0]; - const int64_t kw = kernelShape[1]; - if ((kh != 3) || (kw != 3)) - return failure(); - const int64_t kernelSize = kh; - const int64_t inputTileSize = outputTileSize + kernelSize - 1; - - DenseIntOrFPElementsAttr kernelAttr; - if (!matchPattern(kernel, m_Constant(&kernelAttr))) { - return failure(); - } - - Operation *constOp = kernel.getDefiningOp(); - ShapedType type = constOp->getResult(0).getType().cast(); - auto elemType = type.getElementType().cast(); - ArrayRef shape = type.getShape(); - DenseElementsAttr::iterator_range nonSplatValues = - kernelAttr.getValues(); - bool isSplat = kernelAttr.isSplat(); - float splatValue{0.0}; - if (isSplat) { - splatValue = kernelAttr.getSplatValue().convertToFloat(); - } - SmallVector resultShape{inputTileSize * inputTileSize, shape[2], - shape[3]}; - auto resultType = RankedTensorType::get(resultShape, elemType); - auto foldedKernelAttr = - foldFilterTransform(shape, inputTileSize, kernelSize, resultType, - IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat, - splatValue, nonSplatValues, elemType); - rewriter.replaceOpWithNewOp(constOp, foldedKernelAttr); - return success(); - } -}; - -} // namespace - -static Value -createCollapse(Value tensor, Location loc, PatternRewriter &rewriter, - SmallVectorImpl &outputShape, - SmallVectorImpl &reassociations) { - auto tensorType = tensor.getType().cast(); - auto elementTy = tensorType.getElementType(); - auto resultType = RankedTensorType::get(outputShape, elementTy); - return rewriter.create(loc, resultType, tensor, - reassociations); -} - -static Value -createExpand(Value tensor, Location loc, PatternRewriter &rewriter, - SmallVectorImpl &outputShape, - SmallVectorImpl &reassociations) { - auto tensorType = tensor.getType().cast(); - auto elementTy = tensorType.getElementType(); - auto resultType = RankedTensorType::get(outputShape, elementTy); - return rewriter.create(loc, resultType, tensor, - reassociations); -} - -namespace { - -/// Convert conv2d to a sequence of ops that implement the -/// Winograd transformation. The Winograd transformation -/// is parameterized by the output tile size(r). The larger -/// the tile size, the greater the computational savings, -/// but this comes at the cost of accuracy. -/// -/// For now, we restrict this transform to convolutions -/// where the filter size = 3x3, though extensions to larger -/// filter sizes are possible. We refer to the -/// filter size as (m). The input tile size (i) is defined as -/// m + r - 1. For a given output tile size, the Winograd -/// transformation defines 3 constant matrices: -/// -/// B: i x i [used in input transform] -/// G: m x i [used in the filter transform] -/// A: i x r [used in output transform] -/// -/// The choice of these matrices is not unique and affects -/// the accuracy of the approach. -/// -/// Given a convolution of the form -/// -/// y = conv2d(x, f) -/// -/// where x: (N, H, W, C) -/// f: (H, W, C, F) -/// -/// this pattern converts the convolution to the following -/// sequence: -/// -/// f_winograd = winograd.filter_transform(f) [folded] -/// x_winograd = winograd.input_transform(x) -/// x_winograd_c = collapse(x_winograd) -/// y_winograd = batch_matmul(x_winograd_c, f_winograd) -/// y_winograd_e = expand(y_winograd) -/// y_padded = winograd.output_transform(y_winograd_e) -/// y = extract_slice(y_padded) -/// -/// where the dimensions of the tensors above are: -/// -/// f_winograd: (i * i, C, F) -/// x_winograd: (i, i, N, H', W', C) -/// x_winograd_c: (i * i, N * H' * W', C) -/// y_winograd: (i * i, N * H' * W', F) -/// y_winograd_e: (i, i, N, H', W', F) -/// y_padded: (N, r * H', r * W', F) -/// -/// H': ceil((H - m + 1) / r) -/// W': ceil((W - m + 1) / r) -/// -/// The winograd input transform extracts a tile of the input -/// of size i x i and computes matmul(transpose(B), matmul(tile(x), B)). -/// The winograd filter transform extracts a tile of the filter -/// of size m x m and computes matmul(G, matmul(tile(f), transpose(G)). -/// These two are then combined using elementwise multiplication -/// (which becomes a batch matmul when combining over multiple channels). -/// The winograd output filter extracts a tile of the result of size -/// i x i and computes matmul(transpose(A), matmul(tile(y_winograd_e), A)). -/// -/// For more information and additional references, -/// see here: -/// -/// https://github.com/nod-ai/MLIRWinogradTalk/blob/main/MLIRSummit2022.Nodai.Menon.pdf -/// -class ConvertConv2DNhwcHwcf final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) const override { - // Check that strides = 1 - if (!hasAllOneValues(convOp.getStrides())) - return failure(); - - // Check that dilations = 1 - if (!hasAllOneValues(convOp.getDilations())) - return failure(); - - // Check that kernel has been constant folded (by validating rank = 3) - Value kernel = convOp.getInputs()[1]; - auto kernelType = kernel.getType().cast(); - if (!kernelType) - return failure(); - Type elementType = kernelType.getElementType(); - ArrayRef kernelShape = kernelType.getShape(); - if (kernelShape.size() != 3) - return failure(); - - const int64_t kernelSize = 3; - const int64_t inputTileSize = outputTileSize + kernelSize - 1; - - // Create winograd input transform op - Location loc = convOp.getLoc(); - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(elementType)); - Value input = convOp.getInputs()[0]; - auto inputType = input.getType().cast(); - if (!inputType) - return failure(); - ArrayRef inputShape = inputType.getShape(); - if (llvm::any_of(inputShape, ShapedType::isDynamic)) - return failure(); - assert(inputShape.size() == 4); - - SmallVector imageDimensions = {1, 2}; - const size_t numImageDims = imageDimensions.size(); - SmallVector resultShape(6, inputTileSize); - llvm::SmallSetVector imageDimensionsSet(imageDimensions.begin(), - imageDimensions.end()); - int outputIndex; - for (int i = 0; i < inputShape.size(); i++) { - outputIndex = i + numImageDims; - if (!imageDimensionsSet.contains(i)) { - resultShape[outputIndex] = inputShape[i]; - } else { - resultShape[outputIndex] = - std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize); - } - } - Value emptyTensor = - rewriter.create(loc, resultShape, elementType); - auto winogradInputOp = - rewriter.create( - loc, emptyTensor.getType(), ValueRange{input}, - ValueRange{emptyTensor}, outputTileSize, kernelSize, - imageDimensions); - Value winogradInput = winogradInputOp.getResult()[0]; - - // Add collapse shape - SmallVector collapsedShape = { - resultShape[0] * resultShape[1], - resultShape[2] * resultShape[3] * resultShape[4], resultShape[5]}; - SmallVector reassociations = {{0, 1}, {2, 3, 4}, {5}}; - Value collapsedWinogradInput = createCollapse( - winogradInput, loc, rewriter, collapsedShape, reassociations); - - // Add BatchMatmulOp - SmallVector bmmShape(collapsedShape.begin(), collapsedShape.end()); - Value output = convOp.getOutputs()[0]; - auto outputType = output.getType().cast(); - ArrayRef outputShape = outputType.getShape(); - bmmShape[2] = outputShape[3]; - auto bmmOutputType = RankedTensorType::get(bmmShape, elementType); - emptyTensor = rewriter.create(loc, bmmShape, elementType); - auto fillOp = rewriter.create(loc, ValueRange{zero}, - ValueRange{emptyTensor}); - auto bmmOp = rewriter.create( - loc, bmmOutputType, ValueRange({collapsedWinogradInput, kernel}), - ValueRange({fillOp.result()})); - Value bmmResult = bmmOp.getResult(0); - - // Add expand shape - SmallVector expandedShape = {resultShape[0], resultShape[1], - resultShape[2], resultShape[3], - resultShape[4], outputShape[3]}; - reassociations = {{0, 1}, {2, 3, 4}, {5}}; - Value expandedBmmResult = - createExpand(bmmResult, loc, rewriter, expandedShape, reassociations); - - // Convert back into original domain - SmallVector paddedResultShape(outputShape.size(), 0); - for (int i = 0; i < outputShape.size(); i++) { - if (!imageDimensionsSet.contains(i)) { - paddedResultShape[i] = outputShape[i]; - } else { - paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize; - } - } - emptyTensor = - rewriter.create(loc, paddedResultShape, elementType); - auto winogradOutputOp = - rewriter.create( - loc, emptyTensor.getType(), ValueRange{expandedBmmResult}, - ValueRange{emptyTensor}, outputTileSize, kernelSize, - imageDimensions); - Value paddedOutput = winogradOutputOp.getResult()[0]; - - // Extract slice - SmallVector offsets(outputShape.size(), - rewriter.getIndexAttr(0)); - SmallVector strides(outputShape.size(), - rewriter.getIndexAttr(1)); - SmallVector sizes; - for (int i = 0; i < outputShape.size(); i++) - sizes.push_back(rewriter.getIndexAttr(outputShape[i])); - auto winogradOutput = rewriter.create( - loc, outputType, paddedOutput, offsets, sizes, strides); - - Value result = convOp.getResult(0); - result.replaceAllUsesWith(winogradOutput); - return success(); - } -}; - -struct ConvertConv2DToWinogradPass - : ConvertConv2DToWinogradBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - void runOnOperation() override { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(&getContext()); - patterns.insert( - context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr createConvertConv2DToWinogradPass() { - return std::make_unique(); -} - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp deleted file mode 100644 index 683610afaceb..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" - -using namespace mlir; -namespace IREE = mlir::iree_compiler::IREE; -using namespace IREE::LinalgExt; - -/// Recursive method that lowers one dimension of the `TiledOpInterface` to -/// scalar loops at a time. -static LogicalResult lowerToLoopsImpl(OpBuilder &builder, - TilingInterface tilableOp, - ArrayRef loopRanges, - unsigned loopDepth, - SmallVectorImpl &ivs) { - Location loc = tilableOp.getLoc(); - if (loopDepth == loopRanges.size()) { - return tilableOp.generateScalarImplementation(builder, loc, ivs); - } - LogicalResult status = success(); - builder.create( - loc, - getValueOrCreateConstantIndexOp(builder, loc, - loopRanges[loopDepth].offset), - getValueOrCreateConstantIndexOp(builder, loc, loopRanges[loopDepth].size), - getValueOrCreateConstantIndexOp(builder, loc, - loopRanges[loopDepth].stride), - ValueRange{}, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - ivs.push_back(iv); - status = lowerToLoopsImpl(b, tilableOp, loopRanges, loopDepth + 1, ivs); - b.create(loc); - }); - return status; -} - -/// Main entry point for lowering `TiledOpInterface` op to loops. -static LogicalResult lowerToLoops(OpBuilder &builder, - TilingInterface tilableOp) { - SmallVector loopBounds = tilableOp.getIterationDomain(builder); - SmallVector ivs; - return lowerToLoopsImpl(builder, tilableOp, loopBounds, 0, ivs); -} - -/// Pattern rewriter hook to lower a `TiledOpInterface` to loops. -namespace { -struct TilingInterfaceLowerToLoopsPattern : public RewritePattern { - TilingInterfaceLowerToLoopsPattern(MLIRContext *context, - PatternBenefit benefit = 1) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto tilableOp = dyn_cast(op); - if (!tilableOp) { - return rewriter.notifyMatchFailure(op, "not TilingInterface op"); - } - // Avoid handling `LinalgOp`s here for now. Eventually this should - // be able to handle everything (or this pass would be deprecated to use - // something upstream). - if (isa(op)) { - return rewriter.notifyMatchFailure(op, "ignoring LinalgOps"); - } - if (llvm::any_of(tilableOp->getResults(), - [&](Value v) { return v.getType().isa(); })) { - return rewriter.notifyMatchFailure( - tilableOp, "lower to loops needs to have tensor semantics"); - } - if (failed(lowerToLoops(rewriter, tilableOp))) { - return failure(); - } - rewriter.eraseOp(op); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Pass -//===----------------------------------------------------------------------===// - -namespace { -struct LinalgExtToLoopsPass - : public LinalgExtToLoopsBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - - RewritePatternSet patterns(context); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr> -IREE::LinalgExt::createLinalgExtToLoopsPass() { - return std::make_unique(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp deleted file mode 100644 index 7477bc041cc0..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/FoldIntoPackAndUnpackOps.cpp +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; -using namespace mlir::iree_compiler; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -//===---------------------------------------------------------------------===// -// Patterns to fold operationsinto pack/unpack ops. -//===---------------------------------------------------------------------===// - -namespace { - -static bool areAllConstantIntValue(ArrayRef ofrs, int64_t value) { - return llvm::all_of( - ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); -} - -/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already -/// has extract_slice semantics. -struct FoldUnpackWithExtractSliceOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, - PatternRewriter &rewriter) const override { - auto unpackOp = sliceOp.getSource().getDefiningOp(); - if (!unpackOp) - return failure(); - - // Check all offsets are zeros, and all strides are 1. - if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || - !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { - return rewriter.notifyMatchFailure( - sliceOp, "expects offsets to be 0s and strides to be 1s"); - } - - // Create a new empty output tensor. - Type elementType = unpackOp.getOutput() - .getType() - .cast() - .getElementType(); - Value output = rewriter.create( - sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); - rewriter.replaceOpWithNewOp( - sliceOp, unpackOp.getInput(), output, unpackOp.getInnerDimsPos(), - unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); - return success(); - } -}; -} // namespace - -//===---------------------------------------------------------------------===// -// Pass to fold operations into pack and unpack operations. -//===---------------------------------------------------------------------===// - -namespace { -struct FoldIntoPackAndUnpackOpsPass - : public FoldIntoPackAndUnpackOpsBase { - void getDependentDialects(DialectRegistry ®istry) const override { - return; - } - - void runOnOperation() override; -}; -} // namespace - -void FoldIntoPackAndUnpackOpsPass::runOnOperation() { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - populateFoldIntoPackAndUnpackOpsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) - return signalPassFailure(); -} - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -void populateFoldIntoPackAndUnpackOpsPatterns(RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); -} - -std::unique_ptr> createFoldIntoPackAndUnpackOps() { - return std::make_unique(); -} - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp deleted file mode 100644 index fc2275700fec..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; -using namespace mlir::iree_compiler; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -//===---------------------------------------------------------------------===// -// Utility methods -//===---------------------------------------------------------------------===// - -/// Extract encoding from the `tensorType` if specified. -static Optional getEncoding(RankedTensorType tensorType) { - auto encodingAttr = tensorType.getEncoding().dyn_cast_or_null(); - if (!encodingAttr) - return llvm::None; - return encodingAttr.getEncoding().getValue(); -} - -/// For a given tensor type with an encoding, return the materialized -/// type to use for it. If no encoding is set, then return the tensor type -/// itself. -static RankedTensorType -getMaterializedType(RankedTensorType tensorType, - MaterializeEncodingFn materializeEncodingFn) { - Optional encoding = getEncoding(tensorType); - if (!encoding) - return tensorType; - FailureOr materializeEncodingInfo = - materializeEncodingFn(tensorType); - if (failed(materializeEncodingInfo)) { - return tensorType; - } - return PackOp::getPackedType(tensorType, - materializeEncodingInfo->innerTileSizes, - materializeEncodingInfo->innerDimsPos, - materializeEncodingInfo->outerDimsPerm) - .cast(); -} - -/// Helper methods to get `OpFoldResult` from `int64_t` values. -static OpFoldResult getAsOpFoldResult(OpBuilder &builder, int64_t value) { - return builder.getI64IntegerAttr(value); -} -static SmallVector getAsOpFoldResult(OpBuilder &builder, - ArrayRef values) { - return llvm::to_vector(llvm::map_range( - values, [&](int64_t v) { return getAsOpFoldResult(builder, v); })); -} - -//===---------------------------------------------------------------------===// -// Methods to convert the encoding to parameters of the Pack operation -//===---------------------------------------------------------------------===// - -/// Given the `encoding` return the `MaterializeEncodingInfo` to use for -/// materializing the pack op. -// TODO(ravishankarm): This is currently hard-coded here for convenience. When -// used in IREE, this will be computed based on the architecture information in -// `hal.executable.variant`. -// A real implementation would return tile sizes that depend on at least the -// `tensorType`'s element type (e.g. different tile sizes for i8 vs f32, because -// the SIMD instructions may have different shapes). -// Moreover, in a real implementation, the tile sizes would typically also -// depend on target information. This is demonstrated in -// iree/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPass.cpp -static FailureOr -chooseEncodingInfo(RankedTensorType tensorType) { - Optional encoding = getEncoding(tensorType); - if (!encoding) - return failure(); - switch (*encoding) { - case TensorEncoding::MATMUL_F32F32F32_LHS: - case TensorEncoding::MATMUL_I8I8I32_LHS: - return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}}; - break; - case TensorEncoding::MATMUL_F32F32F32_RHS: - case TensorEncoding::MATMUL_I8I8I32_RHS: - return MaterializeEncodingInfo{{0, 1}, {4, 8}, {}}; - break; - case TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE: - case TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE: - return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}}; - break; - case TensorEncoding::MATMUL_F32F32F32_RESULT: - case TensorEncoding::MATMUL_I8I8I32_RESULT: - return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}}; - break; - default: - return failure(); - } -} - -//===---------------------------------------------------------------------===// -// Methods to convert `set_encoding` and `unset_encoding` operations -// to `pack` and `unpack` operations respectively. -//===---------------------------------------------------------------------===// - -/// Utility method to get the optional padding value to use with pack operation -/// if source is defined using a `tensor.pad` operation. Note `source` is -/// passed by reference. It is updated to use the source of the pad operation. -static Optional getPaddingValue(Value &source) { - auto padOp = source.getDefiningOp(); - if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) - return llvm::None; - - Value constantPaddingValue = padOp.getConstantPaddingValue(); - if (!constantPaddingValue) - return llvm::None; - - source = padOp.getSource(); - return constantPaddingValue; -} - -/// Utility method to convert from `set_encoding` op to `pack` operation. -/// For now this takes a `paddingValue` as input. The source is also taken -/// as input so that these could be used with `OpConversionPatterns`. -static FailureOr -lowerSetEncodingOpToPackOp(RewriterBase &rewriter, SetEncodingOp encodingOp, - Value source, - MaterializeEncodingFn materializeEncodingFn) { - RankedTensorType resultType = encodingOp.getResultType(); - FailureOr materializeEncodingInfo = - materializeEncodingFn(resultType); - if (failed(materializeEncodingInfo)) { - return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding"); - } - - // Create `tensor.empty` operation for the result of the pack operation. - Location loc = encodingOp.getLoc(); - SmallVector sourceDims = getDims(rewriter, loc, source); - SmallVector innerTileSizesOfr = - getAsOpFoldResult(rewriter, materializeEncodingInfo->innerTileSizes); - SmallVector resultDims = - PackOp::getResultShape(rewriter, loc, sourceDims, innerTileSizesOfr, - materializeEncodingInfo->innerDimsPos, - materializeEncodingInfo->outerDimsPerm); - auto initTensor = rewriter.create( - loc, resultDims, resultType.getElementType()); - Optional paddingValue = getPaddingValue(source); - return rewriter.create( - loc, source, initTensor, materializeEncodingInfo->innerDimsPos, - innerTileSizesOfr, paddingValue, materializeEncodingInfo->outerDimsPerm); -} - -/// Utility method to convert from `set_encoding` op to `pack` operation. -/// The source is taken as input so that these could be used with -/// `OpConversionPatterns`. -static FailureOr -lowerUnsetEncodingToUnpackOp(RewriterBase &rewriter, UnsetEncodingOp encodingOp, - Value packedValue, - MaterializeEncodingFn materializeEncodingFn) { - RankedTensorType sourceType = encodingOp.getSourceType(); - FailureOr materializeEncodingInfo = - materializeEncodingFn(sourceType); - if (failed(materializeEncodingInfo)) { - return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding"); - } - // Create an `tensor.empty` for the result of the unpack operation. - Location loc = encodingOp.getLoc(); - SmallVector resultDims = - getDims(rewriter, loc, encodingOp.getSource()); - auto initTensor = rewriter.create( - loc, resultDims, sourceType.getElementType()); - - SmallVector innerTileSizesOfr = - getAsOpFoldResult(rewriter, materializeEncodingInfo->innerTileSizes); - return rewriter.create( - loc, packedValue, initTensor, materializeEncodingInfo->innerDimsPos, - innerTileSizesOfr, materializeEncodingInfo->outerDimsPerm); -} - -/// Utility method to convert from `linalg.matmul` with -/// - lhs encoding of MATMUL_*_LHS -/// - rhs encoding of MATMUL_*_RHS_TRANSPOSE -/// - result encoding of MATMUL_*_RESULT -/// to linalg.mmt4d op. -static FailureOr -lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp, - ValueRange convertedInputOperands, - ValueRange convertedOutputOperands, - MaterializeEncodingFn materializeEncodingFn) { - if (!matmulOp.hasTensorSemantics()) - return failure(); - auto inputs = matmulOp.getDpsInputOperands(); - auto outputs = matmulOp.getDpsInitOperands(); - Optional lhsEncoding = - getEncoding(inputs[0]->get().getType().cast()); - Optional rhsEncoding = - getEncoding(inputs[1]->get().getType().cast()); - Optional resultEncoding = - getEncoding(outputs[0]->get().getType().cast()); - if (!lhsEncoding || - (lhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_LHS && - lhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_LHS) || - !rhsEncoding || - (rhsEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RHS_TRANSPOSE && - rhsEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RHS_TRANSPOSE) || - !resultEncoding || - (resultEncoding.value() != TensorEncoding::MATMUL_F32F32F32_RESULT && - resultEncoding.value() != TensorEncoding::MATMUL_I8I8I32_RESULT)) { - return failure(); - } - Operation *mmt4DOp = rewriter.create( - matmulOp.getLoc(), convertedOutputOperands[0].getType(), - convertedInputOperands, convertedOutputOperands); - return mmt4DOp; -} - -/// Utility method to convert from `linalg.fill` on `tensor` type with encoding -/// to fill of the materialized type -static FailureOr -lowerOpWithEncoding(RewriterBase &rewriter, linalg::FillOp fillOp, - ValueRange convertedInputOperands, - ValueRange convertedOutputOperands, - MaterializeEncodingFn materializeEncodingFn) { - if (!fillOp.hasTensorSemantics()) - return failure(); - Operation *materializedFillOp = rewriter.create( - fillOp.getLoc(), convertedOutputOperands[0].getType(), - convertedInputOperands, convertedOutputOperands); - return materializedFillOp; -} - -/// Utility method to convert `tensor.empty` with encoding to a `tensor.empty` -/// of the materialized type. -static FailureOr -lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp, - ValueRange convertedOperands, - MaterializeEncodingFn materializeEncodingFn) { - auto resultType = emptyOp.getResult().getType().cast(); - FailureOr materializeEncodingInfo = - materializeEncodingFn(resultType); - if (failed(materializeEncodingInfo)) { - return rewriter.notifyMatchFailure( - emptyOp, "failed to find materialization info for result type"); - } - SmallVector innerTileSizesOfr = - getAsOpFoldResult(rewriter, materializeEncodingInfo->innerTileSizes); - SmallVector newShape = PackOp::getResultShape( - rewriter, emptyOp.getLoc(), emptyOp.getMixedSizes(), innerTileSizesOfr, - materializeEncodingInfo->innerDimsPos, - materializeEncodingInfo->outerDimsPerm); - Operation *newEmptyOp = rewriter.create( - emptyOp.getLoc(), newShape, resultType.getElementType()); - return newEmptyOp; -} - -namespace { -//===---------------------------------------------------------------------===// -// Patterns to lower ops with encodings. These are written as -// dialect conversion patterns for now. These are just drivers around -// the core conversion utilities. -//===---------------------------------------------------------------------===// - -/// Convert `set_encoding` op to `pack` op. -struct SetEncodingOpToPackOpConversion - : public OpMaterializeEncodingPattern { - using OpMaterializeEncodingPattern< - SetEncodingOp>::OpMaterializeEncodingPattern; - - LogicalResult - matchAndRewrite(SetEncodingOp encodingOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MaterializeEncodingFn &materializeEncodingFn = - static_cast(getTypeConverter()) - ->getMaterializeEncodingFn(); - // Pack op needs a padding value. Maybe that is an overkill. For now, just - // use zero. - auto packOp = lowerSetEncodingOpToPackOp( - rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn); - if (failed(packOp)) - return rewriter.notifyMatchFailure(encodingOp, - "failed to convert to pack op"); - rewriter.replaceOp(encodingOp, packOp->getResults()); - return success(); - } -}; - -/// Convert `unset_encoding` op to `unpack` op. -struct UnsetEncodingOpToPackOpConversion - : public OpMaterializeEncodingPattern { - using OpMaterializeEncodingPattern< - UnsetEncodingOp>::OpMaterializeEncodingPattern; - - LogicalResult - matchAndRewrite(UnsetEncodingOp encodingOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MaterializeEncodingFn &materializeEncodingFn = - static_cast(getTypeConverter()) - ->getMaterializeEncodingFn(); - auto unpackOp = lowerUnsetEncodingToUnpackOp( - rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn); - if (failed(unpackOp)) - return rewriter.notifyMatchFailure(encodingOp, - "failed to convert to unpack op"); - rewriter.replaceOp(encodingOp, unpackOp->getResults()); - return success(); - } -}; - -/// Generic pattern to convert operaiton that is in Destination Passing Style. -template -struct MaterializeDPSOperation : public OpMaterializeEncodingPattern { - using OpMaterializeEncodingPattern::OpMaterializeEncodingPattern; - - LogicalResult - matchAndRewrite(OpTy dpsOp, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MaterializeEncodingFn &materializeEncodingFn = - static_cast( - this->getTypeConverter()) - ->getMaterializeEncodingFn(); - FailureOr convertedOp = - lowerOpWithEncoding(rewriter, dpsOp, adaptor.getInputs(), - adaptor.getOutputs(), materializeEncodingFn); - if (failed(convertedOp)) - return failure(); - rewriter.replaceOp(dpsOp, convertedOp.value()->getResults()); - return success(); - } -}; - -/// Generic pattern to convert an operation. -template -struct MaterializeOperation : public OpMaterializeEncodingPattern { - using OpMaterializeEncodingPattern::OpMaterializeEncodingPattern; - - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MaterializeEncodingFn &materializeEncodingFn = - static_cast( - this->getTypeConverter()) - ->getMaterializeEncodingFn(); - FailureOr convertedOp = lowerOpWithEncoding( - rewriter, op, adaptor.getOperands(), materializeEncodingFn); - if (failed(convertedOp)) - return failure(); - rewriter.replaceOp(op, convertedOp.value()->getResults()); - return success(); - } -}; - -//===---------------------------------------------------------------------===// -// Pass to materialize encoding -//===---------------------------------------------------------------------===// - -struct MaterializeEncodingPass - : public MaterializeEncodingBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override; -}; - -void MaterializeEncodingPass::runOnOperation() { - MLIRContext *context = &getContext(); - - { - Operation *op = getOperation(); - RewritePatternSet patterns(context); - MaterializeEncodingTypeConverter typeConverter(chooseEncodingInfo); - MaterializeEncodingConversionTarget target(*context); - populateMaterializeEncodingPatterns(patterns, target, typeConverter); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) - return signalPassFailure(); - } - - // Add patterns to fold pack/unpack ops with pad/extract_slice ops. - { - RewritePatternSet patterns(context); - populateFoldIntoPackAndUnpackOpsPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) - return signalPassFailure(); - } -} -} // namespace - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter( - MaterializeEncodingFn materializeEncodingFn) - : materializeEncodingFn(materializeEncodingFn) { - addConversion([](IntegerType intType) { return intType; }); - addConversion([](IndexType indexType) { return indexType; }); - addConversion([](FloatType floatType) { return floatType; }); - addConversion([](MemRefType memrefType) { return memrefType; }); - addConversion( - [materializeEncodingFn](RankedTensorType t) -> RankedTensorType { - return getMaterializedType(t, materializeEncodingFn); - }); -} - -MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget( - MLIRContext &context) - : ConversionTarget(context) { - // Mark any operation that has operands/results with encoding as - // illegal. - markUnknownOpDynamicallyLegal([](Operation *op) { - auto typeHasEncoding = [](Type t) -> bool { - auto tensorType = t.dyn_cast(); - return tensorType && tensorType.getEncoding(); - }; - auto valueHasEncoding = [=](Value v) -> bool { - return typeHasEncoding(v.getType()); - }; - bool hasOperandOrResultsWithEncoding = - llvm::any_of(op->getOperands(), valueHasEncoding) || - llvm::any_of(op->getResultTypes(), typeHasEncoding); - return !hasOperandOrResultsWithEncoding; - }); -} - -void populateMaterializeEncodingPatterns( - RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target, - MaterializeEncodingTypeConverter &typeConverter) { - - // Add all patterns for converting from encoded type to the materialized type - patterns.insert, - MaterializeDPSOperation, - MaterializeOperation, - SetEncodingOpToPackOpConversion, - UnsetEncodingOpToPackOpConversion>(typeConverter, - patterns.getContext()); - ::mlir::memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); -} - -std::unique_ptr> createMaterializeEncodingPass() { - return std::make_unique(); -} - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp deleted file mode 100644 index 7748b0b7dd48..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/Input/InputDialect.h" -#include "iree-dialects/Dialect/Input/InputOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; -namespace IREE = mlir::iree_compiler::IREE; -using namespace IREE::LinalgExt; - -static Operation *sliceTensor(Location loc, Value expanded, Value original, - OpBuilder &builder) { - auto originalType = original.getType().cast(); - auto rank = originalType.getRank(); - SmallVector offsets(rank, builder.getI64IntegerAttr(0)); - SmallVector strides(rank, builder.getI64IntegerAttr(1)); - SmallVector sizes(rank); - for (int i = 0, e = rank; i < e; ++i) { - if (!originalType.isDynamicDim(i)) { - sizes[i] = builder.getI64IntegerAttr(originalType.getDimSize(i)); - } else { - sizes[i] = builder.create(loc, original, i).getResult(); - } - } - - return builder.create(loc, expanded, offsets, sizes, - strides); -} - -static bool padTensor(Location loc, OpOperand *operand, - ArrayRef alignments, OpBuilder &builder) { - Value original = operand->get(); - auto type = original.getType().cast(); - ArrayRef shape = type.getShape(); - assert(shape.size() == alignments.size() && - "expected shape and alignments to match"); - - // New dimensions. - SmallVector newStaticDims; - newStaticDims.resize(shape.size(), ShapedType::kDynamic); - SmallVector newPaddingSizes(shape.size(), - builder.getI64IntegerAttr(0)); - - // Compute padded dims. - bool needsPad = false; - for (int i = 0, e = shape.size(); i < e; ++i) { - auto inputDim = shape[i]; - auto alignment = alignments[i]; - if (inputDim >= 0) { - // Static dim. - if ((inputDim % alignment) == 0) { - newStaticDims[i] = inputDim; - continue; - } - int64_t alignedDim = (inputDim + (alignment - 1)) & ~(alignment - 1); - newStaticDims[i] = alignedDim; - newPaddingSizes[i] = builder.getI64IntegerAttr(alignedDim - inputDim); - needsPad = true; - } else { - // Dynamic dim. - Value inputDimValue = builder.create(loc, original, i); - Value alignedDim = - builder.create(loc, inputDimValue, alignment); - newPaddingSizes[i] = alignedDim; - needsPad = true; - } - } - if (!needsPad) - return false; - - auto resultType = RankedTensorType::get(newStaticDims, type.getElementType()); - Value zeroConstant = builder.create( - loc, builder.getZeroAttr(type.getElementType())); - SmallVector zeroStaticLow(shape.size(), - builder.getI64IntegerAttr(0)); - SmallVector nullLow; - Value padded = builder.create(loc, resultType, operand->get(), - zeroStaticLow, newPaddingSizes, - zeroConstant); - operand->set(padded); - return true; -} - -namespace { - -struct PadContractionToBlockSizePass - : public PadContractionToBlockSizeBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - getOperation()->walk([&](linalg::ContractionOpInterface op) { - auto linalgOp = llvm::cast(op.getOperation()); - Location loc = op.getLoc(); - OpOperand *lhs = linalgOp.getDpsInputOperand(0); - OpOperand *rhs = linalgOp.getDpsInputOperand(1); - OpOperand *output = linalgOp.getDpsInitOperand(0); - Value origOutput = output->get(); - OpResult result = op.getOperation()->getResult(0); - - bool insertSlice = false; - OpBuilder builder(op.getOperation()); - if (op.isRowMajorMatmul()) { - padTensor(loc, lhs, {rowAlignment, rowAlignment}, builder); - padTensor(loc, rhs, {rowAlignment, columnAlignment}, builder); - if (padTensor(loc, output, {rowAlignment, columnAlignment}, builder)) { - result.setType(output->get().getType()); - insertSlice = true; - } - } - - // Insert an appropriate extract. - if (insertSlice) { - builder.setInsertionPointAfter(op.getOperation()); - Operation *slicedResult = sliceTensor(loc, result, origOutput, builder); - result.replaceAllUsesExcept(slicedResult->getResult(0), slicedResult); - } - - return WalkResult::advance(); - }); - } -}; -} // namespace - -std::unique_ptr> -IREE::LinalgExt::createPadContractionToBlockSizePass() { - return std::make_unique(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp deleted file mode 100644 index 05948c09dcfe..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" - -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Passes.h" - -using namespace mlir; -namespace IREE = mlir::iree_compiler::IREE; - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -// Marker used as attribute name in generated Linalg rewriting transformations. -const StringLiteral LinalgTransforms::kLinalgTransformMarker = - "__internal_linalg_transform__"; - -LinalgTransformationFilter::LinalgTransformationFilter( - ArrayRef matchDisjunction, Optional replacement) - : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement), matchByDefault(false) {} - -LinalgTransformationFilter::LinalgTransformationFilter( - const FilterFunction &f, ArrayRef matchDisjunction, - Optional replacement) - : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement), matchByDefault(false) { - if (f) - filters.push_back(f); -} - -LogicalResult -LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter, - Operation *op) const { - if (llvm::any_of(filters, - [&](const FilterFunction &f) { return failed(f(op)); })) - return failure(); - - auto attr = op->template getAttrOfType( - LinalgTransforms::kLinalgTransformMarker); - - if (!attr) { - // 1. Has no filter case and matchDisjunction is empty. - if (matchDisjunction.empty() || matchByDefault) - return success(); - - // 2. Has no filter but was expecting a filter. - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << " does not have any filter from list: "; - interleaveComma(matchDisjunction, diag); - }); - } - - // 4. Match explicit filter. - for (auto filter : matchDisjunction) - if (attr.getValue() == filter) - return success(); - - // 5. Fail to match. - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << " does not have any filter from list: "; - interleaveComma(matchDisjunction, diag); - }); -} - -void LinalgTransformationFilter::replaceLinalgTransformationFilter( - PatternRewriter &rewriter, Operation *op) const { - if (replacement.has_value()) - op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value()); - else - op->removeAttr( - rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); -} - -bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const { - if (!replacement) - return false; - auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) - .dyn_cast(); - return attr && attr == *replacement; -} - -namespace detail { -#define GEN_PASS_REGISTRATION -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: export -} // namespace detail - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -void IREE::LinalgExt::registerPasses() { - IREE::LinalgExt::detail::registerPasses(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp deleted file mode 100644 index d36abd30add9..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp +++ /dev/null @@ -1,430 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -namespace { - -SmallVector getExpandedShape(ArrayRef shape, - int64_t splitReductionRatio, - int64_t splitDimParallel) { - SmallVector ans; - ans.reserve(shape.size() + 1); - ans.assign(shape.begin(), shape.end()); - ans[splitDimParallel] = splitReductionRatio; - ans.insert(std::next(ans.begin(), splitDimParallel + 1), - shape[splitDimParallel] / splitReductionRatio); - - return ans; -} - -SmallVector getCollapsedShape(ArrayRef shape, - int64_t splitReductionRatio, int64_t k, - int64_t targetDim) { - SmallVector ans(shape.begin(), shape.end()); - ans[targetDim] = k * splitReductionRatio; - return ans; -} - -SmallVector -getReassociationIndices(int64_t rank, int64_t splitDimParallel) { - SmallVector reassociationIndices; - for (int i = 0; i < rank; ++i) { - if (i < splitDimParallel) { - reassociationIndices.push_back({i}); - } else if (i == splitDimParallel) { - reassociationIndices.push_back({i, i + 1}); - } else if (i > splitDimParallel) { - reassociationIndices.push_back({i + 1}); - } - } - return reassociationIndices; -} - -LogicalResult shouldParallelTopk(iree_compiler::IREE::LinalgExt::TopkOp topkOp, - PatternRewriter &rewriter, int64_t kDimOrig, - int64_t splitReductionRatio, - int64_t splitReductionDepth) { - // Determine if we should split the reduction. Requires aligned static shapes - // and no input indicies. - auto valuesOrigType = topkOp.getInputType(); - if (valuesOrigType.isDynamicDim(kDimOrig)) { - return rewriter.notifyMatchFailure(topkOp, - "cannot split dynamic dimension"); - } - if (topkOp.indices() && splitReductionDepth == 0) { - return rewriter.notifyMatchFailure( - topkOp, "input indices aren't supported for first split"); - } - if (splitReductionRatio <= 1) { - return rewriter.notifyMatchFailure(topkOp, "reduction ratio <= 1"); - } - if (valuesOrigType.getDimSize(kDimOrig) % splitReductionRatio != 0) { - return rewriter.notifyMatchFailure( - topkOp, - "reduction dimension must be perfectly aligned to (divisible by) the " - "split ratio"); - } - return success(); -} - -// Creates the first phase of the topk split reduction by reshaping the input -// into parallel computations then feeding them into a topk op. -iree_compiler::IREE::LinalgExt::TopkOp -computeParallelTopk(Location loc, PatternRewriter &rewriter, - iree_compiler::IREE::LinalgExt::TopkOp topkOp, - ArrayRef reassociationIndices, - int64_t splitReductionRatio, int64_t splitDimParallel, - int64_t kDimParallel, int64_t kSize) { - Value valuesOrig = topkOp.values(); - auto valuesOrigType = valuesOrig.getType().cast(); - Type valueElementType = valuesOrigType.getElementType(); - Type indicesElementType = - topkOp.getResultTypes()[1].cast().getElementType(); - - SmallVector expandedShape = getExpandedShape( - valuesOrigType.getShape(), splitReductionRatio, splitDimParallel); - auto valuesExpandedType = - RankedTensorType::get(expandedShape, valueElementType); - - // Expand input values shape for parallel processing - Value valuesExpanded = rewriter.create( - loc, valuesExpandedType, valuesOrig, reassociationIndices); - - // Expand input indices shape for parallel processing if they exist - Optional indicesExpanded; - if (Optional inputIndices = topkOp.indices()) { - // Type inputElementType = inputIndices->getType().cast(); - Type indicesExpandedType = - RankedTensorType::get(expandedShape, indicesElementType); - indicesExpanded = rewriter.create( - loc, indicesExpandedType, inputIndices.value(), reassociationIndices); - } - - // Define the expanded output types - SmallVector expandedResultShape = expandedShape; - expandedResultShape[kDimParallel] = kSize; - auto outputValuesExpandedType = - RankedTensorType::get(expandedResultShape, valueElementType); - auto outputIndicesExpandedType = - RankedTensorType::get(expandedResultShape, indicesElementType); - - // Initialize the expanded output values - SmallVector dynSizes; - for (auto i : llvm::seq(0, valuesExpandedType.getRank())) { - if (valuesExpandedType.isDynamicDim(i)) { - dynSizes.push_back( - rewriter.create(loc, valuesExpanded, i)); - } - } - Value emptyTensorOutputValues = rewriter.create( - loc, outputValuesExpandedType.getShape(), valueElementType, dynSizes); - Value emptyTensorOutputIndices = rewriter.create( - loc, outputIndicesExpandedType.getShape(), indicesElementType, dynSizes); - - // Initialize indices to positive infinity and values to negative infinity - // for a top (maxk) comparison. - Attribute negInfAttr; - if (auto intType = valueElementType.dyn_cast()) { - negInfAttr = rewriter.getIntegerAttr( - intType, APInt::getSignedMinValue(intType.getWidth())); - } else { - auto negApFloat = - APFloat::getInf(valueElementType.cast().getFloatSemantics(), - /*Negative=*/true); - negInfAttr = rewriter.getFloatAttr(valueElementType, negApFloat); - } - Value negInf = rewriter.create(loc, negInfAttr); - Attribute posInfAttr = - rewriter.getIntegerAttr(indicesElementType, APInt::getSignedMaxValue(32)); - Value posInf = rewriter.create(loc, posInfAttr); - Value negInfTensor = - rewriter.create(loc, negInf, emptyTensorOutputValues) - .result(); - Value posInfTensor = - rewriter.create(loc, posInf, emptyTensorOutputIndices) - .result(); - - SmallVector parallelTopkResultTypes = {outputValuesExpandedType, - outputIndicesExpandedType}; - SmallVector parallelTopkIns = {valuesExpanded}; - if (indicesExpanded) { - parallelTopkIns.push_back(indicesExpanded.value()); - } - SmallVector parallelTopkOuts = {negInfTensor, posInfTensor}; - - // Parallel topk - auto parallelTopkOp = rewriter.create( - loc, - /*resultTypes=*/ - parallelTopkResultTypes, - /*ins=*/parallelTopkIns, - /*outs=*/parallelTopkOuts, kDimParallel); - rewriter.cloneRegionBefore(topkOp.getRegion(), parallelTopkOp.getRegion(), - parallelTopkOp.getRegion().end()); - - return parallelTopkOp; -} - -// Update the output indices from the parallel TopK with the correct offsets. -// Each parallel computation uses implicit indices (starting from 0) during -// selection, but the values are part of the large input space split into M = -// splitReductionFn() ways. The following linalg.generic adds the appropriate -// offset to reflect to values original position. "Updated pos" = "initial -// pos" + "splitDimParallel size * "splitDimParallel index" -Value offsetParallelIndices(Location loc, PatternRewriter &rewriter, - Value parallelIndices, int64_t kDimParallelSize, - int64_t splitDimParallel) { - auto parallelIndicesType = parallelIndices.getType().cast(); - size_t parallelIndicesRank = parallelIndicesType.getRank(); - AffineMap mapIdentity = rewriter.getMultiDimIdentityMap(parallelIndicesRank); - SmallVector indexingMaps = {mapIdentity}; - SmallVector iterators(parallelIndicesRank, - utils::IteratorType::parallel); - Value mSplitVal = rewriter.create( - loc, kDimParallelSize, parallelIndicesType.getElementType()); - return rewriter - .create( - loc, - /*resultType=*/parallelIndicesType, - /*inputs=*/ValueRange{}, - /*outputs=*/ValueRange{parallelIndices}, indexingMaps, iterators, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value splitIndex = b.create(loc, splitDimParallel); - Value splitIndexInt = b.create( - loc, parallelIndicesType.getElementType(), splitIndex); - Value mOffset = - b.create(loc, mSplitVal, splitIndexInt); - Value updatedParallelIndex = - b.create(loc, mOffset, args[0]); - b.create(loc, updatedParallelIndex); - }) - .getResult(0); -} - -// Creates the second phase of the topk split reduction by collapsing output -// from parallel topk and computing the final combined result. -TopkOp computeReductionTopk(Location loc, PatternRewriter &rewriter, - TopkOp topkOp, TopkOp parallelTopkOp, - Value updatedParallelIndices, - ArrayRef reassociationIndices, - int64_t splitReductionRatio, int64_t kDimOrig, - int64_t kSize) { - Value valuesOrig = topkOp.values(); - auto valuesOrigType = valuesOrig.getType().cast(); - Type valueElementType = valuesOrigType.getElementType(); - Type indicesElementType = - topkOp.getResultTypes()[1].cast().getElementType(); - - // Define the collapsed input shapes - SmallVector collapsedShape = getCollapsedShape( - valuesOrigType.getShape(), splitReductionRatio, kSize, kDimOrig); - auto valuesCollapsedType = - RankedTensorType::get(collapsedShape, valueElementType); - auto indicesCollapsedType = - RankedTensorType::get(collapsedShape, indicesElementType); - - // Collapse collapse parallel output for the input of final reduction - Value valuesCollapsed = rewriter.create( - loc, valuesCollapsedType, parallelTopkOp.getResults()[0], - reassociationIndices); - Value indicesCollapsed = rewriter.create( - loc, indicesCollapsedType, updatedParallelIndices, reassociationIndices); - - // Combined final topk - auto reductionTopkOp = - rewriter.create( - loc, - /*resultTypes=*/topkOp->getResultTypes(), - /*ins=*/ValueRange{valuesCollapsed, indicesCollapsed}, - /*outs=*/topkOp.getOutputs(), kDimOrig); - rewriter.cloneRegionBefore(topkOp.getRegion(), reductionTopkOp.getRegion(), - reductionTopkOp.getRegion().end()); - return reductionTopkOp; -} - -int64_t getSplitReductionDepth(TopkOp topkOp) { - auto attr = - topkOp->template getAttrOfType(kSplitReductionDepthMarker); - if (attr) { - return attr.getInt(); - } else { - return 0; - } -} - -void setSplitReductionDepth(TopkOp topkOp, PatternRewriter &rewriter, - int64_t depth) { - topkOp->setAttr(kSplitReductionDepthMarker, - rewriter.getI64IntegerAttr(depth)); -} - -struct TopkOpSplitReduction : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - TopkOpSplitReduction(MLIRContext *context, TopkSplitReductionControlFn fn, - LinalgTransformationFilter filt) - : OpRewritePattern(context), splitReductionFn(std::move(fn)), - filter(std::move(filt)) {} - - // Transforms an applicable standard single reduction TopkOp into a parallel - // reduction TopkOp with a reduce step following. - // - // Handles parallel reductions in 2 phases: A "map" parallel phase and the a - // single "reduce" reduction phase. The first phase expands the input tensor - // shape by breaking the reduction dimension into multiple parallel reductions - // (upping the rank of the input). Topk is run on these dimensions in parallel - // The second phase collapses the parallel results into a single final reduce. - // Topk is run again on the combined output to produce a final output. - // - // Currently only topk operations without input indices are supported. - LogicalResult matchAndRewrite(TopkOp topkOp, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, topkOp))) { - return rewriter.notifyMatchFailure(topkOp, "preconditions not met"); - } - Location loc = topkOp.getLoc(); - // Original reduction dimension used for the final combined reduction - int64_t kDimOrig = topkOp.getDimension(); - // For parallel topk: the dimension that we compute parallel reductions - int64_t splitDimParallel = kDimOrig; - // For parallel topk: the dimension that we reduce - int64_t kDimParallel = kDimOrig + 1; - int64_t kSize = - topkOp.getResult(0).getType().cast().getDimSize(kDimOrig); - int64_t splitReductionDepth = getSplitReductionDepth(topkOp); - int64_t splitReductionRatio = splitReductionFn(splitReductionDepth); - SmallVector reassociationIndices = - getReassociationIndices(topkOp.getInputRank(), splitDimParallel); - - // Determine if should compute parallel topk - LogicalResult shouldParallelTopkResult = shouldParallelTopk( - topkOp, rewriter, kDimOrig, splitReductionRatio, splitReductionDepth); - if (shouldParallelTopkResult.failed()) { - return shouldParallelTopkResult; - } - - // Topk parallel reduction - TopkOp parallelTopkOp = computeParallelTopk( - loc, rewriter, topkOp, reassociationIndices, splitReductionRatio, - splitDimParallel, kDimParallel, kSize); - - // Update parallel indices to correct offsets if input indices weren't - // provided. If input indices were provided, no offsetting is needed as - // original original indices are already known. - Value updatedParallelIndices = parallelTopkOp.getResult(1); - if (!topkOp.indices()) { - Value parallelIndices = parallelTopkOp.getResult(1); - SmallVector expandedShape = getExpandedShape( - topkOp.values().getType().cast().getShape(), - splitReductionRatio, splitDimParallel); - int64_t kDimParallelSize = expandedShape[kDimParallel]; - updatedParallelIndices = offsetParallelIndices( - loc, rewriter, parallelIndices, kDimParallelSize, splitDimParallel); - } - - // Topk final reduction - TopkOp reductionTopkOp = computeReductionTopk( - loc, rewriter, topkOp, parallelTopkOp, updatedParallelIndices, - reassociationIndices, splitReductionRatio, kDimOrig, kSize); - - // Replace and update result - rewriter.replaceOp(topkOp, reductionTopkOp.getResults()); - filter.replaceLinalgTransformationFilter(rewriter, parallelTopkOp); - setSplitReductionDepth(reductionTopkOp, rewriter, splitReductionDepth + 1); - return success(); - } - -private: - TopkSplitReductionControlFn splitReductionFn; - LinalgTransformationFilter filter; -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// Pass -//===----------------------------------------------------------------------===// - -namespace { -struct TopkSplitReductionPass - : public TopkSplitReductionBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - if (splitRatios.empty()) { - return; - } - RewritePatternSet patterns(&getContext()); - TopkSplitReductionControlFn splitReductionFn = - [&](int64_t splitReductionDepth) -> int64_t { - SmallVector reductionRatios(splitRatios.begin(), - splitRatios.end()); - if (splitReductionDepth >= reductionRatios.size()) { - return -1; - } else { - return reductionRatios[splitReductionDepth]; - } - }; - - patterns.add( - patterns.getContext(), splitReductionFn, - LinalgTransformationFilter( - ArrayRef{}, - StringAttr::get(patterns.getContext(), "SPLIT_REDUCTION"))); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - - // Remove all the markers at the end. - auto funcOp = getOperation(); - funcOp->walk([&](TopkOp op) { - op->removeAttr(LinalgTransforms::kLinalgTransformMarker); - op->removeAttr(kSplitReductionDepthMarker); - }); - } -}; -} // namespace - -void mlir::iree_compiler::IREE::LinalgExt::populateTopkSplitReductionPattern( - RewritePatternSet &patterns, - const TopkSplitReductionControlFn &splitReductionFn, - const LinalgTransformationFilter &f) { - patterns.add(patterns.getContext(), splitReductionFn, - f); -} - -std::unique_ptr> -mlir::iree_compiler::IREE::LinalgExt::createTopkSplitReductionPass() { - return std::make_unique(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp deleted file mode 100644 index 67b54fdbe4f6..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp +++ /dev/null @@ -1,381 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" -#include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/Support/Debug.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -namespace { - -static void computeLoopParams(SmallVectorImpl &lbs, - SmallVectorImpl &ubs, - SmallVectorImpl &steps, Value tensor, - int numImageDims, Location loc, - OpBuilder &builder) { - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - SmallVector dimValues = - tensor::createDimValues(builder, loc, tensor); - for (int i = numImageDims; i < dimValues.size(); i++) { - lbs.push_back(zero); - ubs.push_back(getValueOrCreateConstantIndexOp(builder, loc, dimValues[i])); - steps.push_back(one); - } -} - -class ReifyWinogradInputTransform final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WinogradInputTransformOp inputOp, - PatternRewriter &rewriter) const override { - Location loc = inputOp.getLoc(); - auto funcOp = inputOp->getParentOfType(); - if (!funcOp) { - return rewriter.notifyMatchFailure( - inputOp, "Could not find parent of type funcOp"); - } - - const float *BT{nullptr}; - const float *B{nullptr}; - const int64_t inputTileSize = inputOp.getInputTileSize(); - const int64_t outputTileSize = inputOp.getOutputTileSize(); - switch (outputTileSize) { - case 6: - B = IREE::LinalgExt::Winograd::B_6x6_3x3; - BT = IREE::LinalgExt::Winograd::BT_6x6_3x3; - break; - default: - return failure(); - } - /// The two values below are the transpose(B) [BTV] - /// and B [BV] constant matrices that convert the input - /// tile to the Winograd domain. - Value BTV = IREE::LinalgExt::createValueFrom2DConstant( - BT, inputTileSize, inputTileSize, loc, rewriter); - Value BV = IREE::LinalgExt::createValueFrom2DConstant( - B, inputTileSize, inputTileSize, loc, rewriter); - - Value input = inputOp.input(); - Value output = inputOp.output(); - auto outputType = output.getType().cast(); - auto inputType = input.getType().cast(); - ArrayRef inputShape = inputType.getShape(); - Type elementType = outputType.getElementType(); - SmallVector imageDims = inputOp.imageDimensions(); - const size_t numImageDims = imageDims.size(); - llvm::SmallSetVector imageDimsSet(imageDims.begin(), - imageDims.end()); - SmallVector inputTileSquare(imageDims.size(), inputTileSize); - - rewriter.setInsertionPointToStart(&funcOp.getBody().front()); - Value zeroF32 = rewriter.create( - loc, rewriter.getZeroAttr(elementType)); - Value scratch = - rewriter.create(loc, inputTileSquare, elementType); - - rewriter.setInsertionPoint(inputOp); - SmallVector lbs, ubs, steps; - computeLoopParams(lbs, ubs, steps, output, numImageDims, loc, rewriter); - // Construct loops - scf::LoopNest loopNest = scf::buildLoopNest( - rewriter, loc, lbs, ubs, steps, ValueRange({output}), - [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, - ValueRange iterArgs) -> scf::ValueVector { return {iterArgs[0]}; }); - - // Extract input slice - auto one = rewriter.getIndexAttr(1); - auto zero = rewriter.getIndexAttr(0); - auto inputTileSizeAttr = rewriter.getIndexAttr(inputTileSize); - SmallVector strides(inputOp.getInputOperandRank(), one); - SmallVector sizes(inputOp.getInputOperandRank(), one); - SmallVector offsets(inputOp.getInputOperandRank(), zero); - SmallVector ivs; - for (scf::ForOp loop : loopNest.loops) { - ivs.push_back(loop.getInductionVar()); - } - for (int i = 0; i < inputShape.size(); i++) { - if (!imageDimsSet.contains(i)) { - offsets[i] = ivs[i]; - } else { - rewriter.setInsertionPointToStart(loopNest.loops[i].getBody()); - AffineExpr dim0; - auto it = rewriter.getAffineConstantExpr(inputTileSize); - auto ot = rewriter.getAffineConstantExpr(outputTileSize); - auto delta = rewriter.getAffineConstantExpr(inputShape[i]); - bindDims(rewriter.getContext(), dim0); - AffineMap scaleMap = - AffineMap::get(1, 0, {dim0 * ot}, rewriter.getContext()); - offsets[i] = rewriter.createOrFold(loc, scaleMap, - ValueRange{ivs[i]}); - AffineMap minMap = - AffineMap::get(1, 0, {-dim0 + delta, it}, rewriter.getContext()); - sizes[i] = rewriter.createOrFold( - loc, minMap, - ValueRange{ - getValueOrCreateConstantIndexOp(rewriter, loc, offsets[i])}); - } - } - rewriter.setInsertionPointToStart(loopNest.loops.back().getBody()); - auto tensorType = RankedTensorType::get( - SmallVector(numImageDims, ShapedType::kDynamic), elementType); - Value dynamicSlice = rewriter.create( - loc, tensorType, input, offsets, sizes, strides); - - // Copy input slice into zeroed padded scratch space - strides = SmallVector(numImageDims, one); - offsets = SmallVector(numImageDims, zero); - sizes = SmallVector{sizes[1], sizes[2]}; - linalg::FillOp fillOp = rewriter.create( - loc, ValueRange{zeroF32}, ValueRange{scratch}); - Value inputSlice = rewriter.create( - loc, dynamicSlice, fillOp.result(), offsets, sizes, strides); - - // Extract output slice - strides = SmallVector(inputOp.getOutputOperandRank(), one); - offsets = SmallVector(numImageDims, zero); - offsets.append(ivs.begin(), ivs.end()); - sizes = SmallVector(inputOp.getOutputOperandRank(), one); - sizes[0] = sizes[1] = inputTileSizeAttr; - tensorType = RankedTensorType::get(inputTileSquare, elementType); - Value iterArg = loopNest.loops.back().getRegionIterArg(0); - Value outputSlice = rewriter.create( - loc, tensorType, iterArg, offsets, sizes, strides); - - // Create computation - Value result, AMatrix, BMatrix; - linalg::MatmulOp matmulOp; - for (int i = 0; i < 2; i++) { - fillOp = rewriter.create(loc, ValueRange{zeroF32}, - ValueRange{outputSlice}); - if (i == 0) { - AMatrix = inputSlice; - BMatrix = BV; - } else { - AMatrix = BTV; - BMatrix = result; - } - matmulOp = rewriter.create( - loc, tensorType, ValueRange{AMatrix, BMatrix}, fillOp.result()); - result = matmulOp.getResult(0); - } - - // Insert results into output slice - Value updatedOutput = rewriter.create( - loc, result, iterArg, offsets, sizes, strides); - - // Replace returned value - if (scf::YieldOp yieldOp = dyn_cast( - loopNest.loops.back().getBody()->getTerminator())) { - rewriter.replaceOpWithNewOp(yieldOp, updatedOutput); - } - inputOp.getResults()[0].replaceAllUsesWith(loopNest.results[0]); - return success(); - } -}; - -} // namespace - -namespace { - -class ReifyWinogradOutputTransform final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WinogradOutputTransformOp outputOp, - PatternRewriter &rewriter) const override { - Location loc = outputOp.getLoc(); - auto funcOp = outputOp->getParentOfType(); - if (!funcOp) { - return rewriter.notifyMatchFailure( - outputOp, "Could not find parent of type funcOp"); - } - - const float *AT{nullptr}; - const float *A{nullptr}; - const int64_t inputTileSize = outputOp.getInputTileSize(); - const int64_t outputTileSize = outputOp.getOutputTileSize(); - switch (outputTileSize) { - case 6: - A = IREE::LinalgExt::Winograd::A_6x6_3x3; - AT = IREE::LinalgExt::Winograd::AT_6x6_3x3; - break; - default: - return failure(); - } - /// The two values below are the transpose(A) [ATV] - /// and A [AV] constant matrices that convert the output - /// tile from the Winograd domain to the original domain. - Value ATV = IREE::LinalgExt::createValueFrom2DConstant( - AT, outputTileSize, inputTileSize, loc, rewriter); - Value AV = IREE::LinalgExt::createValueFrom2DConstant( - A, inputTileSize, outputTileSize, loc, rewriter); - - Value input = outputOp.input(); - Value output = outputOp.output(); - auto outputType = output.getType().cast(); - ArrayRef outputShape = outputType.getShape(); - Type elementType = outputType.getElementType(); - SmallVector imageDims = outputOp.imageDimensions(); - const size_t numImageDims = imageDims.size(); - llvm::SmallSetVector imageDimsSet(imageDims.begin(), - imageDims.end()); - SmallVector inputTileSquare(imageDims.size(), inputTileSize); - - rewriter.setInsertionPointToStart(&funcOp.getBody().front()); - Value zeroF32 = rewriter.create( - loc, rewriter.getZeroAttr(elementType)); - SmallVector scratchShape = {inputTileSize, outputTileSize}; - Value scratch = - rewriter.create(loc, scratchShape, elementType); - - rewriter.setInsertionPoint(outputOp); - SmallVector lbs, ubs, steps; - computeLoopParams(lbs, ubs, steps, input, numImageDims, loc, rewriter); - // Construct loops - scf::LoopNest loopNest = scf::buildLoopNest( - rewriter, loc, lbs, ubs, steps, ValueRange({output}), - [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, - ValueRange iterArgs) -> scf::ValueVector { return {iterArgs[0]}; }); - - // Extract input slice - rewriter.setInsertionPointToStart(loopNest.loops.back().getBody()); - auto one = rewriter.getIndexAttr(1); - auto zero = rewriter.getIndexAttr(0); - auto inputTileSizeAttr = rewriter.getIndexAttr(inputTileSize); - auto outputTileSizeAttr = rewriter.getIndexAttr(outputTileSize); - SmallVector strides(outputOp.getInputOperandRank(), one); - SmallVector sizes(outputOp.getInputOperandRank(), one); - SmallVector offsets(numImageDims, zero); - sizes[0] = sizes[1] = inputTileSizeAttr; - SmallVector ivs; - for (scf::ForOp loop : loopNest.loops) { - ivs.push_back(loop.getInductionVar()); - } - offsets.append(ivs.begin(), ivs.end()); - auto tensorType = RankedTensorType::get(inputTileSquare, elementType); - tensor::ExtractSliceOp extractSliceOp = - rewriter.create(loc, tensorType, input, offsets, - sizes, strides); - Value inputSlice = extractSliceOp.getResult(); - - // Extract output slice - strides = SmallVector(outputOp.getOutputOperandRank(), one); - offsets = SmallVector(outputOp.getOutputOperandRank(), zero); - sizes = SmallVector(outputOp.getOutputOperandRank(), one); - for (int i = 0; i < outputShape.size(); i++) { - if (!imageDimsSet.contains(i)) { - offsets[i] = ivs[i]; - } else { - rewriter.setInsertionPointToStart(loopNest.loops[i].getBody()); - AffineExpr dim0; - auto ot = rewriter.getAffineConstantExpr(outputTileSize); - bindDims(rewriter.getContext(), dim0); - AffineMap scaleMap = - AffineMap::get(1, 0, {dim0 * ot}, rewriter.getContext()); - offsets[i] = rewriter.createOrFold(loc, scaleMap, - ValueRange{ivs[i]}); - sizes[i] = outputTileSizeAttr; - } - } - rewriter.setInsertionPointAfter(extractSliceOp); - tensorType = RankedTensorType::get( - SmallVector(numImageDims, outputTileSize), elementType); - Value iterArg = loopNest.loops.back().getRegionIterArg(0); - Value outputSlice = rewriter.create( - loc, tensorType, iterArg, offsets, sizes, strides); - - // Create computation - Value result, AMatrix, BMatrix; - linalg::MatmulOp matmulOp; - linalg::FillOp fillOp; - Value tmp; - for (int i = 0; i < 2; i++) { - tmp = i == 0 ? scratch : outputSlice; - fillOp = rewriter.create(loc, ValueRange{zeroF32}, - ValueRange{tmp}); - if (i == 0) { - AMatrix = inputSlice; - BMatrix = AV; - } else { - AMatrix = ATV; - BMatrix = result; - } - matmulOp = rewriter.create( - loc, tmp.getType(), ValueRange{AMatrix, BMatrix}, fillOp.result()); - result = matmulOp.getResult(0); - } - - // Insert results into output slice - Value updatedOutput = rewriter.create( - loc, result, iterArg, offsets, sizes, strides); - - // Replace returned value - if (scf::YieldOp yieldOp = dyn_cast( - loopNest.loops.back().getBody()->getTerminator())) { - rewriter.replaceOpWithNewOp(yieldOp, updatedOutput); - } - outputOp.getResults()[0].replaceAllUsesWith(loopNest.results[0]); - return success(); - } -}; - -} // namespace - -namespace { -struct TileAndDecomposeWinogradTransformPass - : public TileAndDecomposeWinogradTransformBase< - TileAndDecomposeWinogradTransformPass> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override; -}; -} // namespace - -void TileAndDecomposeWinogradTransformPass::runOnOperation() { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(&getContext()); - patterns.insert( - context); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { - return signalPassFailure(); - } -} - -std::unique_ptr> -createTileAndDecomposeWinogradTransformPass() { - return std::make_unique(); -} - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp deleted file mode 100644 index 7c4cd797b1f0..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp +++ /dev/null @@ -1,445 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/Input/InputDialect.h" -#include "iree-dialects/Dialect/Input/InputOps.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; -namespace IREE = mlir::iree_compiler::IREE; -using namespace IREE::LinalgExt; - -//===----------------------------------------------------------------------===// -// Utility methods for tiling a linalg_ext operation that implements a -// TiledOpInterface -//===----------------------------------------------------------------------===// - -/// Returns failure if the options are unsupported. -static LogicalResult -verifySupportedTilingOptions(PatternRewriter &rewriter, Operation *op, - const linalg::LinalgTilingOptions &options) { - if (!options.interchangeVector.empty()) { - return rewriter.notifyMatchFailure(op, - "unsupported interchange during tiling"); - } - if (options.loopType != linalg::LinalgTilingLoopType::Loops) { - return rewriter.notifyMatchFailure(op, - "only tiling with scf.for is supported"); - } - return success(); -} - -/// Returns true if loop is untiled. Only checks if the value is statically -/// zero. It is assumed that a `Value` defined by a constant op is already -/// converted to an `IntegerAttr` of that value. So here just return true if -/// this is an attribute with a zero value. -static bool isUntiledLoop(OpFoldResult valueOrAttr) { - Optional intVal = getConstantIntValue(valueOrAttr); - return intVal && *intVal == 0; -} - -/// Generates the tiled loops and the body by invoking the interface methods of -/// TiledOpInterface. -/// - `outputs` are the operands to use for outputs of the tiled operation. -/// - `tileSizes` are tile sizes specified for all loops of the operation. If a -/// loop is to be untiled it is set to 0. -/// - `iteratorType` is the type of the loop iterator returned by the -/// TiledOpInterface. -/// - `loopBounds` are the bounds of all the loops of the op returned by the -/// TiledOpInterface. -/// - `loopDepth` is the current loop depth being processed. -/// - `offsets` are the `Value`s that represent the position of the tile being -/// operated on. The offsets are computed as the tiled loops are being -/// generated. -/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for -/// distributed loops. It is a stack, and once an entry at the top of the -/// stack is used for distribution it is popped before processing the inner -/// loops. -static FailureOr -tileInterfaceOpImpl(OpBuilder &builder, TilingInterface tilableOp, - ValueRange outputs, MutableArrayRef tileSizes, - ArrayRef iteratorTypes, - ArrayRef loopBounds, unsigned loopDepth, - SmallVectorImpl &offsets, - ArrayRef distributionInfo) { - Location loc = tilableOp.getLoc(); - // If this is the innermost loop, then generated the tiled implementation of - // the op by invoking the TiledOpInterface methods. - if (loopDepth == tileSizes.size()) { - TiledOp ret; - SmallVector tiledOps = - tilableOp.getTiledImplementation(builder, offsets, tileSizes); - if (tiledOps.empty()) { - return static_cast( - tilableOp.emitOpError("failed to get tiled implementation")); - } - assert( - (tiledOps.size() == 1 || - (tiledOps.size() == 2 && - isa(tilableOp.getOperation()))) && - "expected only a single operation returned from tiling implementation"); - ret.op.assign(tiledOps); - for (auto result : llvm::enumerate(ret.op.back()->getResults())) { - if (!result.value().getType().isa()) { - ret.results.push_back(result.value()); - continue; - } - SmallVector resultOffsets, resultSizes; - if (succeeded(tilableOp.getResultTilePosition( - builder, result.index(), offsets, tileSizes, resultOffsets, - resultSizes))) { - SmallVector resultStrides(resultOffsets.size(), - builder.getIndexAttr(1)); - Value insertSlice = builder.create( - loc, ret.op.back()->getResult(result.index()), - outputs[result.index()], resultOffsets, resultSizes, resultStrides); - ret.results.push_back(insertSlice); - } - } - return ret; - } - - // If tile size at this depth is empty, do nothing. - if (isUntiledLoop(tileSizes[loopDepth])) { - auto zeroAttr = builder.getI64IntegerAttr(0); - offsets.push_back(zeroAttr); - tileSizes[loopDepth] = loopBounds[loopDepth].size; - return tileInterfaceOpImpl(builder, tilableOp, outputs, tileSizes, - iteratorTypes, loopBounds, loopDepth + 1, - offsets, distributionInfo); - } - - // Generate an scf.for for the current loop depth. - Value lb = getValueOrCreateConstantIndexOp(builder, loc, - loopBounds[loopDepth].offset); - Value ub = - getValueOrCreateConstantIndexOp(builder, loc, loopBounds[loopDepth].size); - // TODO(#7073): Put the check back. This is required by tiling linalg_ext.fft - // op. We can put the check back after updating linalg_ext.fft semantics. - // if (!matchPattern(loopBounds[loopDepth].stride, m_One())) { - // return static_cast( - // tilableOp.emitOpError("expected stride to be 1")); - //} - Value step = - getValueOrCreateConstantIndexOp(builder, loc, tileSizes[loopDepth]); - - // Update lb, ub and step for cyclic distribution. - if (!distributionInfo.empty() && - iteratorTypes[loopDepth] == utils::IteratorType::parallel) { - linalg::updateBoundsForCyclicDistribution( - builder, loc, distributionInfo.front().procId, - distributionInfo.front().nprocs, lb, ub, step); - distributionInfo = distributionInfo.drop_front(); - } - FailureOr innerReturnValue; - bool isBufferTiling = tilableOp->getNumResults() == 0; - ValueRange initValues(isBufferTiling ? ValueRange{} : outputs); - auto forOp = builder.create( - loc, lb, ub, step, initValues, - [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - offsets.push_back(iv); - auto affineMaps = AffineMap::inferFromExprList({ArrayRef{ - b.getAffineSymbolExpr(0), - b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0]; - // Similar to linalg tiling, the tile size is the min(tileSizes, ub - - // iv) to account for cases where tile size does not divide (ub - lb) - // exactly. - Value inBoundsTileSize = b.create( - loc, affineMaps, - ValueRange{iv, - getValueOrCreateConstantIndexOp(builder, loc, - tileSizes[loopDepth]), - ub}); - tileSizes[loopDepth] = getAsOpFoldResult(inBoundsTileSize); - // Recursively proceed to generate the tiled loop for the next level. - innerReturnValue = - tileInterfaceOpImpl(b, tilableOp, (isBufferTiling ? outputs : args), - tileSizes, iteratorTypes, loopBounds, - loopDepth + 1, offsets, distributionInfo); - if (failed(innerReturnValue)) - return; - b.create(loc, innerReturnValue->results); - }); - if (failed(innerReturnValue)) { - return innerReturnValue; - } - innerReturnValue->loops.insert(innerReturnValue->loops.begin(), - forOp.getOperation()); - innerReturnValue->results = forOp.getResults(); - return innerReturnValue; -} - -FailureOr tileInterfaceOp(OpBuilder &b, TilingInterface tilableOp, - const linalg::LinalgTilingOptions &options) { - - // Gather destination tensors. - SmallVector dest; - Location loc = tilableOp.getLoc(); - - if (failed(tensor::getOrCreateDestinations(b, loc, tilableOp, dest))) - return tilableOp->emitOpError("failed to get destination tensors"); - - SmallVector iteratorTypes = - tilableOp.getLoopIteratorTypes(); - SmallVector tileSizesVals = - options.tileSizeComputationFunction(b, tilableOp); - auto zeroAttr = b.getI64IntegerAttr(0); - - // The actual tile sizes used converts `Value` defined as constant 0, to a - // zero integer attributes. Currently if the iterator type is not "parallel", - // the tile size is forced to zero as well. - auto tileSizes = getAsOpFoldResult(tileSizesVals); - tileSizes.resize(iteratorTypes.size(), zeroAttr); - for (auto en : llvm::enumerate(iteratorTypes)) { - if (en.value() == utils::IteratorType::parallel) - continue; - if (!isUntiledLoop(tileSizes[en.index()])) { - return static_cast(tilableOp.emitOpError( - "unimplemented tiling of non-parallel loop iterator type")); - } - } - - // Trivial early exit case of tile sizes being zero for all parallel loops. - if (llvm::all_of(tileSizes, isUntiledLoop)) { - return TiledOp{{tilableOp}, {}, {}}; - } - - SmallVector loopBounds = tilableOp.getIterationDomain(b); - SmallVector distributionInfo; - // If the tiled loops are distributed, get the proc_id and nprocs for the - // distributed loops. First collect the parallel loops by iterating over the - // tileSizes and getting the loops that are distribute, i.e., - // - parallel, i.e. iteratorTypes is "parallel" - // - tiled, i.e. tileSize != 0 - if (options.distribution) { - SmallVector distributedLoopRange; - for (auto i : llvm::seq(0, tileSizes.size())) { - if (isUntiledLoop(tileSizes[i])) - continue; - if (iteratorTypes[i] != utils::IteratorType::parallel) - continue; - distributedLoopRange.push_back(loopBounds[i]); - } - distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(), - distributedLoopRange); - } - - SmallVector offsets; - return tileInterfaceOpImpl(b, tilableOp, dest, tileSizes, iteratorTypes, - loopBounds, 0, offsets, distributionInfo); -} - -LogicalResult -TilingInterfaceBaseTilingPattern::matchAndRewriteBase(TilingInterface tilableOp, - PatternRewriter &rewriter, - TiledOp &result) const { - if (failed(filter.checkAndNotify(rewriter, tilableOp))) { - return failure(); - } - if (failed(verifySupportedTilingOptions(rewriter, tilableOp, options))) { - return failure(); - } - - FailureOr res = tileInterfaceOp(rewriter, tilableOp, options); - if (failed(res)) { - return res; - } - result = *res; - for (auto op : result.op) { - filter.replaceLinalgTransformationFilter(rewriter, op); - } - return success(); -} - -LogicalResult -TilingInterfaceTilingPattern::matchAndRewrite(TilingInterface tilableOp, - PatternRewriter &rewriter) const { - // `LinalgOp`s also implement the `TilingInterface`. Do not handle LinalgOps - // in this pattern. For now use these only for `LinalgExt` ops. This pattern - // is to be deprecated to use something that can handle all `TilingInterface` - // ops. - if (isa(tilableOp.getOperation())) { - return rewriter.notifyMatchFailure(tilableOp, "ignoring LinalgOps"); - } - TiledOp tiledOp; - // Check for failure. - if (failed(TilingInterfaceBaseTilingPattern::matchAndRewriteBase( - tilableOp, rewriter, tiledOp))) { - return failure(); - } - // Check for do-nothing case. - if (tiledOp.op.empty()) - return failure(); - if (tiledOp.op.back() != tilableOp) { - if (tiledOp.results.empty()) { - rewriter.eraseOp(tilableOp); - } else { - rewriter.replaceOp(tilableOp, tiledOp.results); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Test pass for tiling Linalg Ext ops -//===----------------------------------------------------------------------===// - -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; - -struct TilingInterfaceTilingPass - : public TilingInterfaceTilingBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert< - AffineDialect, IREE::Input::IREEInputDialect, linalg::LinalgDialect, - IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect, - func::FuncDialect, mlir::arith::ArithDialect, math::MathDialect, - tensor::TensorDialect, scf::SCFDialect>(); - } - void runOnOperation() override; -}; -} // namespace - -template -static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) { - return b.template create(b.getInsertionPoint()->getLoc(), dim); -} - -void TilingInterfaceTilingPass::runOnOperation() { - func::FuncOp funcOp = getOperation(); - MLIRContext *context = funcOp.getContext(); - - RewritePatternSet patterns(context); - patterns.add( - context, linalg::LinalgTilingOptions().setTileSizes({10, 20}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "tiling_input"), - StringAttr::get(context, "tiling_output"))); - patterns.add( - context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef{0}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "no_tiling_input"), - StringAttr::get(context, "no_tiling_output"))); - - patterns.add( - context, linalg::LinalgTilingOptions().setTileSizes({0, 20}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "outer_reduce_input"), - StringAttr::get(context, "outer_reduce_output"))); - patterns.add( - context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "inner_reduce_input"), - StringAttr::get(context, "inner_reduce_output"))); - - static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = { - [](OpBuilder &builder, Location loc, ArrayRef parallelLoopRanges) { - auto numParallelDims = parallelLoopRanges.size(); - - SmallVector procInfo(numParallelDims); - for (size_t dim = 0; dim < numParallelDims; ++dim) { - procInfo[numParallelDims - dim - 1] = { - buildFlowWorkgroupInfoOp( - builder, dim), - buildFlowWorkgroupInfoOp( - builder, dim)}; - } - return procInfo; - }}; - - patterns.add( - context, - linalg::LinalgTilingOptions() - .setTileSizes(ArrayRef{10, 0, 30}) - .setDistributionOptions(workgroupDistributionOptions), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "distribute_input"), - StringAttr::get(context, "distribute_output"))); - - patterns.add( - context, - linalg::LinalgTilingOptions().setTileSizes(ArrayRef{32}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "tiling_1d_stage5_fft_input"), - StringAttr::get(context, "tiling_1d_stage5_fft_output"))); - - patterns.add( - context, - linalg::LinalgTilingOptions().setTileSizes(ArrayRef{10, 32}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "tiling_2d_stage5_fft_input"), - StringAttr::get(context, "tiling_2d_stage5_fft_output"))); - - patterns.add( - context, linalg::LinalgTilingOptions().setTileSizes({0, 20}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "tiling_repeated_indices_scatter_input"), - StringAttr::get(context, "tiling_repeated_indices_scatter_output"))); - - patterns.add( - context, linalg::LinalgTilingOptions().setTileSizes({1, 32}), - IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "tiling_winograd_input_nhwc"))); - - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - - // TODO(hanchung): Deprecate IREE specific logic. We should move to use - // upstream scf::tileUsingSCFForOp method. For now only uses it for packing - // and unpacking ops. - { - SimpleRewriter rewriter(context); - auto filter = IREE::LinalgExt::LinalgTransformationFilter( - StringAttr::get(context, "tiling_pack_input"), - StringAttr::get(context, "tiling_pack_output")); - auto options = scf::SCFTilingOptions().setTileSizes({2, 4}); - auto funcOp = getOperation(); - funcOp->walk([&](Operation *tilableOp) { - if (failed(filter.checkAndNotify(rewriter, tilableOp))) { - return; - } - - FailureOr tilingResult = scf::tileUsingSCFForOp( - rewriter, cast(tilableOp), options); - if (failed(tilingResult)) - return signalPassFailure(); - rewriter.replaceOp(tilableOp, tilingResult->replacements); - - for (auto op : tilingResult.value().tiledOps) { - filter.replaceLinalgTransformationFilter(rewriter, op); - } - }); - } -} - -std::unique_ptr> -IREE::LinalgExt::createTilingInterfaceTilingPass() { - return std::make_unique(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/CMakeLists.txt deleted file mode 100644 index 96e8d8568337..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_library(IREELinalgExtTransformOps - LinalgExtTransformOps.cpp - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - IREEDialectsTransforms - MLIRRewrite - MLIRTransformDialect - - IREELinalgExtDialect - IREELinalgExtTransforms - - MLIRPDLDialect -) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp deleted file mode 100644 index 20465dcf2f47..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" -#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/OpImplementation.h" -#include "llvm/Support/FormatVariadic.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE; - -LinalgExt::LinalgExtTransformOpsExtension::LinalgExtTransformOpsExtension() { - registerTransformOps< -#define GET_OP_LIST -#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp.inc" - >(); -} - -//===---------------------------------------------------------------------===// -// Utility functions -//===---------------------------------------------------------------------===// - -/// Extracts a vector of int64_t from an array attribute. Asserts if the -/// attribute contains values other than integers. -static SmallVector extractI64Array(ArrayAttr attr) { - SmallVector result; - result.reserve(attr.size()); - for (APInt value : attr.getAsValueRange()) - result.push_back(value.getSExtValue()); - return result; -} - -//===---------------------------------------------------------------------===// -// FuseProducersOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -LinalgExt::FuseProducersOp::apply(transform::TransformResults &transformResults, - transform::TransformState &state) { - SmallVector operandsToFuse = extractI64Array(getOperandsToFuse()); - LinalgExt::LinalgExtFusionPattern pattern(getContext(), operandsToFuse); - size_t numProducers = operandsToFuse.size(); - - SmallVector transformedOps; - SmallVector> fusedOps(numProducers); - for (Operation *target : state.getPayloadOps(getTarget())) { - // Apply the pattern. - SimplePatternRewriter rewriter(target); - FailureOr result = - pattern.returningMatchAndRewrite(target, rewriter); - if (failed(result)) - return DiagnosedSilenceableFailure::definiteFailure(); - - // Update the fused operations. - transformedOps.push_back(result->consumerOp); - for (size_t i = 0; i < numProducers; ++i) - fusedOps[i].push_back(result->fusedOps[i]); - } - - transformResults.set(getTransformed().cast(), transformedOps); - for (size_t i = 0; i < numProducers; ++i) - transformResults.set(getFusedOps()[i], fusedOps[i]); - return DiagnosedSilenceableFailure::success(); -} - -LogicalResult LinalgExt::FuseProducersOp::verify() { - SmallVector operandsToFuse = extractI64Array(getOperandsToFuse()); - llvm::SmallDenseSet operandsSet; - for (int64_t operandToFuse : operandsToFuse) { - if (operandToFuse < 0) { - return emitOpError() << "expects positive operand numbers, found " - << operandToFuse; - } - if (operandsSet.count(operandToFuse) != 0) { - return emitOpError() << "expects unique operand numbers, found " - << operandToFuse << " multiple times"; - } - operandsSet.insert(operandToFuse); - } - return success(); -} - -ParseResult LinalgExt::FuseProducersOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand targetOperand; - SMLoc opLoc; - if (parser.getCurrentLocation(&opLoc)) - return failure(); - if (parser.parseOperand(targetOperand)) - return parser.emitError(opLoc, "expected `target` operand"); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - StringRef operandsToFuseAttrName("operands_to_fuse"); - Attribute operandsToFuseAttr = result.attributes.get(operandsToFuseAttrName); - if (!operandsToFuseAttr) { - return parser.emitError(opLoc, llvm::formatv("expected `{0}` attribute", - operandsToFuseAttrName)); - } - auto operandsToFuseArrayAttr = operandsToFuseAttr.dyn_cast(); - if (!operandsToFuseArrayAttr) { - return parser.emitError(opLoc, - llvm::formatv("`{0}` attribute must be an array", - operandsToFuseAttrName)); - } - Type pdlOpType = parser.getBuilder().getType(); - size_t numProducers = operandsToFuseArrayAttr.size(); - result.addTypes(SmallVector(numProducers + 1, pdlOpType)); - if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) - return failure(); - return success(); -} - -void LinalgExt::FuseProducersOp::print(OpAsmPrinter &p) { - p << ' '; - p << getTarget(); - p.printOptionalAttrDict((*this)->getAttrs()); -} - -DiagnosedSilenceableFailure -LinalgExt::RewriteForeachThreadToAsyncOp::applyToOne( - scf::ForeachThreadOp target, SmallVectorImpl &results, - transform::TransformState &state) { - LinalgExt::ForeachThreadOpToAsyncRewriter pattern(this->getContext()); - SimplePatternRewriter rewriter(target); - FailureOr result = - pattern.returningMatchAndRewrite(target, rewriter); - if (failed(result)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - results.assign({*result}); - return DiagnosedSilenceableFailure(success()); -} - -DiagnosedSilenceableFailure -LinalgExt::RewriteForeachThreadToScfForOp::applyToOne( - scf::ForeachThreadOp target, SmallVectorImpl &results, - transform::TransformState &state) { - LinalgExt::ForeachThreadOpToScfForRewriter pattern(this->getContext()); - SimplePatternRewriter rewriter(target); - FailureOr result = - pattern.returningMatchAndRewrite(target, rewriter); - if (failed(result)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - results.assign({*result}); - return DiagnosedSilenceableFailure(success()); -} - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp.inc" diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt deleted file mode 100644 index f02241a93d5e..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ /dev/null @@ -1,34 +0,0 @@ -add_mlir_library(IREELinalgExtTransforms - CodegenStrategy.cpp - ForeachThreadToAsync.cpp - ForeachThreadToSequentialFor.cpp - Fusion.cpp - Tiling.cpp - Transforms.cpp - Utils.cpp - - PARTIAL_SOURCES_INTENDED - DEPENDS - mlir-headers - IREELinalgExtDialect - - LINK_LIBS PUBLIC - IREELinalgExtDialect - - MLIRAffineToStandard - MLIRAsyncDialect - MLIRSCFToControlFlow - MLIRLinalgToLLVM - MLIRDialectUtils - MLIRVectorToLLVM - MLIRMathToLLVM - MLIRMemRefToLLVM - MLIRIR - MLIRMathDialect - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRPass - MLIRSCFDialect - MLIRTensorTransforms - MLIRTransforms -) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp deleted file mode 100644 index af748a794d35..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "mlir/Pass/PassManager.h" - -using namespace mlir; - -#define DEBUG_TYPE "linalg-codegen-strategy" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -void CodegenStrategy::configurePassPipeline(OpPassManager &pm, - MLIRContext *context, - bool addEnablePass) const { - for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e; - ++stepCount) { - const std::unique_ptr &t = - transformationSequence[stepCount]; - std::string currentStr = std::to_string(stepCount); - auto currentState = StringAttr::get(context, currentStr); - std::string nextStr = std::to_string(stepCount + 1); - auto nextState = StringAttr::get(context, nextStr); - auto filter = (currentState.str() == std::to_string(0)) - ? LinalgExt::LinalgTransformationFilter( - t->filter, ArrayRef{}, nextState) - : LinalgExt::LinalgTransformationFilter( - t->filter, currentState, nextState); - t->addToPassPipeline(pm, filter); - if (addEnablePass) - pm.addPass(createLinalgStrategyEnablePass(linalgEnablingOptions)); - } - pm.addPass(createLinalgStrategyRemoveMarkersPass()); -} - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp deleted file mode 100644 index fa3cf3f334a7..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include - -#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" -#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Async/IR/Async.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -FailureOr -mlir::iree_compiler::IREE::LinalgExt::ForeachThreadOpToAsyncRewriter:: - returningMatchAndRewrite(scf::ForeachThreadOp foreachThreadOp, - PatternRewriter &rewriter) const { - if (foreachThreadOp.getNumResults() > 0) - return foreachThreadOp->emitError( - "only bufferized scf.foreach_thread lowers to async"); - - if (foreachThreadOp.getNumThreads().size() > 1) - return foreachThreadOp->emitError( - "only single-dimension scf.foreach_thread lowers to async"); - - // Only consider the top level ForeachThreadOp op and skip if it already - // contains an ExecuteOp. - if (foreachThreadOp->getParentOfType() || - llvm::any_of(foreachThreadOp.getBody()->getOperations(), - [](Operation &op) { return isa(&op); })) - return failure(); - - auto *ctx = foreachThreadOp.getContext(); - Location loc = foreachThreadOp.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - // TODO: allow multi-dim. - Value numThreads = foreachThreadOp.getNumThreads().front(); - - // Wrap the scf.foreach_thread into an async::ExecuteOp. - // 1. Create the async::GroupType object on which we synchronize. - Value asyncGroup = rewriter.create( - loc, async::GroupType::get(ctx), numThreads); - - // 2. Create a bodyless forOp. - scf::ForOp forOp = rewriter.create(loc, zero, numThreads, one); - rewriter.setInsertionPointToStart(forOp.getBody()); - - // 3. Create an empty executeOp, nested within the forOp. - auto noopExec = [&](OpBuilder &executeBuilder, Location executeLoc, - ValueRange executeArgs) {}; - auto executeOp = - rewriter.create(loc, /*resultTypes=*/TypeRange(), - /*dependencies=*/ValueRange(), - /*operands=*/ValueRange(), noopExec); - - // 3. Steal the ops nested under scf::ForeachThread, except the terminator, - // into the body of the async::ExecuteOp, just before the terminator. - SmallVector bbArgsTranslated{forOp.getInductionVar()}; - rewriter.mergeBlocks(&foreachThreadOp.getRegion().front(), - executeOp.getBody(), bbArgsTranslated); - // 3.b. Erase the terminator stolen from foreachThreadOp. - rewriter.eraseOp(&executeOp.getBody()->back()); - // 3.c. Erase foreachThreadOp. - rewriter.eraseOp(foreachThreadOp); - // 3.d. Add ExecuteOp terminator. - rewriter.setInsertionPointToEnd(executeOp.getBody()); - rewriter.create(loc, ValueRange{}); - // 3.e. Add to group within the loop. - rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); - rewriter.create(loc, rewriter.getIndexType(), - executeOp.getToken(), asyncGroup); - - // 4. After the iree_compiler::IREE::LinalgExt::ForeachThread, await all async - // tasks in `asyncGroup`. - rewriter.setInsertionPointAfter(forOp); - return rewriter.create(loc, asyncGroup).getOperation(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp deleted file mode 100644 index 789335aa4578..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" -#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -namespace { - -SmallVector getValuesToYield(scf::PerformConcurrentlyOp op) { - return llvm::to_vector( - llvm::map_range(op.getYieldingOps(), [](Operation &op) -> Value { - return cast(&op).getDest(); - })); -} - -} // namespace - -FailureOr ForeachThreadOpToScfForRewriter::returningMatchAndRewrite( - scf::ForeachThreadOp foreachThreadOp, PatternRewriter &rewriter) const { - if (foreachThreadOp.getNumResults() > 0) - return foreachThreadOp->emitError( - "only bufferized scf.foreach_thread lowers to scf.for"); - - if (foreachThreadOp.getNumThreads().size() > 1) - return foreachThreadOp->emitError( - "only single-dimension scf.foreach_thread lowers to scf.for"); - - // Construct the loop bounds based on the canonical arithmetic progression. - Location loc = foreachThreadOp.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - // TODO: allow multi-dim. - Value numThreads = foreachThreadOp.getNumThreads().front(); - - // Construct the op without a body builder: we need to clone the ops in the - // body explicitly after having access to the new bbArgs. - // As a consequence, `ensureTerminator` is not called and the `forOp` body - // has no terminator. - scf::PerformConcurrentlyOp performConcurrentlyOp = - foreachThreadOp.getTerminator(); - SmallVector valuesToYield = getValuesToYield(performConcurrentlyOp); - scf::ForOp forOp = - rewriter.create(loc, zero, numThreads, one, valuesToYield); - - // Move the body while replacing the threadId by the forOp iv. - SmallVector bbArgsTranslated{forOp.getInductionVar()}; - Block *body = forOp.getBody(); - bool hasTerminator = - !body->empty() && body->back().hasTrait(); - if (hasTerminator) { - rewriter.mergeBlockBefore(&foreachThreadOp.getRegion().front(), - body->getTerminator(), bbArgsTranslated); - } else { - rewriter.mergeBlocks(&foreachThreadOp.getRegion().front(), body, - bbArgsTranslated); - } - - rewriter.setInsertionPointToStart(body); - BlockAndValueMapping bvm; - bvm.map(valuesToYield, forOp.getRegionIterArgs()); - - // Create sequential insertSlice ops. - SmallVector toYield; - rewriter.setInsertionPoint(performConcurrentlyOp); - for (Operation &operation : performConcurrentlyOp.getYieldingOps()) { - tensor::ParallelInsertSliceOp op = - cast(&operation); - toYield.push_back(rewriter.createOrFold( - loc, op.getSource(), bvm.lookup(op.getDest()), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides())); - } - - // performConcurrentlyOp.yieldedValues come from above, not from bbArgs. - // There is no rewriter method to make mergeBlocks update non-bbArgs. - // Need to manually clone + bvm all uses that are now nested under forOp. - // Warning: this replacement is currently optimistic and may change the - // semantics as explained in the pass description in Passes.td. - SmallVector opsToReplace; - for (Value toReplace : valuesToYield) { - for (OpOperand &u : toReplace.getUses()) { - Operation *op = u.getOwner(); - if (!forOp->isProperAncestor(op)) - continue; - opsToReplace.push_back(op); - } - } - for (Operation *op : opsToReplace) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); - Operation *cloned = rewriter.clone(*op, bvm); - rewriter.replaceOp(op, cloned->getResults()); - } - - // Insert terminator. - if (!hasTerminator) { - rewriter.setInsertionPointToEnd(body); - rewriter.create(loc, toYield); - } - - // Cleanup and replace. - rewriter.eraseOp(performConcurrentlyOp); - rewriter.replaceOp(foreachThreadOp, forOp.getResults()); - - return forOp; -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp deleted file mode 100644 index 7069bed159d8..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/TilingInterface.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -FailureOr LinalgExtFusionPattern::returningMatchAndRewrite( - TilingInterface consumerOp, PatternRewriter &rewriter) const { - // Try to fuse the producers of all operands to fuse. - SmallVector fusedOps; - for (int64_t operandToFuse : operandsToFuse) { - // Check the operand exists. - if (operandToFuse >= consumerOp->getNumOperands()) - return failure(); - - // Check the operand is a slice of a producer result. - auto sliceOp = consumerOp->getOperand(operandToFuse) - .getDefiningOp(); - if (!sliceOp) - return failure(); - auto producerOp = sliceOp.getSource().getDefiningOp(); - if (!producerOp || producerOp->getNumResults() != 1) - return failure(); - - // Tile the producer. - FailureOr tiledProducer = producerOp.generateResultTileValue( - rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes()); - if (failed(tiledProducer)) - return failure(); - fusedOps.push_back(cast(tiledProducer->getDefiningOp())); - } - - // Update the consumer in-place using the tiled producer results. - SmallVector newOperands = consumerOp->getOperands(); - for (auto it : llvm::zip(operandsToFuse, fusedOps)) { - int64_t operandToFuse = std::get<0>(it); - TilingInterface fusedOp = std::get<1>(it); - newOperands[operandToFuse] = fusedOp->getResult(0); - } - rewriter.updateRootInPlace(consumerOp, - [&]() { consumerOp->setOperands(newOperands); }); - - return FusionResult{consumerOp, fusedOps}; -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp deleted file mode 100644 index de4bff6b09a7..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" -#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -// TODO: connect these patterns to PDL. Either via the transform dialect or via -// PDLL. - -static bool isZero(Value v) { - if (auto cst = v.getDefiningOp()) - return cst.value() == 0; - return false; -} - -SmallVector tileToSCF(PatternRewriter &rewriter, TilingInterface op, - TilingInterface clonedOp, ValueRange tileSizes) { - // Compute lower and upper bounds of the loop nest. - SmallVector ranges = clonedOp.getIterationDomain(rewriter); - assert(tileSizes.size() <= ranges.size() && - "expected tile sizes to match the number of loops"); - - // Fill the tile sizes with zeros for the untiled dimensions. - Location loc = op->getLoc(); - SmallVector tileSizesVec = getAsOpFoldResult(tileSizes); - if (ranges.size() != tileSizes.size()) { - tileSizesVec.resize(ranges.size(), rewriter.getIndexAttr(0)); - } - - SmallVector lbs, dims, steps; - SmallVector allDims; - for (auto it : llvm::enumerate(ranges)) { - allDims.push_back(it.value().size); - if (!isConstantIntValue(tileSizesVec[it.index()], 0)) { - lbs.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, it.value().offset)); - dims.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, it.value().size)); - steps.push_back(getValueOrCreateConstantIndexOp( - rewriter, loc, tileSizesVec[it.index()])); - } - } - - // Generate loop nest: One loop per dimension. - llvm::SmallPtrSet preservedUses; - SmallVector destOperand; - - if (failed(tensor::getOrCreateDestinations(rewriter, loc, clonedOp, - destOperand))) - clonedOp->emitOpError("failed to get destOperand"); - - auto loopNest = mlir::scf::buildLoopNest( - rewriter, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), - [&](OpBuilder &b, Location loc, ValueRange localIvs, - ValueRange iterArgs) -> scf::ValueVector { - // Compute offsets and sizes of ExtractSliceOp. - SmallVector offsets = linalg::computeTileOffsets( - b, loc, getAsOpFoldResult(localIvs), tileSizesVec); - SmallVector sizes = - linalg::computeTileSizes(b, loc, tileSizesVec, allDims); - // Create ExtractSliceOp: Extract a tile from the PadOp. - // Note: The PadOp is located outside of the loop nest. It is - // later moved inside by ExtractSliceOfPadTensorSwapPattern. - auto map = - AffineMap::getMultiDimIdentityMap(ranges.size(), b.getContext()); - assert(clonedOp->getNumResults() == 1 && "expected single result op"); - Value tiledOutput = linalg::makeTiledShape( - b, loc, clonedOp->getResult(0), tileSizesVec, map, offsets, allDims, - sizes, /*omitPartialTileCheck=*/false); - auto sliceOp = tiledOutput.getDefiningOp(); - preservedUses.insert(sliceOp); - assert(sliceOp && "expected ExtractSliceOp"); - // Insert the tile into the output tensor. - Value yieldValue = - createMatchingSubsetInsertOp(b, loc, sliceOp, sliceOp, iterArgs[0]); - return scf::ValueVector({yieldValue}); - }); - return loopNest.results; -} - -namespace { - -/// The tiling here works by two steps. The first step is to create a loop based -/// on the loop bounds of the operation obtained from `TilingInterface`. -/// -/// ```mlir -/// %1 = ins(...) outs(%0 : ...) -/// ... ... %1 ... -/// ``` -/// -/// is rewritten using a "noop" subtensor extract/insert pair -/// -/// ```mlir -/// %1 = ins(...) outs(%0 : ...) -/// %2 = scf.for %iv0 = ... iter_args(%arg0 = %0) { -/// %3 = scf.for %iv1 = ... iter_args(%arg1 = %arg0) { -/// ... -/// %4 = tensor.extract_slice %1[%iv0, %iv1].... -/// %5 = tensor.insert_slice %4 into %arg1[%iv0, %iv1]... -/// scf.yield %5 -/// } -/// scf.yield %3 -/// } -/// ... ... %2 ... -/// ``` -/// -/// Following this the `TilingInterface` -> `tensor::ExtractSliceOp` pattern is -/// replaced with -/// -/// /// ```mlir -/// %2 = scf.for %iv0 = ... iter_args(%arg0 = %0) { -/// %3 = scf.for %iv1 = ... iter_args(%arg1 = %arg0) { -/// ... -/// %4 = tensor.extract_slice %0[%iv0, %iv1] -/// %5 = ins(...) outs(%4 : ...) -/// %6 = tensor.insert_slice %5 into %arg1[%iv0, %iv1]... -/// scf.yield %6 -/// } -/// scf.yield %3 -/// } -/// ... ... %2 ... -/// ``` -/// -/// TODO(ravishankarm): The current approach seems to work for only tiling the -/// parallel loops of the operation. Specifically, -/// 1) the `%0` in the third snippet needs to be `%arg1`, for cases where the -/// tiled loop is a reduction. -/// 2) Current implementation is using the `getIterationDomain` method to get -/// the -/// initial loop structure as described in the second snippet. If any of -/// those loops are reductions, then that IR snippet itself is wrong (replace -/// this with the case of `linalg.matmul` and the error becomes apparent). - -/// First pattern to introduce the loop nests. -struct OpTilingPattern : public OpInterfaceRewritePattern { - OpTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions opt, - LinalgTransformationFilter filt) - : OpInterfaceRewritePattern(context), options(opt), - filter(filt) {} - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - /// Currently only handle single result operations. - if (op->getNumResults() != 1) - return failure(); - - Location loc = op->getLoc(); - // Get rank and tile sizes. - SmallVector tileSizes = - options.tileSizeComputationFunction(rewriter, op); - auto iteratorTypes = op.getLoopIteratorTypes(); - Value zero = rewriter.create(loc, 0); - tileSizes.resize(iteratorTypes.size(), zero); - - /// Currently only handle operations with all parallel iterator types. - for (auto iteratorType : enumerate(iteratorTypes)) { - if (iteratorType.value() != utils::IteratorType::parallel && - !isZero(tileSizes[iteratorType.index()])) { - return rewriter.notifyMatchFailure( - op, "unhandled tiling of non-parallel iterator"); - } - } - - auto clonedOp = cast(rewriter.clone(*op.getOperation())); - SmallVector results = tileToSCF(rewriter, op, clonedOp, tileSizes); - - filter.replaceLinalgTransformationFilter(rewriter, clonedOp); - rewriter.replaceOp(op, results); - return success(); - } - -private: - linalg::LinalgTilingOptions options; - LinalgTransformationFilter filter; -}; -} // namespace - -/// Second pattern to implement the switch of `TilingInterface -> -/// tensor.extract_slice` to `tensor.extract_slice -> `TilingInterface`. -FailureOr SwapTilingInterfaceOp::returningMatchAndRewrite( - tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { - auto sourceOp = sliceOp.getSource().getDefiningOp(); - if (!sourceOp) - return failure(); - SmallVector tiledOps = sourceOp.getTiledImplementation( - rewriter, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes()); - assert(tiledOps.size() && "expected single tiled op"); - Operation *tiledOp = tiledOps.front(); - rewriter.replaceOp(sliceOp, tiledOp->getResults()); - return tiledOp; -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp deleted file mode 100644 index 935ddd27ebc0..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp +++ /dev/null @@ -1,1053 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include "mlir/Transforms/Passes.h" - -using namespace mlir; - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -//===----------------------------------------------------------------------===// -// CodegenStrategy patterns and passes. -//===----------------------------------------------------------------------===// - -FailureOr tileConsumerAndFuseProducers( - OpBuilder &b, linalg::LinalgOp consumerOp, ArrayRef tileSizes, - ArrayRef tileInterchange, - const Optional &tileDistribution) { - assert(tileSizes.size() == tileInterchange.size() && - "expect the number of tile sizes and interchange dims to match"); - assert(isPermutationVector(tileInterchange) && - "expect tile interchange is a permutation"); - - // Create an empty tile loop nest. - linalg::TileLoopNest tileLoopNest(consumerOp); - - // Search the number of outer parallel loops to separate them from possible - // inner reduction dimensions. - auto iterTypes = consumerOp.getIteratorTypesArray(); - // Make sure to only look at the leading loops for tiling---we will scan this - // array to find the first non-parallel loop later and use that for indexing - // into the tile sizes. - if (iterTypes.size() > tileSizes.size()) { - iterTypes.resize(tileSizes.size()); - } - applyPermutationToVector(iterTypes, tileInterchange); - auto *it = find_if_not(iterTypes, linalg::isParallelIterator); - int64_t split = std::distance(iterTypes.begin(), it); - - // Helper to fuse the producers greedily using a queue of fusion candidates. - auto fuseProducersGreedily = [&](ArrayRef operands) { - SmallVector candidates(operands.begin(), operands.end()); - while (!candidates.empty()) { - FailureOr fusedProducer = - tileLoopNest.fuseProducer(b, candidates.pop_back_val()); - if (failed(fusedProducer)) - continue; - candidates.append(fusedProducer->getDpsInputOperands()); - candidates.append(fusedProducer->getDpsInitOperands()); - } - }; - - // Perform tiling and fusion in two steps. We need to respect the loop - // interchange here; filter parellel dimensions based on their order *after* - // permutation but pass in the original configuration *before* permuation, - // given the tiling and interchange happen together. - SmallVector outerTileSizes(tileSizes.size(), 0); - SmallVector innerTileSizes(tileSizes.size(), 0); - for (int64_t i : tileInterchange.take_front(split)) - outerTileSizes[i] = tileSizes[i]; - for (int64_t i : tileInterchange.drop_front(split)) - innerTileSizes[i] = tileSizes[i]; - - // Tile the outer parallel loops and fuse the output operands. - if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, - tileDistribution))) - return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getDpsInitOperands()); - - // Tile the remaining loops and fuse the input operands. - if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, - tileDistribution))) - return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getDpsInputOperands()); - - // Exit if the tile loop nest is empty since all tile sizes are zero. - if (tileLoopNest.isEmpty()) - return failure(); - - return tileLoopNest; -} - -/// Peel loops after tiling. -static void peelTiledLinalgOp(RewriterBase &rewriter, - linalg::TiledLinalgOp &res, - ArrayRef peeledLoops, - linalg::LinalgTilingLoopType loopType) { - for (int64_t loop : peeledLoops) { - assert(loop < static_cast(res.loops.size()) && - "requested peeling of non-existing loop"); - SmallVector loopResults; - Operation *loopOp = res.loops[loop]; - loopResults = linalg::peelLoop(rewriter, loopOp); - - // The result of the loop nest may change with peeling. - if (res.tensorResults.size() == loopOp->getNumResults() && - std::equal(res.tensorResults.begin(), res.tensorResults.end(), - loopOp->getResults().begin())) - res.tensorResults = loopResults; - } -} - -/// Linalg tiling pattern. -LinalgTilingPattern::LinalgTilingPattern( - MLIRContext *context, linalg::LinalgTilingOptions options, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)), options(std::move(options)) {} - -LinalgTilingPattern::LinalgTilingPattern( - StringRef opName, MLIRContext *context, linalg::LinalgTilingOptions options, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)), options(std::move(options)) {} - -FailureOr -LinalgTilingPattern::returningMatchAndRewrite(linalg::LinalgOp op, - PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - FailureOr res = - linalg::tileLinalgOp(rewriter, op, options); - if (failed(res)) - return failure(); - - // Clear filter to stop recursive pattern application. - // This must be done here to properly propagate to peeling branches. - filter.replaceLinalgTransformationFilter(rewriter, res->op); - - // Peel the loops of the TiledLinalgOp. - peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); - - if (res->tensorResults.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, res->tensorResults); - - return res; -} - -/// Linalg SCF tiling pattern. -LinalgSCFTilingPattern::LinalgSCFTilingPattern( - MLIRContext *context, scf::SCFTilingOptions options, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)), options(std::move(options)) {} - -LinalgSCFTilingPattern::LinalgSCFTilingPattern( - StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)), options(std::move(options)) {} - -LogicalResult LinalgSCFTilingPattern::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - FailureOr tiledResults = - scf::tileUsingSCFForOp(rewriter, op, options); - if (failed(tiledResults)) - return failure(); - - rewriter.replaceOp(op, tiledResults->replacements); - - for (auto tiledOp : tiledResults->tiledOps) { - filter.replaceLinalgTransformationFilter(rewriter, tiledOp); - } - - return success(); -} - -/// Linalg tile and fuse tensor ops pattern. -LinalgSCFTileAndFusePattern::LinalgSCFTileAndFusePattern( - MLIRContext *context, scf::SCFTileAndFuseOptions options, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)), options(std::move(options)) {} - -LinalgSCFTileAndFusePattern::LinalgSCFTileAndFusePattern( - StringRef opName, MLIRContext *context, scf::SCFTileAndFuseOptions options, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)), options(std::move(options)) {} - -LogicalResult -LinalgSCFTileAndFusePattern::matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - FailureOr tiledResults = - tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, options); - if (failed(tiledResults)) - return rewriter.notifyMatchFailure( - op, - "tileConsumerAndFuseProducerGreedilyUsingSCFForOp failed unexpectedly"); - - // Replace all uses of the tiled loop operation. - SmallVector replacements(op->getNumResults()); - for (auto result : llvm::enumerate(op->getResults())) { - auto it = tiledResults->replacements.find(result.value()); - if (it == tiledResults->replacements.end()) { - replacements[result.index()] = result.value(); - } else { - replacements[result.index()] = it->getSecond(); - } - } - rewriter.replaceOp(op, replacements); - - // Apply the filter if specified. - for (linalg::LinalgOp linalgOp : tiledResults->tiledAndFusedOps) - filter.replaceLinalgTransformationFilter(rewriter, linalgOp); - - return success(); -} - -LinalgVectorizationPattern::LinalgVectorizationPattern( - MLIRContext *context, LinalgExt::LinalgTransformationFilter f, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)) {} - -LinalgVectorizationPattern::LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)) {} - -LogicalResult -LinalgVectorizationPattern::matchAndRewrite(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - return vectorize(rewriter, linalgOp); -} - -namespace { - -/// -/// Linalg peeling patterns. -/// - -/// Compute the loops to peel and return them in a SmallVector. Loops will be -/// peeled in order of appearance in the SmallVector. This order will impact the -/// output IR. If an inner-to-outer order is provided, the peeled iterations of -/// the outer loops will also contain the peeled inner loops. If an -/// outer-to-inner order is provided, the peeled iterations of the outer loops -/// will not contain any peeled inner loops. - -/// `filter` controls LinalgTransformMarker matching and update when specified. -struct LinalgPeelingPattern - : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgPeelingPattern(MLIRContext *context, - LinalgExt::LinalgTransformationFilter f = - LinalgExt::LinalgTransformationFilter(), - LinalgPeelOptions options = LinalgPeelOptions(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgPeelingPattern(StringRef opName, MLIRContext *context, - LinalgPeelOptions options = LinalgPeelOptions(), - LinalgExt::LinalgTransformationFilter f = - LinalgExt::LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - const LinalgExt::LinalgTransformationFilter filter; - /// Peeling options. - const LinalgPeelOptions options; -}; - -LinalgPeelingPattern::LinalgPeelingPattern( - MLIRContext *context, LinalgExt::LinalgTransformationFilter f, - LinalgPeelOptions options, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)), options(std::move(options)) {} - -LinalgPeelingPattern::LinalgPeelingPattern( - StringRef opName, MLIRContext *context, LinalgPeelOptions options, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)), options(std::move(options)) {} - -LogicalResult -LinalgPeelingPattern::matchAndRewrite(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - - // Increase marker counter even if peeling doesn't happen for this op. - filter.replaceLinalgTransformationFilter(rewriter, linalgOp); - - if (!options.loopsToPeelComputationFunction) - return failure(); - - SmallVector loopsToPeel; - options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel); - linalg::peelLoops(rewriter, loopsToPeel); - return success(); -} - -/// Configurable pass to apply pattern-based tiling and fusion. -struct LinalgStrategyTileAndFusePass - : public LinalgStrategyTileAndFusePassBase { - - LinalgStrategyTileAndFusePass() = default; - - LinalgStrategyTileAndFusePass(StringRef opName, - scf::SCFTileAndFuseOptions options, - LinalgExt::LinalgTransformationFilter filt) - : options(std::move(options)), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); - if (!anchorOpName.empty()) { - tilingAndFusionPattern.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - tilingAndFusionPattern.add( - funcOp.getContext(), options, filter); - } - // Search the root operation using bottom up traversal. - GreedyRewriteConfig config; - config.useTopDownTraversal = false; - (void)applyPatternsAndFoldGreedily( - funcOp, std::move(tilingAndFusionPattern), config); - } - - scf::SCFTileAndFuseOptions options; - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to apply pattern-based linalg tiling. -struct LinalgStrategyTilePass - : public LinalgStrategyTilePassBase { - - LinalgStrategyTilePass() = default; - - LinalgStrategyTilePass(StringRef opName, scf::SCFTilingOptions options, - LinalgExt::LinalgTransformationFilter filt) - : options(std::move(options)), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - MLIRContext *ctx = funcOp.getContext(); - RewritePatternSet tilingPattern(ctx); - if (!anchorOpName.empty()) - tilingPattern.add(anchorOpName, ctx, options, - filter); - else - tilingPattern.add(ctx, options, filter); - - (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); - } - - scf::SCFTilingOptions options; - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to apply hoisting and padding. -struct LinalgStrategyPadPass - : public LinalgStrategyPadPassBase { - - LinalgStrategyPadPass() = default; - - LinalgStrategyPadPass(StringRef opName, linalg::LinalgPaddingOptions opt, - LinalgExt::LinalgTransformationFilter filt) - : options(std::move(opt)), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet paddingPattern(funcOp.getContext()); - if (!anchorOpName.empty()) { - paddingPattern.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - paddingPattern.add(funcOp.getContext(), options, - filter); - } - (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern)); - } - - linalg::LinalgPaddingOptions options; - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to apply lowering of coarser-grained named linalg ops into -/// finer-grained named versions. -struct LinalgStrategyDecomposePass - : public LinalgStrategyDecomposePassBase { - - LinalgStrategyDecomposePass() = default; - - LinalgStrategyDecomposePass(LinalgExt::LinalgTransformationFilter filter) - : filter(std::move(filter)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - RewritePatternSet decompositionPattern(funcOp.getContext()); - decompositionPattern - .add, - DownscaleSizeOneWindowed2DConvolution, - DownscaleDepthwiseConv2DNhwcHwcOp>(funcOp.getContext(), filter); - if (failed(applyPatternsAndFoldGreedily(funcOp, - std::move(decompositionPattern)))) - signalPassFailure(); - } - - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to apply pattern-based linalg peeling. -struct LinalgStrategyPeelPass - : public LinalgStrategyPeelPassBase { - - LinalgStrategyPeelPass() = default; - - LinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt, - LinalgExt::LinalgTransformationFilter filt) - : options(std::move(opt)), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet peelingPatterns(funcOp.getContext()); - if (!anchorOpName.empty()) { - peelingPatterns.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - peelingPatterns.add(funcOp.getContext(), filter, - options); - } - if (failed( - applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns)))) - return signalPassFailure(); - } - - LinalgPeelOptions options; - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to apply pattern-based linalg vectorization. -struct LinalgStrategyVectorizePass - : public LinalgStrategyVectorizePassBase { - - LinalgStrategyVectorizePass() = default; - - LinalgStrategyVectorizePass(StringRef opName, - LinalgExt::LinalgTransformationFilter filt, - bool padVectorize = false) - : filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - this->vectorizePadding.setValue(padVectorize); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet vectorizationPatterns(funcOp.getContext()); - if (!anchorOpName.empty()) { - vectorizationPatterns.add( - anchorOpName, funcOp.getContext(), filter); - } else { - vectorizationPatterns.add(funcOp.getContext(), - filter); - } - vector::populateVectorTransferPermutationMapLoweringPatterns( - vectorizationPatterns); - vector::populateVectorReductionToContractPatterns(vectorizationPatterns); - vectorizationPatterns.add( - funcOp.getContext(), /*benefit=*/2); - vector::TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, - funcOp.getContext()); - vector::TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, - funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(vectorizationPatterns)); - - // Apply the pad tensor op vectorization separately to avoid running the - // GenericPadOpVectorizationPattern too early. - // TODO: Improve once we have better infrastructure to control pattern - // application. - if (vectorizePadding) { - RewritePatternSet patterns(funcOp.getContext()); - linalg::populatePadOpVectorizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } - } - - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to enable the application of other pattern-based linalg -/// passes. -struct LinalgStrategyEnablePass - : public LinalgStrategyEnablePassBase { - - LinalgStrategyEnablePass(LinalgEnablingOptions opt, - LinalgExt::LinalgTransformationFilter filt) - : options(opt), filter(std::move(filt)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - MLIRContext *context = funcOp.getContext(); - RewritePatternSet patterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - scf::populateSCFForLoopCanonicalizationPatterns(patterns); - tensor::populateFoldTensorEmptyPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) - return signalPassFailure(); - - if (options.licm) { - funcOp->walk([&](LoopLikeOpInterface loopLike) { - moveLoopInvariantCode(loopLike); - }); - } - - // Gathers all innermost loops through a post order pruned walk. - funcOp.walk([](Operation *op) { - if (auto forOp = dyn_cast(op)) - (void)promoteIfSingleIteration(forOp); - else if (auto forOp = dyn_cast(op)) - (void)promoteIfSingleIteration(forOp); - }); - if (options.hoistRedundantVectorTransfers) - linalg::hoistRedundantVectorTransfers(funcOp); - - if (options.hoistRedundantVectorTransfersOnTensor) - linalg::hoistRedundantVectorTransfersOnTensor(funcOp); - - // Run CSE to cleanup after canonicalization. - OpPassManager dynamicPM("func.func"); - dynamicPM.addPass(createCSEPass()); - if (failed(runPipeline(dynamicPM, funcOp))) - return signalPassFailure(); - } - - LinalgEnablingOptions options; - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to lower vector operations. -struct LinalgStrategyLowerVectorsPass - : public LinalgStrategyLowerVectorsPassBase< - LinalgStrategyLowerVectorsPass> { - - LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt, - LinalgExt::LinalgTransformationFilter filt) - : options(opt), filter(std::move(filt)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - MLIRContext *context = funcOp.getContext(); - RewritePatternSet patterns(context); - vector::populateVectorToVectorCanonicalizationPatterns(patterns); - // In a progressive lowering of vectors, this would be the 1st step. - if (options.contractionLowering) { - patterns.add( - options.vectorTransformOptions, context); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - } - // In a progressive lowering of vectors, this would be the 2nd step. - if (options.multiReductionLowering) { - vector::populateVectorMultiReductionLoweringPatterns( - patterns, - options.vectorTransformOptions.vectorMultiReductionLowering); - } - // In a progressive lowering of vectors, this would be the 3rd step. - if (options.transferPartialRewrite) { - patterns.add( - context, options.vectorTransformOptions); - } - // In a progressive lowering of vectors, this would be the 4th step. - if (options.transferLowering) { - vector::populateVectorTransferLoweringPatterns(patterns, - options.maxTransferRank); - } - // In a progressive lowering of vectors, this would be the 5th step. - if (options.transferToSCFConversion) { - populateVectorToSCFConversionPatterns( - patterns, options.vectorTransferToSCFOptions.setTargetRank( - options.maxTransferRank)); - } - // In a progressive lowering of vectors, this would be the 6th step. - if (options.shapeCastLowering) { - vector::populateVectorShapeCastLoweringPatterns(patterns); - } - // In a progressive lowering of vectors, this would be the 7th step. - if (options.transposeLowering) { - vector::populateVectorTransposeLoweringPatterns( - patterns, options.vectorTransformOptions); - if (options.avx2Lowering) - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, options.avx2LoweringOptions, /*benefit=*/10); - } - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } - - LinalgVectorLoweringOptions options; - LinalgExt::LinalgTransformationFilter filter; -}; - -/// Configurable pass to lower vector operations. -struct LinalgStrategyRemoveMarkersPass - : public LinalgStrategyRemoveMarkersPassBase< - LinalgStrategyRemoveMarkersPass> { - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - funcOp.walk([](linalg::LinalgOp op) { - op->removeAttr(LinalgTransforms::kLinalgTransformMarker); - }); - } -}; -} // namespace - -/// Create a LinalgStrategyTileAndFusePass. -std::unique_ptr> -createLinalgStrategyTileAndFusePass( - StringRef opName, const scf::SCFTileAndFuseOptions &options, - const LinalgExt::LinalgTransformationFilter &filter) { - return std::make_unique(opName, options, - filter); -} - -/// Create a LinalgStrategyTilePass. -std::unique_ptr> createLinalgStrategyTilePass( - StringRef opName, const scf::SCFTilingOptions &options, - const LinalgExt::LinalgTransformationFilter &filter) { - return std::make_unique(opName, options, filter); -} - -/// Create a LinalgStrategyPadPass. -std::unique_ptr> createLinalgStrategyPadPass( - StringRef opName, const linalg::LinalgPaddingOptions &opt, - const LinalgExt::LinalgTransformationFilter &filter) { - return std::make_unique(opName, opt, filter); -} - -/// Create a LinalgStrategyDecomposePass. -// TODO: if/when we need finer control add an `opName` parameter. -std::unique_ptr> createLinalgStrategyDecomposePass( - const LinalgExt::LinalgTransformationFilter &filter) { - return std::make_unique(filter); -} - -/// Create a LinalgStrategyPeelPass. -std::unique_ptr> createLinalgStrategyPeelPass( - StringRef opName, const LinalgPeelOptions &opt, - const LinalgExt::LinalgTransformationFilter &filter) { - return std::make_unique(opName, opt, filter); -} - -/// Create a LinalgStrategyVectorizePass. -std::unique_ptr> createLinalgStrategyVectorizePass( - StringRef opName, const LinalgExt::LinalgTransformationFilter &filter, - bool padVectorize) { - return std::make_unique(opName, filter, - padVectorize); -} - -/// Create a LinalgStrategyEnablePass. -std::unique_ptr> createLinalgStrategyEnablePass( - LinalgEnablingOptions opt, - const LinalgExt::LinalgTransformationFilter &filter) { - return std::make_unique(opt, filter); -} - -/// Create a LinalgStrategyLowerVectorsPass. -std::unique_ptr> -createLinalgStrategyLowerVectorsPass( - LinalgVectorLoweringOptions opt, - const LinalgExt::LinalgTransformationFilter &filter) { - return std::make_unique(opt, filter); -} - -/// Create a LinalgStrategyRemoveMarkersPass. -std::unique_ptr> -createLinalgStrategyRemoveMarkersPass() { - return std::make_unique(); -} - -//===----------------------------------------------------------------------===// -// LinalgExt patterns and passes. -//===----------------------------------------------------------------------===// - -namespace { - -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; - -/// Returns a tensor.pad op if padding value is set. Otherwise, returns the -/// input directly. The method assumes that the `packOp` has static shapes. -Value getInputOrPaddedInput(OpBuilder &builder, PackOp packOp) { - Value input = packOp.getInput(); - if (!packOp.getPaddingValue()) { - return input; - } - - Location loc = packOp.getLoc(); - ShapedType inputType = packOp.getInputType(); - int64_t inputRank = inputType.getRank(); - assert(llvm::all_of(packOp.getOutputShape().take_front(inputRank), - [](int64_t val) { return val == 1; })); - - SmallVector paddedShape; - DenseMap tileAndPosMapping = - packOp.getDimAndTileMapping(); - for (int64_t dim = 0; dim < inputRank; ++dim) { - int64_t size = inputType.getDimSize(dim); - if (!tileAndPosMapping.count(dim)) { - paddedShape.push_back(size); - continue; - } - - // The size is less than or equal to tileSize because outer dims are all 1s. - Optional tileSize = - getConstantIntValue(tileAndPosMapping.lookup(dim)); - assert(tileSize.has_value() && "dynamic inner tile size is not supported"); - paddedShape.push_back(tileSize.value()); - } - auto resultType = - RankedTensorType::get(paddedShape, inputType.getElementType()); - return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(), - /*nofold=*/false, loc, builder); -} - -/// Rewrites iree_linalg_ext.pack to tensor.pad + rank-up linalg.generic -/// (transpose) ops. -struct GeneralizePackOpPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PackOp packOp, - PatternRewriter &rewriter) const final { - if (!packOp.hasTensorSemantics()) { - return rewriter.notifyMatchFailure(packOp, "require tensor semantics"); - } - - // The expand_shape op can be avoided if outer dimensions of result are all - // 1s. This can be relaxed if needed. A tensor.expand_shape will be - // generated in that case. - int64_t inputRank = packOp.getInputRank(); - if (llvm::any_of(packOp.getOutputShape().take_front(inputRank), - [](int64_t val) { return val != 1; })) { - return rewriter.notifyMatchFailure( - packOp, "require the outer dimension of the result are all 1s"); - } - if (llvm::any_of(packOp.getMixedTiles(), - [](OpFoldResult tile) { return tile.is(); })) { - return rewriter.notifyMatchFailure( - packOp, "require inner tile sizes being static"); - } - - Value input = getInputOrPaddedInput(rewriter, packOp); - - SmallVector inputExprs; - for (int64_t dim = 0; dim < inputRank; ++dim) { - inputExprs.push_back(rewriter.getAffineDimExpr(dim)); - } - // The dimensions map in the order of output dimensions. Since the - // interchange is applied, we have to undo it for input. - if (!packOp.getOuterDimsPerm().empty()) { - inputExprs = - undoInterchange(inputExprs, packOp.getOuterDimsPerm()); - } - for (auto en : llvm::enumerate(packOp.getInnerDimsPos())) { - inputExprs[en.value()] = - rewriter.getAffineDimExpr(inputRank + en.index()); - } - - Location loc = packOp.getLoc(); - auto inputType = input.getType().cast(); - auto nloops = packOp.getOutputRank(); - - Value empty = rewriter.create(loc, packOp.getOutputShape(), - inputType.getElementType()); - SmallVector loopAttributeTypes( - nloops, utils::IteratorType::parallel); - SmallVector indexingMaps = { - AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), - AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; - - auto transposedOp = rewriter.create( - loc, empty.getType(), input, empty, indexingMaps, loopAttributeTypes, - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); - }); - rewriter.replaceOp(packOp, transposedOp.getResult(0)); - return success(); - } -}; - -/// Rewrites iree_linalg_ext.unpack to rank-reduced extract_slice op + transpose -/// op + insert_slice op. It requires the outer dims are all 1s. -struct GeneralizeUnPackOpPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(UnPackOp unpackOp, - PatternRewriter &rewriter) const final { - if (!unpackOp.hasTensorSemantics()) { - return rewriter.notifyMatchFailure(unpackOp, "require tensor semantics"); - } - - int64_t outputRank = unpackOp.getOutputRank(); - if (llvm::any_of(unpackOp.getInputShape().take_front(outputRank), - [](int64_t val) { return val != 1; })) { - return rewriter.notifyMatchFailure( - unpackOp, "require the outer dimension of the result are all 1s"); - } - - int64_t inputRank = unpackOp.getInputRank(); - - Location loc = unpackOp.getLoc(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); - SmallVector readOffsets(inputRank, zeroIdxAttr); - SmallVector readStrides(inputRank, oneIdxAttr); - - auto mixedTiles = unpackOp.getMixedTiles(); - SmallVector readSizes(outputRank, oneIdxAttr); - readSizes.append(mixedTiles.begin(), mixedTiles.end()); - - // Explicitly create the type for extract_slice op because the inner tile - // size could be 1. We want to represent the whole inner tile in this case. - ArrayRef readShape = - unpackOp.getInputShape().drop_front(outputRank); - Type elemType = unpackOp.getInputType().getElementType(); - auto readType = RankedTensorType::get(readShape, elemType); - Value innerTile = rewriter.create( - loc, readType, unpackOp.getInput(), readOffsets, readSizes, - readStrides); - - ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); - auto interchangeVector = - computeInterchangeFromDimPos(innerDimsPos, outputRank); - SmallVector transpShape = - interchange(readShape, interchangeVector); - - Value empty = rewriter.create(loc, transpShape, elemType); - auto transposedOp = rewriter.create( - loc, innerTile, empty, interchangeVector); - - // Handle in-complete tiles. - int numLoops = transpShape.size(); - SmallVector tileStrides(numLoops, oneIdxAttr); - SmallVector tileOffsets(numLoops, zeroIdxAttr); - SmallVector tileSizes; - for (int dim : innerDimsPos) { - tileSizes.push_back(getDim(rewriter, loc, unpackOp.getOutput(), dim)); - } - tileSizes = interchange(tileSizes, interchangeVector); - auto partialTile = rewriter.create( - loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); - - SmallVector writeSizes; - SmallVector writeStrides(outputRank, oneIdxAttr); - SmallVector writeOffsets(outputRank, zeroIdxAttr); - DenseMap dimAndTileMapping = - unpackOp.getDimAndTileMapping(); - for (int i = 0, idx = 0; i < outputRank; ++i) { - if (dimAndTileMapping.count(i)) { - writeSizes.push_back(tileSizes[idx++]); - } else { - writeSizes.push_back(oneIdxAttr); - } - } - auto insert = rewriter.create( - loc, partialTile, unpackOp.getOutput(), writeOffsets, writeSizes, - writeStrides); - rewriter.replaceOp(unpackOp, insert.getResult()); - - return success(); - } -}; - -struct LinalgExtVectorizationPass - : public LinalgExtVectorizationBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - // Apply tiling to make outer dims be all 1s. - { - SimpleRewriter rewriter(ctx); - auto packOptions = scf::SCFTileAndFuseOptions().setTilingOptions( - scf::SCFTilingOptions().setTileSizeComputationFunction( - [](OpBuilder &builder, Operation *op) -> SmallVector { - Location loc = op->getLoc(); - auto packOp = cast(op); - - // Do nothing if any of inner tile sizes is dynamic. - if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) { - return tile.is(); - })) - return {}; - - int inputRank = packOp.getInputRank(); - SmallVector tileSizes( - inputRank, builder.create(loc, 1)); - return tileSizes; - })); - auto funcOp = getOperation(); - funcOp->walk([&](LinalgExt::PackOp op) { - FailureOr tileAndFuseResult = - scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, - packOptions); - if (failed(tileAndFuseResult)) - return signalPassFailure(); - rewriter.replaceOp(op, - tileAndFuseResult->replacements[op.getResult(0)]); - }); - - auto unpackTilingOptions = - scf::SCFTilingOptions().setTileSizeComputationFunction( - [](OpBuilder &builder, Operation *op) { - Location loc = op->getLoc(); - auto unpackOp = cast(op); - int numLoops = unpackOp.getOutputRank(); - auto dimAndTileMapping = unpackOp.getDimAndTileMapping(); - SmallVector tileSizes; - for (int i = 0; i < numLoops; ++i) { - if (dimAndTileMapping.count(i)) { - tileSizes.push_back(getValueOrCreateConstantIndexOp( - builder, loc, dimAndTileMapping[i])); - } else { - tileSizes.push_back( - getDimValue(builder, loc, unpackOp.getOutput(), i)); - } - } - return tileSizes; - }); - funcOp->walk([&](LinalgExt::UnPackOp op) { - FailureOr tilingResult = scf::tileUsingSCFForOp( - rewriter, cast(op.getOperation()), - unpackTilingOptions); - if (failed(tilingResult)) - return signalPassFailure(); - rewriter.replaceOp(op, tilingResult->replacements); - }); - } - - // Generalize pack and unpack ops and canonicalize tiled ops. - { - RewritePatternSet patterns(ctx); - linalg::populateLinalgTilingCanonicalizationPatterns(patterns); - patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - - // Kick in generic vectorizer. - { - RewritePatternSet patterns(ctx); - patterns.add(ctx); - linalg::populatePadOpVectorizationPatterns(patterns); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); - vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - } -}; -} // namespace - -std::unique_ptr> -createLinalgExtVectorizationPass() { - return std::make_unique(); -} - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp deleted file mode 100644 index 4248ed2b2afd..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h" - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; -using namespace mlir::iree_compiler::IREE::LinalgExt; - -void mlir::iree_compiler::IREE::LinalgExt::completeOffsetsSizesAndStrides( - OpBuilder &b, Location loc, Value tensor, ArrayRef leadingOffsets, - ArrayRef leadingSizes, ArrayRef leadingStrides, - SmallVectorImpl &offsets, SmallVectorImpl &sizes, - SmallVectorImpl &strides) { - assert(leadingOffsets.size() == leadingSizes.size() && - "expected matching lengths"); - assert(leadingSizes.size() == leadingStrides.size() && - "expected matching lengths"); - - auto rankedTensorType = tensor.getType().cast(); - int64_t tensorRank = rankedTensorType.getRank(); - int64_t leadingRank = leadingOffsets.size(); - offsets = SmallVector(leadingOffsets.begin(), leadingOffsets.end()); - sizes = SmallVector(leadingSizes.begin(), leadingSizes.end()); - strides = SmallVector(leadingStrides.begin(), leadingStrides.end()); - if (leadingRank >= tensorRank) - return; - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); - for (int64_t i = leadingRank, e = tensorRank; i < e; ++i) { - offsets.push_back(zero); - sizes.push_back(b.createOrFold(loc, tensor, i)); - strides.push_back(one); - } -} - -/// Create a tensor::ExtractSliceOp by auto-completing the missing trailing -/// dimensions to always be offset = 0, size = dim, stride = 1. -Value mlir::iree_compiler::IREE::LinalgExt:: - createSubsetExtractOpFromLeadingOffsetsSizesAndStrides( - OpBuilder &b, Location loc, Value tensor, - ArrayRef leadingOffsets, ArrayRef leadingSizes, - ArrayRef leadingStrides) { - SmallVector offsets, sizes, strides; - completeOffsetsSizesAndStrides(b, loc, tensor, leadingOffsets, leadingSizes, - leadingStrides, offsets, sizes, strides); - return b.createOrFold(loc, tensor, offsets, sizes, - strides); -} - -/// Create a tensor::InsertSliceOp by auto-completing the missing trailing -/// dimensions to always be offset = 0, size = dim, stride = 1. -Value mlir::iree_compiler::IREE::LinalgExt:: - createSubsetInsertOpFromLeadingOffsetsSizesAndStrides( - OpBuilder &b, Location loc, Value tensor, Value dest, - ArrayRef leadingOffsets, ArrayRef leadingSizes, - ArrayRef leadingStrides) { - SmallVector offsets, sizes, strides; - completeOffsetsSizesAndStrides(b, loc, tensor, leadingOffsets, leadingSizes, - leadingStrides, offsets, sizes, strides); - return b.createOrFold(loc, tensor, dest, offsets, - sizes, strides); -} - -/// Insert the `source` tensor into the `dest` tensor by creating the relevant -/// `subset_insert` op. The details of the `subset_insert` op are retrieved -/// from the `subset_extract` op so that they form a matching extract/insert -/// pair. -Value mlir::iree_compiler::IREE::LinalgExt::createMatchingSubsetInsertOp( - OpBuilder &b, Location loc, tensor::ExtractSliceOp subsetExtractOp, - Value source, Value dest) { - return b.create( - loc, subsetExtractOp.getSource().getType(), source, dest, - subsetExtractOp.offsets(), subsetExtractOp.sizes(), - subsetExtractOp.strides(), subsetExtractOp.static_offsets(), - subsetExtractOp.static_sizes(), subsetExtractOp.static_strides()); -} - -void mlir::iree_compiler::IREE::LinalgExt::createMatchingParallelSubsetInsertOp( - OpBuilder &b, Location loc, tensor::ExtractSliceOp subsetExtractOp, - Value source, Value dest) { - b.create( - loc, source, dest, subsetExtractOp.getMixedOffsets(), - subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides()); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt deleted file mode 100644 index 57d063baf509..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_mlir_library(IREELinalgExtUtils - Utils.cpp - - PARTIAL_SOURCES_INTENDED - DEPENDS - mlir-headers - - MLIRDialectUtils - MLIRIR - MLIRTensorDialect - MLIRMemRefDialect -) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp deleted file mode 100644 index 5eef5b40ada2..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Utils/Utils.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Builders.h" -#include "llvm/ADT/TypeSwitch.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace LinalgExt { - -Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) { - ShapedType type = v.getType().cast(); - if (!type.isDynamicDim(dim)) { - return builder.create(loc, type.getDimSize(dim)); - } - return TypeSwitch(v.getType()) - .Case([&](RankedTensorType t) -> Value { - return builder.create(loc, v, dim); - }) - .Case([&](MemRefType t) -> Value { - return builder.create(loc, v, dim); - }); -} - -OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) { - auto t = v.getType().cast(); - if (t.isDynamicDim(dim)) { - return getDimValue(builder, loc, v, dim); - } - return builder.getI64IntegerAttr(t.getDimSize(dim)); -} - -SmallVector getDims(OpBuilder &builder, Location loc, - Value shapedTypeValue) { - return llvm::to_vector(llvm::map_range( - llvm::seq( - 0, shapedTypeValue.getType().cast().getRank()), - [&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); })); -} - -SmallVector computeInterchangeFromDimPos(ArrayRef dimsPos, - int64_t rank) { - SmallVector interchangeVector; - interchangeVector.reserve(dimsPos.size()); - // First map dims and their position. For example, dims_pos = [2, 0] will map - // to: - // [ - // [ key: 2, value: 0] - // [ key: 0, value: 1] - // ] - // where key is the idx in dims_pos while value its position in dims_pos. - DenseMap dimsAndPosMapping; - for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) - dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx; - - // Scan the position in order and insert the value in the map - // to compute the interchange vector. - for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) { - if (dimsAndPosMapping.count(dimsIdx)) - interchangeVector.push_back(dimsAndPosMapping[dimsIdx]); - } - return interchangeVector; -} - -Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols, - Location loc, PatternRewriter &rewriter) { - ArrayRef vector(val, rows * cols); - SmallVector shape{rows, cols}; - return rewriter.create( - loc, DenseFPElementsAttr::get( - RankedTensorType::get(shape, rewriter.getF32Type()), vector)); -} - -} // namespace LinalgExt -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt deleted file mode 100644 index 5a7289b10131..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Passes) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt deleted file mode 100644 index 4a5428fb0fcf..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -add_mlir_library(IREELinalgTransformDialect - LinalgTransformOps.cpp - ScopedTransform.cpp - StructuredTransformOpsExt.cpp - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - IREEDialectsTransforms - MLIRIR - - # Dialects - IREELinalgExtDialect - IREELinalgExtTransforms - - MLIRAsyncDialect - MLIRControlFlowInterfaces - MLIRLinalgDialect - MLIRPDLDialect - MLIRRewrite - MLIRTransformDialect - - # Transforms - MLIRAsyncTransforms - MLIRLinalgTransforms - MLIRAffineToStandard - MLIRTransforms - MLIRReconcileUnrealizedCasts - - # Conversions - MLIRAsyncToLLVM - MLIRMemRefToLLVM - MLIRMathToLLVM - MLIRVectorToLLVM - MLIRLinalgToLLVM - MLIRSCFToControlFlow -) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp deleted file mode 100644 index 95e83c826e51..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" - -#include - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" -#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h" -#include "iree-dialects/Transforms/Listener.h" -#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/PDL/IR/PDLTypes.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/InliningUtils.h" -#include "mlir/Transforms/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" - -#define DEBUG_TYPE "linalg-transform-dialect" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") - -using namespace mlir; -using namespace mlir::iree_compiler::IREE; - -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.cpp.inc" - -void linalg::transform::LinalgTransformDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.cpp.inc" - >(); -} - -//===---------------------------------------------------------------------===// -// ScopeOp -//===---------------------------------------------------------------------===// - -void linalg::transform::ScopeOp::getSuccessorRegions( - Optional index, ArrayRef operands, - SmallVectorImpl ®ions) { - if (index) - regions.emplace_back(getResults()); - else - regions.emplace_back(&getBody()); -} - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.cpp.inc" diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/ScopedTransform.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/ScopedTransform.cpp deleted file mode 100644 index e40b8d5ad1b4..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/ScopedTransform.cpp +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h" - -#include "mlir/Transforms/InliningUtils.h" - -using namespace mlir; - -namespace { -struct Rewriter : public PatternRewriter { - Rewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} -}; -} // namespace - -linalg::transform::ScopeOp linalg::transform::wrapInScope(Operation *op) { - Rewriter rewriter(op->getContext()); - rewriter.setInsertionPoint(op); - - auto scope = rewriter.create( - op->getLoc(), op->getResultTypes(), op->getOperands()); - Region &body = scope.getBody(); - rewriter.setInsertionPointToStart(&body.emplaceBlock()); - BlockAndValueMapping bv; - SmallVector locs(op->getOperandTypes().size(), op->getLoc()); - bv.map(op->getOperands(), body.addArguments(op->getOperandTypes(), locs)); - - Operation *cloneInScope = rewriter.clone(*op, bv); - rewriter.create(op->getLoc(), cloneInScope->getResults()); - - rewriter.replaceOp(op, scope.getResults()); - return scope; -} - -namespace { -/// Instruct the inliner to inline everything. Scopes have no semantic meaning -/// so moving operations in and out of them, regardless of whether their -/// dialects have implemented an inliner interface, is valid. -struct ScopeInliner : public InlinerInterface { - using InlinerInterface::InlinerInterface; - - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const override { - return true; - } - bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - BlockAndValueMapping &valueMapping) const override { - return true; - } - bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, - BlockAndValueMapping &valueMapping) const override { - return true; - } - - /// Don't recursively analyze operations, because they can all be "inlined". - bool shouldAnalyzeRecursively(Operation *op) const override { return false; } - - /// Replace uses of the results with the `forward` op's operands. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const override { - assert(isa(op)); - for (auto value : llvm::zip(op->getOperands(), valuesToRepl)) - std::get<1>(value).replaceAllUsesWith(std::get<0>(value)); - } -}; -} // namespace - -FailureOr> -linalg::transform::unwrapScope(linalg::transform::ScopeOp scope) { - ScopeInliner interface(scope->getContext()); - SmallVector ops; - scope.getBody().walk([&](Operation *op) { ops.push_back(op); }); - if (failed(inlineRegion(interface, &scope.getBody(), scope, - scope.getOperands(), scope.getResults(), - /*inlineLoc=*/{}, - /*shouldCloneInlinedRegion=*/false))) - return failure(); - Rewriter(scope->getContext()).eraseOp(scope); - return ops; -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp deleted file mode 100644 index 2afd47702ead..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp +++ /dev/null @@ -1,1359 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" - -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" -#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h" -#include "iree-dialects/Transforms/Listener.h" -#include "iree-dialects/Transforms/ListenerCSE.h" -#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h" -#include "iree-dialects/Transforms/TransformMatchers.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/PDL/IR/PDLTypes.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/InliningUtils.h" -#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include "mlir/Transforms/Passes.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" -#include -#include - -#define DEBUG_TYPE "transform-ops-ext" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") - -using namespace mlir; -using mlir::iree_compiler::IREE::LinalgExt::LinalgEnablingOptions; - -//===----------------------------------------------------------------------===// -// Additional constraints for PDLMatchOp. -//===----------------------------------------------------------------------===// - -/// Hook for PDL driver to check if an operation (`values[0]`) is directly -/// nested in a function with the name provided by an attribute -/// (`values[1]`). -/// TODO: PDL needs user-defined "questions". -static LogicalResult nestedInFunc(PatternRewriter &rewriter, - Operation *operation, Attribute attr) { - auto func = operation->getParentOfType(); - if (!func) - return rewriter.notifyMatchFailure(operation, "not nested in a function"); - auto functionSymbol = attr.dyn_cast(); - if (!functionSymbol) - return rewriter.notifyMatchFailure(operation, "not a function identifier"); - return success(functionSymbol.getLeafReference() == func.getName()); -} - -/// Construct a BlockAndValueMapping from `linalgOp` to `genericLinalgModelOp`. -/// Walk both ops and check whether all subops are the same. -static LogicalResult -haveIdenticalBodiesImpl(linalg::LinalgOp linalgOp, - linalg::LinalgOp genericLinalgModelOp) { - BlockAndValueMapping bvm; - bvm.map(linalgOp.getBlock()->getArguments(), - genericLinalgModelOp.getBlock()->getArguments()); - SmallVector linalgBodyOps; - linalgOp.getBlock()->walk( - [&](Operation *op) { linalgBodyOps.push_back(op); }); - - unsigned idx = 0; - WalkResult res = genericLinalgModelOp.getBlock()->walk([&](Operation *op) { - Operation *linalgSubOp = linalgBodyOps[idx++]; - if (op->getName() != linalgSubOp->getName()) - return WalkResult::interrupt(); - if (op->getAttrs() != linalgSubOp->getAttrs()) - return WalkResult::interrupt(); - for (auto it : llvm::zip(op->getOperands(), linalgSubOp->getOperands())) - if (std::get<0>(it) != bvm.lookupOrNull(std::get<1>(it))) - return WalkResult::interrupt(); - bvm.map(linalgSubOp->getResults(), op->getResults()); - return WalkResult::advance(); - }); - - return success(!res.wasInterrupted()); -} - -/// Dispatch body equivalence check depending on case. -static LogicalResult haveEquivalentBodies(linalg::LinalgOp linalgOp, - linalg::LinalgOp genericLinalgModelOp, - PatternRewriter &rewriter) { - if (succeeded(haveIdenticalBodiesImpl(linalgOp, genericLinalgModelOp))) - return success(); - // TODO: haveEquivalentBodiesImpl, see e.g. - // https://gist.github.com/nicolasvasilache/39e89e18c46e02335c16db6ec20a07e3 - return failure(); -} - -/// Succeed when `linalgOp` and `linalgModelOp` are deemed equivalent. -static LogicalResult isEquivalentToOpImpl(PatternRewriter &rewriter, - linalg::LinalgOp linalgOp, - linalg::LinalgOp linalgModelOp) { - // If basic properties do not match, return failure. - { - OpOperandVector opInputs = linalgOp.getDpsInputOperands(); - OpOperandVector modelInputs = linalgModelOp.getDpsInputOperands(); - OpOperandVector opOutputs = linalgOp.getDpsInitOperands(); - OpOperandVector modelOutputs = linalgModelOp.getDpsInitOperands(); - auto notEqualFn = [](std::tuple in) -> bool { - return std::get<0>(in)->get() != std::get<1>(in)->get(); - }; - - if (opInputs.size() != modelInputs.size() || - opOutputs.size() != modelOutputs.size() || - llvm::any_of(llvm::zip(opInputs, modelInputs), notEqualFn) || - llvm::any_of(llvm::zip(opOutputs, modelOutputs), notEqualFn) || - linalgOp.getIndexingMaps() != linalgModelOp.getIndexingMaps() || - linalgOp.getIteratorTypesArray() != - linalgModelOp.getIteratorTypesArray()) - return failure(); - } - - // Build the block and go perform a body comparison. - { - // createBlock moves the insertion point, scope it in an RAII block. - OpBuilder::InsertionGuard guard(rewriter); - Region &r = linalgModelOp->getRegion(0); - Block *bodyBlock = rewriter.createBlock( - &r, r.end(), linalgOp.getBlock()->getArgumentTypes(), - llvm::to_vector<4>( - llvm::map_range(linalgOp.getBlock()->getArguments(), - [](Value v) { return v.getLoc(); }))); - ImplicitLocOpBuilder b(linalgModelOp.getLoc(), rewriter); - auto regionBuilder = linalgModelOp.getRegionBuilder(); - llvm::ArrayRef attrs = {}; - regionBuilder(b, *bodyBlock, attrs); - } - - return haveEquivalentBodies(linalgOp, linalgModelOp, rewriter); -} - -/// Check whether the unique Operation* stored in `values[0]` (assumed) is -/// equivalent to the unique StringRefAttr passed in `values[1]` (assumed). -/// Equivalence is achieved when either: -/// 1. `values[0]` has the name stored in `values[1]`. -/// 2. `values[0]` and `values[1]` are both linalg ops and their structured -/// interfaces as well as their bodies are equivalent. -/// Structured interfaces equivalence is a simple attribute level check. -/// Body equivalence is more involved and currently limited: -/// a. the current impl constructs an instance of the op whose name is -/// specified in `values[1]` and checks for exact body equality. -/// b. a more advanced version would "subtract" the bodies and fold, cse -/// and canonicalize to fixed point. If the result is "all zeros", -/// then the bodies would be equivalent (really isomorphic). -/// 3. other cases TBD (e.g. vector.generic when available). -static LogicalResult isEquivalentToOp(PatternRewriter &rewriter, - Operation *operation, - Attribute attribute) { - auto modelOpNameAttr = attribute.dyn_cast(); - if (!modelOpNameAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - auto modelOpName = modelOpNameAttr.strref(); - - // 1. If op has name `modelOpName`, the match is trivial. - if (operation->getName().getStringRef() == modelOpName) - return success(); - - // 2. Linalg vs Linalg. - // Create op from `modelOpName`. - OperationState modelOpState( - operation->getLoc(), modelOpName, operation->getOperands(), - operation->getResultTypes(), operation->getAttrs()); - modelOpState.addRegion(); - Operation *modelOp = rewriter.create(modelOpState); - auto g1 = llvm::make_scope_exit([&]() { rewriter.eraseOp(modelOp); }); - linalg::LinalgOp linalgOp = dyn_cast(operation); - linalg::LinalgOp linalgModelOp = dyn_cast(modelOp); - if (linalgOp && linalgModelOp) - return isEquivalentToOpImpl(rewriter, linalgOp, linalgModelOp); - - // 3. TBD - return failure(); -} - -/// Assume that: -/// 1. `values[0]` is an operands range -/// 2. `values[1]` contains a DictAttr with `operand_number`, `dim` and -/// `divisor` IntegerAttr entries. -/// Succeed if `operands`[`operand_number`] is a ranked type whose `dim` is a -/// multiple of `divisor`. -/// Note: 0 is the convention to express "do not tile", it is considered to -/// divide everything. -static LogicalResult isDimMultipleOf(PatternRewriter &rewriter, - ValueRange operands, Attribute attribute) { - auto dict = attribute.dyn_cast(); - if (!dict) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - dim = dimAttr.getInt(); - - int64_t divisor; - auto divisorAttr = dict.getAs("divisor"); - if (!divisorAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - divisor = divisorAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) - shapedType = operands[operandNumber].getType().dyn_cast(); - if (!shapedType || shapedType.getRank() <= dim) - return failure(); - return success(divisor == 0 || (shapedType.getShape()[dim] > 0 && - shapedType.getShape()[dim] % divisor == 0)); -} - -/// Assume that: -/// 1. `values[0]` is an operands range -/// 2. `values[1]` contains a DictAttr with `operand_number` and `dim` -/// IntegerAttr entries. -/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is -/// dynamic. -static LogicalResult isDimStatic(PatternRewriter &rewriter, ValueRange operands, - Attribute attribute) { - auto dict = attribute.dyn_cast(); - if (!dict) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - dim = dimAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) - shapedType = operands[operandNumber].getType().dyn_cast(); - return success(shapedType && !shapedType.isDynamicDim(dim)); -} - -/// Assume that: -/// 1. `values[0]` is an operands range -/// 2. `values[1]` contains a DictAttr with `operand_number` and `dim` -/// IntegerAttr entries. -/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is -/// dynamic. -static LogicalResult isDimDynamic(PatternRewriter &rewriter, - ValueRange operands, Attribute attribute) { - auto dict = attribute.dyn_cast(); - if (!dict) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - dim = dimAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) - shapedType = operands[operandNumber].getType().dyn_cast(); - return success(shapedType && shapedType.isDynamicDim(dim)); -} - -//===----------------------------------------------------------------------===// -// StructuredTransformOpsExtension -//===----------------------------------------------------------------------===// - -mlir::transform_ext::StructuredTransformOpsExtension:: - StructuredTransformOpsExtension() { - registerTransformOps< -#define GET_OP_LIST -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc" - >(); - - registerPDLMatchConstraintFn("nestedInFunc", nestedInFunc); - registerPDLMatchConstraintFn("isDimDynamic", isDimDynamic); - registerPDLMatchConstraintFn("isDimMultipleOf", isDimMultipleOf); - registerPDLMatchConstraintFn("isDimStatic", isDimStatic); - registerPDLMatchConstraintFn("isEquivalentToOp", isEquivalentToOp); - - declareDependentDialect(); - declareDependentDialect(); -} - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc" - -//===----------------------------------------------------------------------===// -// TrackingListener -//===----------------------------------------------------------------------===// - -/// Find the linalg op that defines all values in range, potentially -/// transitively through tensor casts. -static linalg::LinalgOp findSingleLinalgOpDefiningAll(ValueRange range) { - linalg::LinalgOp sourceOp = nullptr; - for (Value value : range) { - // See through tensor casts and reshape ops. - // - // TODO: we may need some generalization (interfaces?) of this for other - // operations, especially multi-operand ones to understand which of their - // operands may be coming from a Linalg op. Or a completely different - // mechanism of tracking op replacement at creation, or even different - // patterns that identify the "main" result of a transformation. - while (isa( - value.getDefiningOp())) { - value = llvm::TypeSwitch(value.getDefiningOp()) - .Case([](tensor::CastOp op) { return op.getSource(); }) - .Case([](tensor::CollapseShapeOp op) { return op.getSrc(); }) - .Case([](tensor::ExpandShapeOp op) { return op.getSrc(); }) - .Default([](Operation *) { - llvm_unreachable("Wrong op type"); - return Value(); - }); - } - - if (auto currentSourceOp = value.getDefiningOp()) { - if (!sourceOp || sourceOp == currentSourceOp) { - sourceOp = currentSourceOp; - continue; - } - - LLVM_DEBUG( - DBGS() << "different source linalg ops for replacing one op: \n" - << sourceOp << "\n" - << currentSourceOp << "\n"); - return nullptr; - } - LLVM_DEBUG(DBGS() << "replacing linalg op with unknown non-linalg op:\n" - << *value.getDefiningOp() << "\n"); - return nullptr; - } - return sourceOp; -} - -/// Find the scf "for" op that defines all values in the range. -static scf::ForOp findSingleForOpDefiningAll(ValueRange range) { - scf::ForOp forOp = nullptr; - for (Value value : range) { - if (auto currentSourceOp = value.getDefiningOp()) { - if (!forOp || forOp == currentSourceOp) { - forOp = currentSourceOp; - continue; - } - LLVM_DEBUG( - DBGS() << "different source scf.for ops when replacing one op\n"); - return nullptr; - } - - LLVM_DEBUG( - DBGS() - << "could not find a source scf.for when replacing another scf.for\n"); - return nullptr; - } - return forOp; -} - -/// Find the op that defines all values in the range. -static Operation *findSingleOpDefiningAll(ValueRange range) { - Operation *op = nullptr; - for (Value value : range) { - if (auto currentSourceOp = value.getDefiningOp()) { - if (!op || op == currentSourceOp) { - op = currentSourceOp; - continue; - } - LLVM_DEBUG(DBGS() << "different source op when replacing one op\n"); - return nullptr; - } - - LLVM_DEBUG( - DBGS() << "could not find a source op when replacing another op\n"); - return nullptr; - } - return op; -} - -// Find a single op that defines all values in the range, optionally -// transitively through other operations in an op-specific way. -static Operation *findSingleDefiningOp(Operation *replacedOp, - ValueRange range) { - return llvm::TypeSwitch(replacedOp) - .Case([&](linalg::LinalgOp) -> Operation * { - return findSingleLinalgOpDefiningAll(range); - }) - .Case([&](scf::ForOp) -> Operation * { - return findSingleForOpDefiningAll(range); - }) - .Default([&](Operation *) -> Operation * { - return findSingleOpDefiningAll(range); - }); -} - -void mlir::TrackingListener::notifyRootReplaced(Operation *op, - ValueRange newValues) { - // Bail out if in error state. - if (hadErrors) - return; - - // Exit early if the op is not tracked. - SmallVector handles; - if (failed(getTransformState().getHandlesForPayloadOp(op, handles))) - return; - - Operation *replacement = findSingleDefiningOp(op, newValues); - if (!replacement) { - emitError(op) << "could not find replacement for tracked op"; - return; - } - - LLVM_DEBUG(DBGS() << "replacing tracked @" << op << " : " << *op << " with " - << *replacement << "\n"); - mayFail(replacePayloadOp(op, replacement)); -} - -void mlir::TrackingListener::notifyOperationRemoved(Operation *op) { - // Bail out if in error state. - if (hadErrors) - return; - - // Exit early if the op is not tracked. - SmallVector handles; - if (failed(getTransformState().getHandlesForPayloadOp(op, handles))) - return; - - LLVM_DEBUG(DBGS() << "removing tracked @" << op << " : " << *op << "\n"); - mayFail(replacePayloadOp(op, nullptr)); -} - -void mlir::TrackingListener::removeMappings(Operation *op) { - // Bail if in error state. - if (hadErrors) - return; - - // Replacing the tracked op with null will stop the tracking. - LLVM_DEBUG(DBGS() << "removing mappings @" << op << " : " << *op << "\n"); - mayFail(replacePayloadOp(op, nullptr)); -} - -//===----------------------------------------------------------------------===// -// CanonicalizedSequenceOp -//===----------------------------------------------------------------------===// - -void ::transform_ext::CanonicalizedSequenceOp::build( - OpBuilder &builder, OperationState &state, - transform::FailurePropagationMode failurePropagationMode, - ::transform_ext::CanonicalizedSequenceOp::BodyBuilderFn bodyBuilder) { - assert(state.name.isRegistered() && "not registered!!"); - assert(bodyBuilder && "requires a body builder"); - MLIRContext *ctx = builder.getContext(); - state.addAttribute( - CanonicalizedSequenceOp::getFailurePropagationModeAttrName(state.name), - transform::FailurePropagationModeAttr::get(ctx, failurePropagationMode)); - Region *bodyRegion = state.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(pdl::OperationType::get(ctx), state.location); - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, state.location, bodyBlock.getArgument(0)); -} - -/// Run enabling transformations (LICM and its variants, single-iteration loop -/// removal, CSE) on the given function. -static LogicalResult performEnablerTransformations( - func::FuncOp func, RewriteListener &listener, - LinalgEnablingOptions options = LinalgEnablingOptions()) { - MLIRContext *ctx = func->getContext(); - RewritePatternSet patterns(ctx); - linalg::populateLinalgTilingCanonicalizationPatterns(patterns); - scf::populateSCFForLoopCanonicalizationPatterns(patterns); - if (failed(applyPatternsTrackAndFoldGreedily(func, listener, - std::move(patterns)))) - return failure(); - - // This assumes LICM never removes operations so we don't need tracking. - if (options.licm) { - func->walk( - [](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); - } - - func.walk([](Operation *op) { - (void)llvm::TypeSwitch(op) - .Case( - [](auto loop) { return promoteIfSingleIteration(loop); }) - .Default([](Operation *) { return success(); }); - }); - - if (options.hoistRedundantVectorTransfers) - linalg::hoistRedundantVectorTransfers(func); - if (options.hoistRedundantVectorTransfersOnTensor) - linalg::hoistRedundantVectorTransfersOnTensor(func); - - return eliminateCommonSubexpressions(func, /*domInfo=*/nullptr, &listener); -} - -/// Run enabling transformations on the given `containerOp` while preserving the -/// operation tracking information. -static LogicalResult performEnablerTransformations( - Operation *containerOp, RewriteListener &listener, - LinalgEnablingOptions options = LinalgEnablingOptions()) { - auto res = containerOp->walk([&](func::FuncOp func) { - if (failed(performEnablerTransformations(func, listener, options))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - return failure(res.wasInterrupted()); -} - -/// Drop the association between payload operations and transform dialect -/// handles when it is no longer necessary in a canonicalized sequence. -/// Specifically, drop the association between payload operations and the -/// operand handles if all handles to them will not be used after the current -/// `transform`. Also drop the association between payload operations and result -/// handles if results are never read. Note that the operand part is specific to -/// sequence-like execution that is not guaranteed in the transform dialect in -/// general. -static void -forgetUnnecessaryHandles(transform::TransformState &state, - transform_ext::CanonicalizedSequenceOp sequence, - transform::TransformOpInterface transform) { - auto *listener = state.getExtension(); - assert(transform->getParentOp() == sequence && - "only works for transform ops immediately nested in a canonicalized " - "sequence"); - assert(listener && "expected tracking listener to be present"); - - // Checks if the operation or its ancestor is before `transform` in its block - // or is `transform` itself. - auto userIsBefore = [&](Operation *user) { - while (user && user->getParentOp() != sequence) - user = user->getParentOp(); - if (!user) - return false; - return user->isBeforeInBlock(transform) || user == transform; - }; - - // Drop associations for operands that will not be read again. Ignore consumed - // operands that have been deassociated already. Consider all handles to each - // payload operation and only drop the association if all handles pointing to - // the same operation will are not used after the current transform op. The - // handle will be erased automatically after the last payload operation is - // deassociated from it. - llvm::SmallDenseSet seen; - llvm::SmallDenseMap handlesUsedAfterTransform; - for (Value operand : transform->getOperands()) { - if (transform::isHandleConsumed(operand, transform)) - continue; - - for (Operation *payload : state.getPayloadOps(operand)) { - if (!payload || seen.contains(payload)) - continue; - SmallVector allHandles; - (void)state.getHandlesForPayloadOp(payload, allHandles); - bool allHandlesUnused = llvm::all_of(allHandles, [&](Value handle) { - if (!handlesUsedAfterTransform.count(handle)) { - handlesUsedAfterTransform[handle] = - !llvm::all_of(handle.getUsers(), userIsBefore); - } - return !handlesUsedAfterTransform[handle]; - }); - if (allHandlesUnused) { - listener->removeMappings(payload); - seen.insert(payload); - } - } - } - - // Drop associations for results that will never be read. - for (Value result : transform->getResults()) { - if (!result.getUses().empty()) - continue; - for (Operation *payload : state.getPayloadOps(result)) { - if (!payload || seen.contains(payload)) - continue; - listener->removeMappings(payload); - seen.insert(payload); - } - } -} - -DiagnosedSilenceableFailure transform_ext::CanonicalizedSequenceOp::apply( - transform::TransformResults &results, transform::TransformState &state) { - - MLIRContext *ctx = getContext(); - RewritePatternSet patternList(ctx); - for (Dialect *dialect : ctx->getLoadedDialects()) - dialect->getCanonicalizationPatterns(patternList); - for (RegisteredOperationName op : ctx->getRegisteredOperations()) - op.getCanonicalizationPatterns(patternList, ctx); - FrozenRewritePatternSet patterns(std::move(patternList)); - - transform::TransformState::RegionScope regionScope = - state.make_region_scope(getBodyRegion()); - auto &listener = state.addExtension<::mlir::TrackingListener>(); - auto detachListener = llvm::make_scope_exit( - [&] { state.removeExtension<::mlir::TrackingListener>(); }); - if (failed(mapBlockArguments(state))) - return DiagnosedSilenceableFailure::definiteFailure(); - - auto checkedListenerTransform = - [&](function_ref - transform) { - SmallVector roots; - if (Value root = getRoot()) - llvm::append_range(roots, state.getPayloadOps(root)); - else - roots.push_back(state.getTopLevel()); - - for (Operation *target : roots) { - // Make sure we always check the error state, no boolean - // short-circuting. - if (failed(transform(target, listener))) { - target->emitOpError("Transform application failed."); - return failure(); - } - if (failed(listener.checkErrorState())) { - target->emitOpError("Listener failed."); - return failure(); - } - } - return success(); - }; - - auto performCSE = [](Operation *root, RewriteListener &listener) { - LogicalResult result = - eliminateCommonSubexpressions(root, /*domInfo=*/nullptr, &listener); - LLVM_DEBUG( - DBGS() << (succeeded(result) ? "successfully performed" : "failed") - << " CSE\n"); - return result; - }; - auto performEnabler = [](Operation *root, RewriteListener &listener) { - LogicalResult result = performEnablerTransformations(root, listener); - LLVM_DEBUG( - DBGS() << (succeeded(result) ? "successfully performed" : "failed") - << " enabling transformations\n"); - return result; - }; - auto performCanonicalization = [&patterns](Operation *root, - RewriteListener &listener) { - LogicalResult result = - applyPatternsTrackAndFoldGreedily(root, listener, patterns); - LLVM_DEBUG( - DBGS() << (succeeded(result) ? "successfully performed" : "failed") - << " canonicalization\n"); - return result; - }; - - LLVM_DEBUG(DBGS() << "begin canonicalizing sequence\n"); - if (failed(checkedListenerTransform(performCSE))) { - return mlir::emitDefiniteFailure( - *this, "Failed to performCSE beform transform sequence"); - } - if (failed(checkedListenerTransform(performCanonicalization))) { - return mlir::emitDefiniteFailure( - *this, "Failed to performCanonicalization beform transform sequence"); - } - - // Apply the sequenced ops one by one. - for (Operation &transform : getBodyBlock()->without_terminator()) { - auto transformOp = cast(transform); - DiagnosedSilenceableFailure result = state.applyTransform(transformOp); - if (result.isDefiniteFailure()) { - LLVM_DEBUG(DBGS() << "failed: " << transform << "\n"); - return result; - } - if (result.isSilenceableFailure()) { - LLVM_DEBUG(DBGS() << "failed silently: " << transform << "\n"); - if (getFailurePropagationMode() == - transform::FailurePropagationMode::Propagate) - return result; - (void)result.silence(); - } - LLVM_DEBUG(DBGS() << "successfully performed: " << transform << "\n"); - - // Canonicalization may replace payload operations associated with the - // transform dialect handles. Post-canonicalize reassociation is fragile and - // may fail. To make this less likely, drop any association that are no - // longer necessary, i.e., if the operand is no longer used in the sequence - // or elsewhere or if the result is never read. - forgetUnnecessaryHandles(state, *this, transformOp); - - if (failed(checkedListenerTransform(performCSE))) { - return mlir::emitDefiniteFailure(&transform, - "Failed to performCSE after transform"); - } - if (failed(checkedListenerTransform(performEnabler))) { - return mlir::emitDefiniteFailure( - &transform, "Failed to performEnabler after transform"); - } - if (failed(checkedListenerTransform(performCanonicalization))) { - return mlir::emitDefiniteFailure( - &transform, "Failed to performCanonicalization after transform"); - } - } - - // Forward the operation mapping for values yielded from the sequence to the - // values produced by the sequence op. - for (const auto &pair : - llvm::zip(getBodyBlock()->getTerminator()->getOperands(), - getOperation()->getOpResults())) { - Value terminatorOperand = std::get<0>(pair); - OpResult result = std::get<1>(pair); - results.set(result, state.getPayloadOps(terminatorOperand)); - } - - LLVM_DEBUG(DBGS() << "end canonicalizing sequence\n"); - return DiagnosedSilenceableFailure::success(); -} - -/// Returns `true` if the given op operand may be consuming the handle value in -/// the Transform IR. That is, if it may have a Free effect on it. -static bool isValueUsePotentialConsumer(OpOperand &use) { - // Conservatively assume the effect being present in absence of the interface. - auto memEffectInterface = dyn_cast(use.getOwner()); - if (!memEffectInterface) - return true; - - SmallVector effects; - memEffectInterface.getEffectsOnValue(use.get(), effects); - return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { - return isa(effect.getResource()) && - isa(effect.getEffect()); - }); -} - -// TODO: Add declaration to TransformOps.h, then we do not have to duplicate -// this function. -static LogicalResult -checkDoubleConsume(Value value, - function_ref reportError) { - OpOperand *potentialConsumer = nullptr; - for (OpOperand &use : value.getUses()) { - if (!isValueUsePotentialConsumer(use)) - continue; - - if (!potentialConsumer) { - potentialConsumer = &use; - continue; - } - - InFlightDiagnostic diag = reportError() - << " has more than one potential consumer"; - diag.attachNote(potentialConsumer->getOwner()->getLoc()) - << "used here as operand #" << potentialConsumer->getOperandNumber(); - diag.attachNote(use.getOwner()->getLoc()) - << "used here as operand #" << use.getOperandNumber(); - return diag; - } - - return success(); -} - -LogicalResult transform_ext::CanonicalizedSequenceOp::verify() { - // Check if the block argument has more than one consuming use. - for (BlockArgument argument : getBodyBlock()->getArguments()) { - auto report = [&]() { - return (emitOpError() << "block argument #" << argument.getArgNumber()); - }; - if (failed(checkDoubleConsume(argument, report))) - return failure(); - } - - // Check properties of the nested operations they cannot check themselves. - for (Operation &child : *getBodyBlock()) { - if (!isa(child) && - &child != &getBodyBlock()->back()) { - InFlightDiagnostic diag = - emitOpError() - << "expected children ops to implement TransformOpInterface"; - diag.attachNote(child.getLoc()) << "op without interface"; - return diag; - } - - for (OpResult result : child.getResults()) { - auto report = [&]() { - return (child.emitError() << "result #" << result.getResultNumber()); - }; - if (failed(checkDoubleConsume(result, report))) - return failure(); - } - } - - if (getBodyBlock()->getTerminator()->getOperandTypes() != - getOperation()->getResultTypes()) { - InFlightDiagnostic diag = emitOpError() - << "expects the types of the terminator operands " - "to match the types of the result"; - diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; - return diag; - } - return success(); -} - -void transform_ext::CanonicalizedSequenceOp::getEffects( - SmallVectorImpl &effects) { - auto *mappingResource = transform::TransformMappingResource::get(); - // Effects on root if present. - if (getRoot()) - effects.emplace_back(MemoryEffects::Read::get(), getRoot(), - mappingResource); - // Effects on results. - for (Value result : getResults()) { - effects.emplace_back(MemoryEffects::Allocate::get(), result, - mappingResource); - effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); - } - - for (Operation &op : *getBodyBlock()) { - auto iface = dyn_cast(&op); - if (!iface) { - // TODO: fill all possible effects; or require ops to actually implement - // the memory effect interface always - assert(false); - } - if (getRoot()) { - // Carry over all effects on the argument of the entry block as those on - // the operand, this is the same value just remapped. - SmallVector nestedEffects; - iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); - for (const auto &effect : nestedEffects) - effects.emplace_back(effect.getEffect(), getRoot(), - effect.getResource()); - } else { - // Otherwise, get all the effects. - iface.getEffects(effects); - } - } -} - -OperandRange transform_ext::CanonicalizedSequenceOp::getSuccessorEntryOperands( - Optional index) { - assert(index && index.value() == 0 && "unexpected region index"); - if (getOperation()->getNumOperands() == 1) - return getOperation()->getOperands(); - return OperandRange(getOperation()->operand_end(), - getOperation()->operand_end()); -} - -void transform_ext::CanonicalizedSequenceOp::getSuccessorRegions( - Optional index, ArrayRef operands, - SmallVectorImpl ®ions) { - if (!index.has_value()) { - Region *bodyRegion = &getBody(); - regions.emplace_back(bodyRegion, !operands.empty() - ? bodyRegion->getArguments() - : Block::BlockArgListType()); - return; - } - - assert(*index == 0 && "unexpected region index"); - regions.emplace_back(getOperation()->getResults()); -} - -void transform_ext::CanonicalizedSequenceOp::getRegionInvocationBounds( - ArrayRef operands, SmallVectorImpl &bounds) { - (void)operands; - bounds.emplace_back(1, 1); -} - -//===----------------------------------------------------------------------===// -// TODO: WILL MIGRATE -//===----------------------------------------------------------------------===// - -using namespace mlir::linalg; - -//===---------------------------------------------------------------------===// -// BufferizeOp -//===---------------------------------------------------------------------===// - -static void applyBufferizationEnablingTransformations(ModuleOp moduleOp) { - RewritePatternSet patterns(moduleOp.getContext()); - patterns.add(moduleOp.getContext()); - (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); -} - -DiagnosedSilenceableFailure -transform_ext::BufferizeOp::apply(mlir::transform::TransformResults &result, - mlir::transform::TransformState &state) { - bufferization::OneShotBufferizationOptions options; - options.bufferizeFunctionBoundaries = true; - options.memCpyFn = [](OpBuilder &builder, Location loc, Value from, - Value to) { - return success(linalg::makeMemRefCopyOp(builder, loc, from, to)); - }; - - auto moduleOp = cast(state.getTopLevel()); - applyBufferizationEnablingTransformations(moduleOp); - if (failed(runOneShotModuleBufferize(moduleOp, options))) - return DiagnosedSilenceableFailure::definiteFailure(); - - // Perform buffer-level hoistings. - state.getTopLevel()->walk( - [&](func::FuncOp funcOp) { hoistRedundantVectorTransfers(funcOp); }); - return DiagnosedSilenceableFailure::success(); -} - -//===---------------------------------------------------------------------===// -// LowerToLLVMOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform_ext::LowerToLLVMOp::apply(mlir::transform::TransformResults &result, - mlir::transform::TransformState &state) { - // TODO: it is feasible to scope lowering at arbitrary level and introduce - // unrealized casts, but there needs to be the final module-wise cleanup in - // the end. Keep module-level for now. - PassManager pm(getContext()); - - pm.addNestedPass(createConvertVectorToSCFPass()); - pm.addNestedPass(createConvertLinalgToLoopsPass()); - if (getEnableAsync()) { - pm.addPass(createAsyncToAsyncRuntimePass()); - pm.addPass(createAsyncRuntimeRefCountingPass()); - pm.addPass(createAsyncRuntimeRefCountingOptPass()); - } - pm.addPass(createCanonicalizerPass()); - pm.addPass(createLowerAffinePass()); - pm.addPass(createConvertSCFToCFPass()); - pm.addPass(createConvertLinalgToLLVMPass()); - pm.addPass(createConvertVectorToLLVMPass( - // clang-format off - LowerVectorToLLVMOptions() - .enableReassociateFPReductions(getReassociateFpReductions()) - .enableIndexOptimizations(getEnableIndexOptimizations()) - .enableArmNeon(getEnableArmNeon()) - .enableArmSVE(getEnableArmSve()) - .enableAMX(getEnableAmx()) - .enableX86Vector(getEnableX86vector()))); - // clang-format on - pm.addNestedPass(createConvertMathToLLVMPass()); - pm.addPass(createMemRefToLLVMConversionPass()); - if (getEnableAsync()) - pm.addPass(createConvertAsyncToLLVMPass()); - pm.addPass(createConvertFuncToLLVMPass()); - pm.addPass(createReconcileUnrealizedCastsPass()); - if (failed(pm.run(state.getTopLevel()))) - return DiagnosedSilenceableFailure::definiteFailure(); - - // Make all arguments noalias for now. - // FIXME: this is a terrible hack! - state.getTopLevel()->walk([](LLVM::LLVMFuncOp funcOp) { - for (int64_t i = 0; i < funcOp.getNumArguments(); ++i) { - if (!funcOp.getFunctionType() - .getParamType(i) - .isa()) - continue; - funcOp.setArgAttr(i, "llvm.noalias", UnitAttr::get(funcOp.getContext())); - } - }); - return DiagnosedSilenceableFailure::success(); -} - -//===---------------------------------------------------------------------===// -// LowerVectorsOp -//===---------------------------------------------------------------------===// - -/// Returns true of the numbered vector lowering stage is included into the list -/// of stages specified on the given lowerVectors operation. -static bool stageIncluded(int stage, - transform_ext::LowerVectorsOp lowerVectorsOp) { - for (auto s : lowerVectorsOp.getStages().getAsValueRange()) { - if (s.getSExtValue() == stage) - return true; - } - return false; -} - -// Applies the transformation specified by the given lower vectors operation -/// to the given function. -DiagnosedSilenceableFailure -transform_ext::LowerVectorsOp::apply(mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - MLIRContext *ctx = getContext(); - RewritePatternSet patterns(ctx); - - vector::VectorTransposeLowering vectorTransposeLowering = - llvm::StringSwitch( - getTransposeLowering()) - .Case("eltwise", vector::VectorTransposeLowering::EltWise) - .Case("flat_transpose", vector::VectorTransposeLowering::Flat) - .Case("shuffle", vector::VectorTransposeLowering::Shuffle) - .Default(vector::VectorTransposeLowering::EltWise); - vector::VectorMultiReductionLowering vectorMultiReductionLowering = - llvm::StringSwitch( - getMultireductionLowering()) - .Case("innerreduction", - vector::VectorMultiReductionLowering::InnerReduction) - .Default(vector::VectorMultiReductionLowering::InnerParallel); - vector::VectorContractLowering vectorContractLowering = - llvm::StringSwitch( - getContractionLowering()) - .Case("matrixintrinsics", vector::VectorContractLowering::Matmul) - .Case("dot", vector::VectorContractLowering::Dot) - .Case("outerproduct", vector::VectorContractLowering::OuterProduct) - .Default(vector::VectorContractLowering::OuterProduct); - // TODO: fix the annoying name mismatch (vector-transfers vs vector-transfer). - vector::VectorTransferSplit vectorTransferSplit = - llvm::StringSwitch(getSplitTransfers()) - .Case("none", vector::VectorTransferSplit::None) - .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy) - .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer) - .Default(vector::VectorTransferSplit::None); - - vector::VectorTransformsOptions vectorTransformOptions; - vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering) - .setVectorMultiReductionLowering(vectorMultiReductionLowering) - .setVectorTransposeLowering(vectorTransposeLowering) - .setVectorTransferSplit(vectorTransferSplit); - - VectorTransferToSCFOptions vectorTransferToSCFOptions = - VectorTransferToSCFOptions().enableFullUnroll(getUnrollVectorTransfers()); - - int maxTransferRank = 1; - - auto avx2LoweringOptions = - x86vector::avx2::LoweringOptions().setTransposeOptions( - x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32(getTransposeAvx2Lowering()) - .lower8x8xf32(getTransposeAvx2Lowering())); - - // TODO: this is copy-pasta from LinalgStrategyLowerVectorsPass, shouldn't be. - vector::populateVectorToVectorCanonicalizationPatterns(patterns); - if (stageIncluded(1, *this)) { - patterns.add(vectorTransformOptions, - ctx); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - } - if (stageIncluded(2, *this)) { - vector::populateVectorMultiReductionLoweringPatterns( - patterns, vectorTransformOptions.vectorMultiReductionLowering); - } - if (stageIncluded(3, *this)) { - patterns.add( - ctx, vectorTransformOptions); - } - if (stageIncluded(4, *this)) { - vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank); - } - if (stageIncluded(5, *this)) { - populateVectorToSCFConversionPatterns( - patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank)); - } - if (stageIncluded(6, *this)) { - vector::populateVectorShapeCastLoweringPatterns(patterns); - } - if (stageIncluded(7, (*this))) { - vector::populateVectorTransposeLoweringPatterns(patterns, - vectorTransformOptions); - if (getTransposeAvx2Lowering()) - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx2LoweringOptions, /*benefit=*/10); - } - - // TODO: these transformations are currently not targeted at concrete ops. - // LinalgTransformationFilter filter = makeTransformationFilter(target); - if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(), - std::move(patterns)))) - return DiagnosedSilenceableFailure::definiteFailure(); - - // TODO: make composable... - return DiagnosedSilenceableFailure::success(); -} - -//===---------------------------------------------------------------------===// -// MatchCallbackOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform_ext::MatchCallbackOp::apply( - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - auto setEmptyResults = [&results, this] { - for (OpResult value : getResults()) { - results.set(value, {}); - } - }; - auto errorOut = [this, &setEmptyResults] { - setEmptyResults(); - return emitSilenceableError(); - }; - - auto *registry = state.getExtension(); - if (!registry) - return errorOut() << "match registry not available"; - - const transform_ext::MatchCallbacksRegistry::MatchCallbackFn *callback = - registry->get(getCallbackName()); - if (!callback) { - return errorOut() << "callback '" << getCallbackName() - << "' not found in the registry"; - } - - MatchCallbackResult result; - DiagnosedSilenceableFailure status = - (*callback)(result, getLoc(), state, getInputs()); - if (!status.succeeded()) { - setEmptyResults(); - if (status.isDefiniteFailure()) - return status; - if (getFailurePropagationMode() == - mlir::transform::FailurePropagationMode::Propagate) { - return emitSilenceableError() << "failed to match"; - } else { - return DiagnosedSilenceableFailure::success(); - } - } - if (getNumResults() != result.getNumPayloadGroups()) { - return errorOut() - << "callback produced a different number of handles than expected ( " - << result.getNumPayloadGroups() << " vs " << getNumResults() << " )"; - } - - for (OpResult value : getResults()) { - results.set(value, result.getPayloadGroup(value.getResultNumber())); - } - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::MatchCallbackOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getInputs(), effects); - mlir::transform::producesHandle(getOutputs(), effects); - // TODO: it doesn't really modify the payload, we need a separate resource for - // this mapping. - mlir::transform::modifiesPayload(effects); -} - -//===---------------------------------------------------------------------===// -// RegisterMatchCallbacksOp -//===---------------------------------------------------------------------===// - -/// Match callback for "_test_match_callback" hook. Matches any payload -/// operations associated with operand handles unless they have the -/// "test.iree_transform_do_not_match" attribute, in which case produces a -/// silenceable failure. -static DiagnosedSilenceableFailure -testMatchCallbackCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - bool hadFailures = false; - for (Value handle : handles) { - if (llvm::any_of(state.getPayloadOps(handle), [](Operation *op) { - return op->hasAttr("test.iree_transform_do_not_match"); - })) { - res.addPayloadGroup(ArrayRef()); - hadFailures = true; - } else { - res.addPayloadGroup(state.getPayloadOps(handle)); - } - } - if (hadFailures) - return emitSilenceableFailure(loc) << "failed to match"; - return DiagnosedSilenceableFailure::success(); -} - -/// Match callback for a reduction with optional leading and trailing -/// elementwise operations. Matches *the first* occurrence of such a reduction -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - leading elementwise op, if any; -/// - the "fill" op preceding the reduction; -/// - reduction op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -reductionCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher pattern, fill, leadingEltwise, - trailingEltwise; - makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = state.getPayloadOps(handles[0])[0]; - WalkResult walkResult = root->walk([&](Operation *op) { - pattern.resetCapture(); - if (!matchPattern(op, pattern)) - return WalkResult::advance(); - - res.addPotentiallyEmptyPayloadGroup(leadingEltwise.getCaptured()); - res.addPayloadGroup({fill.getCaptured()}); - res.addPayloadGroup({pattern.getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailingEltwise.getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) - return DiagnosedSilenceableFailure::success(); - return emitSilenceableFailure(loc) << "failed to match"; -} - -/// Match callback for a reduction after splitting with optional leading and -/// trailing elementwise operations. Matches *the first* occurrence of such a -/// reduction within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - leading elementwise op, if any; -/// - the "fill" op preceding the original reduction; -/// - the "fill" op preceding the split, more parallel reduction; -/// - the split, more parallel reduction op; -/// - reduction op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -splitReductionCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher parallel_reduction, combiner_reduction, - parallel_fill, original_fill, leading, trailing; - makeSplitReductionMatcher(parallel_reduction, combiner_reduction, - parallel_fill, original_fill, leading, trailing); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = state.getPayloadOps(handles[0])[0]; - WalkResult walkResult = root->walk([&](Operation *op) { - combiner_reduction.resetCapture(); - if (!matchPattern(op, combiner_reduction)) - return WalkResult::advance(); - - res.addPotentiallyEmptyPayloadGroup(leading.getCaptured()); - res.addPayloadGroup({original_fill.getCaptured()}); - res.addPayloadGroup({parallel_fill.getCaptured()}); - res.addPayloadGroup({parallel_reduction.getCaptured()}); - res.addPayloadGroup({combiner_reduction.getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailing.getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) - return DiagnosedSilenceableFailure::success(); - return emitSilenceableFailure(loc) << "failed to match"; -} - -DiagnosedSilenceableFailure transform_ext::RegisterMatchCallbacksOp::apply( - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - auto ®istry = state.addExtension(); - registry.registerCallback("_test_match_callback", testMatchCallbackCallback); - registry.registerCallback("reduction", reductionCallback); - registry.registerCallback("split_reduction", splitReductionCallback); - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::RegisterMatchCallbacksOp::getEffects( - SmallVectorImpl &effects) { - // TODO: it doesn't really modify the payload, we need a separate resource for - // this mapping. - mlir::transform::modifiesPayload(effects); -} - -//===---------------------------------------------------------------------===// -// TakeFirstOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform_ext::TakeFirstOp::apply(mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - SmallVector concatenated; - bool found = false; - for (Value handle : getInputs()) { - ArrayRef payloads = state.getPayloadOps(handle); - if (payloads.empty()) - continue; - if (!found) { - results.set(getFirst().cast(), payloads); - found = true; - } else { - llvm::append_range(concatenated, payloads); - } - } - - if (!found) - results.set(getFirst().cast(), {}); - results.set(getRest().cast(), concatenated); - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::TakeFirstOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getInputs(), effects); - mlir::transform::producesHandle(getFirst(), effects); - mlir::transform::producesHandle(getRest(), effects); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt deleted file mode 100644 index 682e5147dfba..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -add_mlir_library(IREELinalgTransformDialectPasses - ExpertExpansion.cpp - TransformInterpreter.cpp - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - IREELinalgTransformDialect - - MLIRBufferizationDialect - MLIRIR - MLIRLinalgDialect - MLIRLLVMDialect - MLIRMathDialect - MLIRMathToLLVM - MLIRMemRefDialect - MLIRMemRefToLLVM - MLIRPass - MLIRTensorDialect - MLIRTransforms - MLIRVectorDialect - MLIRVectorToLLVM -) diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/ExpertExpansion.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/ExpertExpansion.cpp deleted file mode 100644 index 72a335046490..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/ExpertExpansion.cpp +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" -#include "iree-dialects/Dialect/LinalgTransform/Passes.h" -#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h" -#include "mlir/Dialect/PDL/IR/PDLOps.h" -#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Rewrite/PatternApplicator.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "expert-expansion" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]") - -using namespace mlir; - -/// Expands the linalg::transform::ExpertOp instances in the `module` into lists -/// of transformations as described by the `expansions` module that contains -/// PDL. -static void expandStrategyOps(ModuleOp module, ModuleOp expansions) { - mlir::OwningOpRef clonedExpansions( - cast(expansions->clone())); - RewritePatternSet patterns(std::move(clonedExpansions)); - FrozenRewritePatternSet frozen(std::move(patterns)); - PatternApplicator applicator(frozen); - applicator.applyDefaultCostModel(); - - module.walk([&](linalg::transform::ExpertOp expertOp) { - SimplePatternRewriter rewriter(expertOp); - if (failed(applicator.matchAndRewrite(expertOp, rewriter))) { - LLVM_DEBUG(DBGS() << "failed to rewrite strategy \"" - << expertOp.getExpertName() << "\"\n"); - } - }); -} - -namespace { -struct ExpertExpansion : public PassWrapper { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ExpertExpansion) - - Pass::Option strategyModuleName{ - *this, "strategy-module-name", llvm::cl::init("strategies"), - llvm::cl::desc( - "Name of the nested module containing expert strategies.")}; - - explicit ExpertExpansion(StringRef name = "strategies") - : PassWrapper() { - strategyModuleName = name.str(); - } - - ExpertExpansion(const ExpertExpansion &other) - : PassWrapper(other) { - strategyModuleName = other.strategyModuleName.getValue(); - } - - StringRef getArgument() const final { - return "linalg-transform-expert-expansion"; - } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - StringRef getDescription() const final { - return "Expands transformation experts into individual transformations"; - } - - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } - - void runOnOperation() override { - auto module = dyn_cast(getOperation()); - if (!module) - return signalPassFailure(); - - ModuleOp strategyModule = nullptr; - for (auto nestedModule : module.getOps()) { - Optional name = nestedModule.getSymName(); - if (!name) - continue; - - if (*name == strategyModuleName) { - if (!strategyModule) { - strategyModule = nestedModule; - continue; - } - InFlightDiagnostic diag = nestedModule->emitError() - << "more than one strategy module provided"; - diag.attachNote(strategyModule->getLoc()) << "previous strategy module"; - return signalPassFailure(); - } - } - - if (!strategyModule) { - module->emitError() << "expected a nested strategy module"; - return signalPassFailure(); - } - - expandStrategyOps(module, strategyModule); - strategyModule->erase(); - } -}; -} // namespace - -void mlir::linalg::transform::registerLinalgTransformExpertExpansionPass() { - PassRegistration( - []() { return std::make_unique(); }); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp deleted file mode 100644 index d499db9f35da..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgTransform/Passes/TransformInterpreter.cpp +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" -#include "iree-dialects/Dialect/LinalgTransform/Passes.h" -#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/PDL/IR/PDL.h" -#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/FileUtilities.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/SourceMgr.h" - -#define DEBUG_TYPE "transform-dialect-interpreter" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") - -using namespace mlir; - -LogicalResult mlir::transform::parseTransformModuleFromFile( - MLIRContext *context, llvm::StringRef transformFileName, - OwningOpRef &transformModule) { - if (transformFileName.empty()) { - llvm::errs() << "no transform file name specified, assuming the transform " - "module is embedded in the IR next to the top-level\n"; - return success(); - } - // Parse transformFileName content into a ModuleOp. - std::string errorMessage; - auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage); - if (!memoryBuffer) { - llvm::errs() << "failed to parse transform file: " << transformFileName - << "\n"; - return failure(); - } - // Tell sourceMgr about this buffer, the parser will pick it up. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); - transformModule = - OwningOpRef(parseSourceFile(sourceMgr, context)); - return success(); -} - -LogicalResult mlir::transform::applyTransformsInRegion(Region &transformRegion, - Operation *target) { - SmallVector transforms; - if (failed( - transform::extractTopLevelTransformOps(transformRegion, transforms))) - return failure(); - - for (transform::TransformOpInterface transform : transforms) { - // TransformState::applyTransform requires that the parent region is a - // proper ancestor of the transform op to perform SSA liveness assertions. - // In multithreaded state however, we cannot clone into `transformRegion` so - // we build a new single-block region and clone the transform op into it. - Region r; - OpBuilder b(target->getContext()); - b.createBlock(&r); - TransformOptions options; -#ifndef NDEBUG - options = options.enableExpensiveChecks(); -#endif - auto xform = cast(b.clone(*transform)); - auto g = llvm::make_scope_exit([&]() { xform->erase(); }); - if (failed(transform::applyTransforms(target, xform, options))) - return failure(); - } - return success(); -} - -LogicalResult mlir::transform::extractTopLevelTransformOps( - Region &r, SmallVectorImpl &res) { - assert(r.getBlocks().size() == 1 && - "Expected single-block region to extract transform ops from"); - r.walk([&](transform::TransformOpInterface transform) { - if (transform->hasTrait()) { - assert(llvm::none_of(res, [&](transform::TransformOpInterface seen) { - return seen->isAncestor(transform); - })); - res.push_back(transform); - return WalkResult::skip(); - } - return WalkResult::advance(); - }); - return success(); -} - -namespace { - -/// Pass declaration. -/// Interpreter pass that applies transform dialect ops for codegen. -/// This needs to be its own pass because the registration mechanism and ops -/// available are different than for other interpreters. -class TransformDialectInterpreter - : public PassWrapper { -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectInterpreter) - - void getDependentDialects(DialectRegistry ®istry) const override { - // TODO: this is only necessary to make registry subset happy when running - // the lowering to LLVM. The lowering should be changed to stop using the - // nested pass manager and this will go away. - - // clang-format off - registry.insert(); - - // TODO: these should be registered by the extension instead, but there is - // no support for it in core currently. - arith::registerBufferizableOpInterfaceExternalModels(registry); - linalg::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerBufferizableOpInterfaceExternalModels(registry); - bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( - registry); - tensor::registerBufferizableOpInterfaceExternalModels(registry); - vector::registerBufferizableOpInterfaceExternalModels(registry); - } - - StringRef getArgument() const override { - return "transform-dialect-interpreter"; - } - - StringRef getDescription() const override { - return "apply transform dialect operations one by one"; - } - - bool canScheduleOn(RegisteredOperationName name) const override { - return true; - } - - TransformDialectInterpreter(StringRef transformFileName = StringRef()) { - this->transformFileName = transformFileName.str(); - } - TransformDialectInterpreter(const TransformDialectInterpreter &pass) { - this->transformFileName = pass.transformFileName; - // TODO: if we really don't like shared_ptr, we could also clone the - // transformModule here. - sharedTransformModule = pass.sharedTransformModule; - } - - LogicalResult initialize(MLIRContext *context) override { - OwningOpRef module; - if (failed(transform::parseTransformModuleFromFile( - context, transformFileName, module))) - return failure(); - sharedTransformModule = - std::make_shared>(std::move(module)); - return success(); - } - - void runOnOperation() override { - Operation *target = getOperation(); - bool parsedTransform = (sharedTransformModule && *sharedTransformModule); - assert(parsedTransform || (target->getNumRegions() == 1 && - target->getRegion(0).getBlocks().size() == 1) && - "Cannot extract transform from op"); - Region &transformRegion = parsedTransform - ? (*sharedTransformModule)->getRegion() - : target->getRegion(0); - if (failed(transform::applyTransformsInRegion(transformRegion, target))) - return signalPassFailure(); - } - -protected: - Pass::Option transformFileName{ - *this, "transform-file-name", - ::llvm::cl::desc( - "File name containing a transform dialect specification to apply.")}; - -private: - // The parsed transform module to be used for scheduling. - // TODO: Figure a better way to build a transform module and transport it in - // the proper places in the IR as it is transformed by IREE so that it is - // available with better ownership semantics. - // Note: we wrap the OwningOpRef to get the desired destruction mechanism. - // Note: shared_ptr is not great but we know the sharedTransformModule is - // readonly. - // Alternatives comprise: - // 1. no shared_ptr but copying the module with every pass clone that the - // OpPassManager decides to perform. - // 2. lifting ownership of the parsed transform module higher up in the - // IREE stack. This may be only shift the problem as we have passes - // building pass managers in IREE. - // 3. build better support to embed the transformation module in the - // input IR and transport it to the place of use in IREE. This is deemed - // too intrusive atm. - // 4. (future) config/resources mechanism that is being proposed in core? - std::shared_ptr> sharedTransformModule; -}; - -struct DropSchedulePass : public PassWrapper { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DropSchedulePass) - - StringRef getArgument() const final { - return "transform-dialect-drop-schedule"; - } - - StringRef getDescription() const final { - return "Drop the schedule from the operation"; - } - - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } - - void runOnOperation() override { - getOperation()->walk([&](Operation *nestedOp) { - if (isa(nestedOp)) - nestedOp->erase(); - if (isa<::mlir::transform::TransformOpInterface>(nestedOp)) { - nestedOp->erase(); - return WalkResult::skip(); - } - return WalkResult::advance(); - }); - // Remove potential empty module after cleanup. - getOperation()->walk([&](ModuleOp module) { - if (module.getBodyRegion().hasOneBlock() && module.getBody()->empty()) { - module->erase(); - return WalkResult::skip(); - } - return WalkResult::advance(); - }); - } -}; -} // namespace - -/// Create a Transform dialect interpreter pass. -std::unique_ptr -mlir::createTransformDialectInterpreterPass(llvm::StringRef transformFileName) { - return std::make_unique(transformFileName); -} - -/// Create a Linalg pass to drop the schedule from the module. -std::unique_ptr mlir::createDropSchedulePass() { - return std::make_unique(); -} - -/// Registration hook for the Linalg drop schedule from module pass. -void mlir::linalg::transform::registerDropSchedulePass() { - PassRegistration(); -} - -/// Registration hook for the Transform dialect interpreter pass. -void mlir::linalg::transform::registerTransformDialectInterpreterPass() { - PassRegistration(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Transforms/CMakeLists.txt b/integrations/tensorflow/iree-dialects/lib/Transforms/CMakeLists.txt deleted file mode 100644 index 00eda6f75ef9..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Transforms/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ - -add_mlir_library(IREEDialectsTransforms - Listener.cpp - ListenerCSE.cpp - ListenerGreedyPatternRewriteDriver.cpp - TransformMatchers.cpp - - LINK_LIBS PRIVATE - # TODO: break dialect dependency by implementing the transformation separately - # and registering it. - MLIRAsyncDialect - MLIRLinalgDialect - MLIRLinalgTransforms - - DEPENDS - mlir-headers - IREELinalgExtIncGen - IREELinalgExtInterfacesIncGen -) diff --git a/integrations/tensorflow/iree-dialects/lib/Transforms/Listener.cpp b/integrations/tensorflow/iree-dialects/lib/Transforms/Listener.cpp deleted file mode 100644 index 58bde22742cf..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Transforms/Listener.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/Listener.h" - -namespace mlir { - -//===----------------------------------------------------------------------===// -// RewriteListener -//===----------------------------------------------------------------------===// - -RewriteListener::~RewriteListener() = default; - -//===----------------------------------------------------------------------===// -// ListenerList -//===----------------------------------------------------------------------===// - -void ListenerList::notifyOperationInserted(Operation *op) { - for (RewriteListener *listener : listeners) - listener->notifyOperationInserted(op); -} - -void ListenerList::notifyBlockCreated(Block *block) { - for (RewriteListener *listener : listeners) - listener->notifyBlockCreated(block); -} - -void ListenerList::notifyRootReplaced(Operation *op, ValueRange newValues) { - for (RewriteListener *listener : listeners) - listener->notifyRootReplaced(op, newValues); -} - -void ListenerList::notifyOperationRemoved(Operation *op) { - for (RewriteListener *listener : listeners) - listener->notifyOperationRemoved(op); -} - -LogicalResult ListenerList::notifyMatchFailure( - Location loc, function_ref reasonCallback) { - bool failed = false; - for (RewriteListener *listener : listeners) - failed |= listener->notifyMatchFailure(loc, reasonCallback).failed(); - return failure(failed); -} - -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/lib/Transforms/ListenerCSE.cpp b/integrations/tensorflow/iree-dialects/lib/Transforms/ListenerCSE.cpp deleted file mode 100644 index e9ba6d7ad257..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Transforms/ListenerCSE.cpp +++ /dev/null @@ -1,448 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/ListenerCSE.h" - -#include - -#include "mlir/IR/Dominance.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/RecyclingAllocator.h" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// BEGIN copied from mlir/lib/Transforms/CSE.cpp -//===----------------------------------------------------------------------===// -namespace { -struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const Operation *opC) { - return OperationEquivalence::computeHash( - const_cast(opC), - /*hashOperands=*/OperationEquivalence::directHashValue, - /*hashResults=*/OperationEquivalence::ignoreHashValue, - OperationEquivalence::IgnoreLocations); - } - static bool isEqual(const Operation *lhsC, const Operation *rhsC) { - auto *lhs = const_cast(lhsC); - auto *rhs = const_cast(rhsC); - if (lhs == rhs) - return true; - if (lhs == getTombstoneKey() || lhs == getEmptyKey() || - rhs == getTombstoneKey() || rhs == getEmptyKey()) - return false; - - // If op has no regions, operation equivalence w.r.t operands alone is - // enough. - if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) { - return OperationEquivalence::isEquivalentTo( - const_cast(lhsC), const_cast(rhsC), - OperationEquivalence::exactValueMatch, - OperationEquivalence::ignoreValueEquivalence, - OperationEquivalence::IgnoreLocations); - } - - // If lhs or rhs does not have a single region with a single block, they - // aren't CSEed for now. - if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 || - !llvm::hasSingleElement(lhs->getRegion(0)) || - !llvm::hasSingleElement(rhs->getRegion(0))) - return false; - - // Compare the two blocks. - Block &lhsBlock = lhs->getRegion(0).front(); - Block &rhsBlock = rhs->getRegion(0).front(); - - // Don't CSE if number of arguments differ. - if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) - return false; - - // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s - // in `rhsBlock`. `Value`s from `lhsBlock` are the key. - DenseMap areEquivalentValues; - for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(), - rhs->getRegion(0).getArguments())) { - areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs); - } - - // Helper function to get the parent operation. - auto getParent = [](Value v) -> Operation * { - if (auto blockArg = v.dyn_cast()) - return blockArg.getParentBlock()->getParentOp(); - return v.getDefiningOp()->getParentOp(); - }; - - // Callback to compare if operands of ops in the region of `lhs` and `rhs` - // are equivalent. - auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult { - if (lhsValue == rhsValue) - return success(); - if (areEquivalentValues.lookup(lhsValue) == rhsValue) - return success(); - return failure(); - }; - - // Callback to compare if results of ops in the region of `lhs` and `rhs` - // are equivalent. - auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult { - if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) { - auto insertion = areEquivalentValues.insert({lhsResult, rhsResult}); - return success(insertion.first->second == rhsResult); - } - return success(); - }; - - return OperationEquivalence::isEquivalentTo( - const_cast(lhsC), const_cast(rhsC), - mapOperands, mapResults, OperationEquivalence::IgnoreLocations); - } -}; -} // namespace - -namespace { -/// Simple common sub-expression elimination. -//===----------------------------------------------------------------------===// -// END copied from mlir/lib/Transforms/CSE.cpp -//===----------------------------------------------------------------------===// -/// Copy of CSE::runOnOperation, without the pass baggage. -// struct CSE : public impl::CSEBase { -struct CSE { - //===----------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/CSE.cpp - //===----------------------------------------------------------------------===// - /// Shared implementation of operation elimination and scoped map - /// definitions. - using AllocatorTy = llvm::RecyclingAllocator< - llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal>; - using ScopedMapTy = llvm::ScopedHashTable; - - /// Cache holding MemoryEffects information between two operations. The - /// first operation is stored has the key. The second operation is stored - /// inside a pair in the value. The pair also hold the MemoryEffects between - /// those two operations. If the MemoryEffects is nullptr then we assume - /// there is no operation with MemoryEffects::Write between the two - /// operations. - using MemEffectsCache = - DenseMap>; - - /// Represents a single entry in the depth first traversal of a CFG. - struct CFGStackNode { - CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) - : scope(knownValues), node(node), childIterator(node->begin()) {} - - /// Scope for the known values. - ScopedMapTy::ScopeTy scope; - - DominanceInfoNode *node; - DominanceInfoNode::const_iterator childIterator; - - /// If this node has been fully processed yet or not. - bool processed = false; - }; - - /// Attempt to eliminate a redundant operation. Returns success if the - /// operation was marked for removal, failure otherwise. - LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op, - bool hasSSADominance); - void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); - void simplifyRegion(ScopedMapTy &knownValues, Region ®ion); - - // void runOnOperation() override; - void doItOnOperation(Operation *rootOp, DominanceInfo *domInfo, - RewriteListener *listener); - -private: - void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, - Operation *existing, bool hasSSADominance); - - /// Check if there is side-effecting operations other than the given effect - /// between the two operations. - bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); - - /// Operations marked as dead and to be erased. - std::vector opsToErase; - DominanceInfo *domInfo = nullptr; - MemEffectsCache memEffectsCache; - //===----------------------------------------------------------------------===// - // END copied from mlir/lib/Transforms/CSE.cpp - //===----------------------------------------------------------------------===// - /// An optional listener to notify of replaced or erased operations. - RewriteListener *listener; - int64_t numDCE = 0, numCSE = 0; - - //===----------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/CSE.cpp - //===----------------------------------------------------------------------===// -}; -} // namespace - -void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, - Operation *existing, bool hasSSADominance) { - // If we find one then replace all uses of the current operation with the - // existing one and mark it for deletion. We can only replace an operand in - // an operation if it has not been visited yet. - if (hasSSADominance) { - // If the region has SSA dominance, then we are guaranteed to have not - // visited any use of the current operation. - //===----------------------------------------------------------------------===// - // END copied from mlir/lib/Transforms/CSE.cpp - //===----------------------------------------------------------------------===// - if (listener) - listener->notifyRootReplaced(op, existing->getResults()); - //===----------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/CSE.cpp - //===----------------------------------------------------------------------===// - op->replaceAllUsesWith(existing); - opsToErase.push_back(op); - } else { - // When the region does not have SSA dominance, we need to check if we - // have visited a use before replacing any use. - for (auto it : llvm::zip(op->getResults(), existing->getResults())) { - std::get<0>(it).replaceUsesWithIf( - std::get<1>(it), [&](OpOperand &operand) { - return !knownValues.count(operand.getOwner()); - }); - } - - // There may be some remaining uses of the operation. - if (op->use_empty()) - opsToErase.push_back(op); - } - - // If the existing operation has an unknown location and the current - // operation doesn't, then set the existing op's location to that of the - // current op. - if (existing->getLoc().isa() && !op->getLoc().isa()) - existing->setLoc(op->getLoc()); - - ++numCSE; -} - -bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) { - assert(fromOp->getBlock() == toOp->getBlock()); - assert( - isa(fromOp) && - cast(fromOp).hasEffect() && - isa(toOp) && - cast(toOp).hasEffect()); - Operation *nextOp = fromOp->getNextNode(); - auto result = - memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr)); - if (result.second) { - auto memEffectsCachePair = result.first->second; - if (memEffectsCachePair.second == nullptr) { - // No MemoryEffects::Write has been detected until the cached operation. - // Continue looking from the cached operation to toOp. - nextOp = memEffectsCachePair.first; - } else { - // MemoryEffects::Write has been detected before so there is no need to - // check further. - return true; - } - } - while (nextOp && nextOp != toOp) { - auto nextOpMemEffects = dyn_cast(nextOp); - // TODO: Do we need to handle other effects generically? - // If the operation does not implement the MemoryEffectOpInterface we - // conservatively assumes it writes. - if ((nextOpMemEffects && - nextOpMemEffects.hasEffect()) || - !nextOpMemEffects) { - result.first->second = - std::make_pair(nextOp, MemoryEffects::Write::get()); - return true; - } - nextOp = nextOp->getNextNode(); - } - result.first->second = std::make_pair(toOp, nullptr); - return false; -} - -/// Attempt to eliminate a redundant operation. -LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op, - bool hasSSADominance) { - // Don't simplify terminator operations. - if (op->hasTrait()) - return failure(); - - // If the operation is already trivially dead just add it to the erase list. - if (isOpTriviallyDead(op)) { - opsToErase.push_back(op); - ++numDCE; - return success(); - } - - // Don't simplify operations with nested blocks. We don't currently model - // equality comparisons correctly among other things. It is also unclear - // whether we would want to CSE such operations. - if (!(op->getNumRegions() == 0 || - (op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0))))) - return failure(); - - // Some simple use case of operation with memory side-effect are dealt with - // here. Operations with no side-effect are done after. - if (!isMemoryEffectFree(op)) { - auto memEffects = dyn_cast(op); - // TODO: Only basic use case for operations with MemoryEffects::Read can - // be eleminated now. More work needs to be done for more complicated - // patterns and other side-effects. - if (!memEffects || !memEffects.onlyHasEffect()) - return failure(); - - // Look for an existing definition for the operation. - if (auto *existing = knownValues.lookup(op)) { - if (existing->getBlock() == op->getBlock() && - !hasOtherSideEffectingOpInBetween(existing, op)) { - // The operation that can be deleted has been reach with no - // side-effecting operations in between the existing operation and - // this one so we can remove the duplicate. - replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); - return success(); - } - } - knownValues.insert(op, op); - return failure(); - } - - // Look for an existing definition for the operation. - if (auto *existing = knownValues.lookup(op)) { - replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); - ++numCSE; - return success(); - } - - // Otherwise, we add this operation to the known values map. - knownValues.insert(op, op); - return failure(); -} - -void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb, - bool hasSSADominance) { - for (auto &op : *bb) { - // If the operation is simplified, we don't process any held regions. - if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance))) - continue; - - // Most operations don't have regions, so fast path that case. - if (op.getNumRegions() == 0) - continue; - - // If this operation is isolated above, we can't process nested regions - // with the given 'knownValues' map. This would cause the insertion of - // implicit captures in explicit capture only regions. - if (op.mightHaveTrait()) { - ScopedMapTy nestedKnownValues; - for (auto ®ion : op.getRegions()) - simplifyRegion(nestedKnownValues, region); - continue; - } - - // Otherwise, process nested regions normally. - for (auto ®ion : op.getRegions()) - simplifyRegion(knownValues, region); - } - // Clear the MemoryEffects cache since its usage is by block only. - memEffectsCache.clear(); -} - -void CSE::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { - // If the region is empty there is nothing to do. - if (region.empty()) - return; - - bool hasSSADominance = domInfo->hasSSADominance(®ion); - - // If the region only contains one block, then simplify it directly. - if (region.hasOneBlock()) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(knownValues, ®ion.front(), hasSSADominance); - return; - } - - // If the region does not have dominanceInfo, then skip it. - // TODO: Regions without SSA dominance should define a different - // traversal order which is appropriate and can be used here. - if (!hasSSADominance) - return; - - // Note, deque is being used here because there was significant performance - // gains over vector when the container becomes very large due to the - // specific access patterns. If/when these performance issues are no - // longer a problem we can change this to vector. For more information see - // the llvm mailing list discussion on this: - // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html - std::deque> stack; - - // Process the nodes of the dom tree for this region. - stack.emplace_back(std::make_unique( - knownValues, domInfo->getRootNode(®ion))); - - while (!stack.empty()) { - auto ¤tNode = stack.back(); - - // Check to see if we need to process this node. - if (!currentNode->processed) { - currentNode->processed = true; - simplifyBlock(knownValues, currentNode->node->getBlock(), - hasSSADominance); - } - - // Otherwise, check to see if we need to process a child node. - if (currentNode->childIterator != currentNode->node->end()) { - auto *childNode = *(currentNode->childIterator++); - stack.emplace_back( - std::make_unique(knownValues, childNode)); - } else { - // Finally, if the node and all of its children have been processed - // then we delete the node. - stack.pop_back(); - } - } -} - -//===----------------------------------------------------------------------===// -// END copied from mlir/lib/Transforms/CSE.cpp -//===----------------------------------------------------------------------===// - -/// Copy of CSE::runOnOperation, without the pass baggage. -void CSE::doItOnOperation(Operation *rootOp, DominanceInfo *domInfo, - RewriteListener *listener) { - /// A scoped hash table of defining operations within a region. - ScopedMapTy knownValues; - this->domInfo = domInfo; - this->listener = listener; - - for (auto ®ion : rootOp->getRegions()) - simplifyRegion(knownValues, region); - - /// Erase any operations that were marked as dead during simplification. - for (auto *op : opsToErase) { - if (listener) - listener->notifyOperationRemoved(op); - op->erase(); - } - opsToErase.clear(); -} - -/// Run CSE on the provided operation -LogicalResult mlir::eliminateCommonSubexpressions(Operation *op, - DominanceInfo *domInfo, - RewriteListener *listener) { - assert(op->hasTrait() && - "can only do CSE on isolated-from-above ops"); - Optional defaultDomInfo; - if (domInfo == nullptr) { - defaultDomInfo.emplace(op); - domInfo = &*defaultDomInfo; - } - CSE().doItOnOperation(op, domInfo, listener); - return success(); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp b/integrations/tensorflow/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp deleted file mode 100644 index cc38f93e9166..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp +++ /dev/null @@ -1,469 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h" - -#include "iree-dialects/Transforms/Listener.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Rewrite/PatternApplicator.h" -#include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ScopedPrinter.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -#define DEBUG_TYPE "listener-greedy-rewriter" - -//===----------------------------------------------------------------------===// -// GreedyPatternRewriteDriver -//===----------------------------------------------------------------------===// - -namespace { -/// This is a worklist-driven driver for the PatternMatcher, which repeatedly -/// applies the locally optimal patterns in a roughly "bottom up" way. -class GreedyPatternRewriteDriver : public RewriteListener { -public: - explicit GreedyPatternRewriteDriver(Operation *rootOp, - const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, - RewriteListener *listener); - //===--------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - - /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef regions); - - /// Add the given operation to the worklist. - void addToWorklist(Operation *op); - - /// Pop the next operation from the worklist. - Operation *popFromWorklist(); - - /// If the specified operation is in the worklist, remove it. - void removeFromWorklist(Operation *op); - -protected: - // Implement the hook for inserting operations, and make sure that newly - // inserted ops are added to the worklist for processing. - void notifyOperationInserted(Operation *op) override; - - // Look over the provided operands for any defining operations that should - // be re-added to the worklist. This function should be called when an - // operation is modified or removed, as it may trigger further - // simplifications. - void addOperandsToWorklist(ValueRange operands); - - // If an operation is about to be removed, make sure it is not in our - // worklist anymore because we'd get dangling references to it. - void notifyOperationRemoved(Operation *op) override; - - // When the root of a pattern is about to be replaced, it can trigger - // simplifications to its users - make sure to add them to the worklist - // before the root is changed. - void notifyRootReplaced(Operation *op, ValueRange replacement) override; - - //===--------------------------------------------------------------------===// - // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - // This seems unused - /// PatternRewriter hook for erasing a dead operation. - // void eraseOp(Operation *op) override; - // //===-----------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - - /// PatternRewriter hook for notifying match failure reasons. - LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) override; - - /// The low-level pattern applicator. - PatternApplicator matcher; - - /// The worklist for this transformation keeps track of the operations that - /// need to be revisited, plus their index in the worklist. This allows us to - /// efficiently remove operations from the worklist when they are erased, even - /// if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; - - /// Non-pattern based folder for operations. - OperationFolder folder; - -private: - /// Configuration information for how to simplify. - GreedyRewriteConfig config; - - //===--------------------------------------------------------------------===// - // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - /// The pattern rewriter to use. - PatternRewriterListener rewriter; - /// The operation under which all processed ops must be nested. - Operation *rootOp; - //===--------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - -#ifndef NDEBUG - /// A logger used to emit information during the application process. - llvm::ScopedPrinter logger{llvm::dbgs()}; -#endif -}; -} // namespace - -//===----------------------------------------------------------------------===// -// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp -//===----------------------------------------------------------------------===// -GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( - Operation *rootOp, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, RewriteListener *listener) - : matcher(patterns), folder(rootOp->getContext()), config(config), - rewriter(rootOp->getContext()), rootOp(rootOp) { - // Add self as a listener and the user-provided listener. - rewriter.addListener(this); - if (listener) - rewriter.addListener(listener); - //===--------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - - worklist.reserve(64); - - // Apply a simple cost model based solely on pattern benefit. - matcher.applyDefaultCostModel(); -} - -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { -#ifndef NDEBUG - const char *logLineComment = - "//===-------------------------------------------===//\n"; - - /// A utility function to log a process result for the given reason. - auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) { - logger.unindent(); - logger.startLine() << "} -> " << result; - if (!msg.isTriviallyEmpty()) - logger.getOStream() << " : " << msg; - logger.getOStream() << "\n"; - }; - auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) { - logResult(result, msg); - logger.startLine() << logLineComment; - }; -#endif - - auto insertKnownConstant = [&](Operation *op) { - // Check for existing constants when populating the worklist. This avoids - // accidentally reversing the constant order during processing. - Attribute constValue; - if (matchPattern(op, m_Constant(&constValue))) - if (!folder.insertKnownConstant(op, constValue)) - return true; - return false; - }; - - bool changed = false; - unsigned iteration = 0; - do { - worklist.clear(); - worklistMap.clear(); - - if (!config.useTopDownTraversal) { - // Add operations to the worklist in postorder. - for (auto ®ion : regions) { - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) - addToWorklist(op); - }); - } - } else { - // Add all nested operations to the worklist in preorder. - for (auto ®ion : regions) { - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) { - worklist.push_back(op); - return WalkResult::advance(); - } - return WalkResult::skip(); - }); - } - - // Reverse the list so our pop-back loop processes them in-order. - std::reverse(worklist.begin(), worklist.end()); - // Remember the reverse index. - for (size_t i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; - } - - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; - - changed = false; - while (!worklist.empty()) { - auto *op = popFromWorklist(); - - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; - - LLVM_DEBUG({ - logger.getOStream() << "\n"; - logger.startLine() << logLineComment; - logger.startLine() << "Processing operation : '" << op->getName() - << "'(" << op << ") {\n"; - logger.indent(); - - // If the operation has no regions, just print it here. - if (op->getNumRegions() == 0) { - op->print( - logger.startLine(), - OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); - logger.getOStream() << "\n\n"; - } - }); - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); - changed = true; - - LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); - continue; - } - - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (auto result : op->getResults()) - for (auto *userOp : result.getUsers()) - addToWorklist(userOp); - - notifyOperationRemoved(op); - }; - - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; - - // Try to fold this op. - bool inPlaceUpdate; - if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, - &inPlaceUpdate)))) { - LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - - changed = true; - if (!inPlaceUpdate) - continue; - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do - // here. -#ifndef NDEBUG - auto canApply = [&](const Pattern &pattern) { - LLVM_DEBUG({ - logger.getOStream() << "\n"; - logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" - << op->getName() << " -> ("; - llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); - logger.getOStream() << ")' {\n"; - logger.indent(); - }); - return true; - }; - auto onFailure = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("failure", "pattern failed to match")); - }; - auto onSuccess = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("success", "pattern applied successfully")); - return success(); - }; - - LogicalResult matchResult = - //===------------------------------------------------------------===// - // BEGIN single line change from - // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - // END - //===------------------------------------------------------------===// - matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess); - if (succeeded(matchResult)) - LLVM_DEBUG(logResultWithLine("success", "pattern matched")); - else - LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); -#else - //===----------------------------------------------------------------===// - // BEGIN single line change from - // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - // END - //===----------------------------------------------------------------===// - LogicalResult matchResult = matcher.matchAndRewrite(op, rewriter); -#endif - changed |= succeeded(matchResult); - } - - // After applying patterns, make sure that the CFG of each of the regions - // is kept up to date. - if (config.enableRegionSimplification) - //===----------------------------------------------------------------===// - // BEGIN single line change from - // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - // END - //===----------------------------------------------------------------===// - changed |= succeeded(simplifyRegions(rewriter, regions)); - } while (changed && - (iteration++ < config.maxIterations || - config.maxIterations == GreedyRewriteConfig::kNoIterationLimit)); - - // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return !changed; -} - -void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { - // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) - return; - //===--------------------------------------------------------------------===// - // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - // Enforce nested under constraint before adding to worklist. - if (!rootOp->isProperAncestor(op)) - return; - //===--------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - - worklistMap[op] = worklist.size(); - worklist.push_back(op); -} - -Operation *GreedyPatternRewriteDriver::popFromWorklist() { - auto *op = worklist.back(); - worklist.pop_back(); - - // This operation is no longer in the worklist, keep worklistMap up to date. - if (op) - worklistMap.erase(op); - return op; -} - -void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) { - auto it = worklistMap.find(op); - if (it != worklistMap.end()) { - assert(worklist[it->second] == op && "malformed worklist data structure"); - worklist[it->second] = nullptr; - worklistMap.erase(it); - } -} - -void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { - LLVM_DEBUG({ - logger.startLine() << "** Insert : '" << op->getName() << "'(" << op - << ")\n"; - }); - addToWorklist(op); -} - -void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) { - for (Value operand : operands) { - // If the use count of this operand is now < 2, we re-add the defining - // operation to the worklist. - // TODO: This is based on the fact that zero use operations - // may be deleted, and that single use values often have more - // canonicalization opportunities. - if (!operand || (!operand.use_empty() && !operand.hasOneUse())) - continue; - if (auto *defOp = operand.getDefiningOp()) - addToWorklist(defOp); - } -} - -void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { - addOperandsToWorklist(op->getOperands()); - op->walk([this](Operation *operation) { - removeFromWorklist(operation); - folder.notifyRemoval(operation); - }); -} - -void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op, - ValueRange replacement) { - LLVM_DEBUG({ - logger.startLine() << "** Replace : '" << op->getName() << "'(" << op - << ")\n"; - }); - for (auto result : op->getResults()) - for (auto *user : result.getUsers()) - addToWorklist(user); -} - -//===----------------------------------------------------------------------===// -// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp -//===----------------------------------------------------------------------===// -// This seems unused -// void GreedyPatternRewriteDriver::eraseOp(Operation *op) { -// LLVM_DEBUG({ -// logger.startLine() << "** Erase : '" << op->getName() << "'(" << op -// << ")\n"; -// }); -// PatternRewriter::eraseOp(op); -// } -//===----------------------------------------------------------------------===// -// BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp -//===----------------------------------------------------------------------===// - -LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( - Location loc, function_ref reasonCallback) { - LLVM_DEBUG({ - Diagnostic diag(loc, DiagnosticSeverity::Remark); - reasonCallback(diag); - logger.startLine() << "** Failure : " << diag.str() << "\n"; - }); - return failure(); -} - -/// Rewrite the regions of the specified operation, which must be isolated from -/// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return success if no more patterns can be matched -/// in the result operation regions. Note: This does not apply patterns to the -/// top-level operation itself. -/// -//===----------------------------------------------------------------------===// -// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp -//===----------------------------------------------------------------------===// -LogicalResult mlir::applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, RewriteListener *listener) { - if (op->getRegions().empty()) - return success(); - - // Start the pattern driver. - GreedyPatternRewriteDriver driver(op, patterns, config, listener); - auto regions = op->getRegions(); - //===--------------------------------------------------------------------===// - // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp - //===--------------------------------------------------------------------===// - bool converged = driver.simplify(regions); - LLVM_DEBUG(if (!converged) { - llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << config.maxIterations << " times\n"; - }); - return success(converged); -} diff --git a/integrations/tensorflow/iree-dialects/lib/Transforms/TransformMatchers.cpp b/integrations/tensorflow/iree-dialects/lib/Transforms/TransformMatchers.cpp deleted file mode 100644 index eda5d55fbfe3..000000000000 --- a/integrations/tensorflow/iree-dialects/lib/Transforms/TransformMatchers.cpp +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/TransformMatchers.h" - -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" - -using namespace mlir; - -//===---------------------------------------------------------------------===// -// StructuredOpMatcher and friends. -//===---------------------------------------------------------------------===// - -bool transform_ext::StructuredOpMatcher::match(Operation *op) { - auto linalgOp = dyn_cast(op); - if (!linalgOp) - return false; - - if (!llvm::all_of(predicates, [linalgOp](const PredicateFn &fn) { - return fn(linalgOp); - })) { - return false; - } - - captured = linalgOp; - return true; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, ShapeKind kind) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - SmallVector shape = linalgOp.getStaticLoopRanges(); - int64_t transformedDimension = - dimension >= 0 ? dimension : shape.size() + dimension; - if (transformedDimension >= shape.size()) - return false; - return ShapedType::isDynamic(shape[transformedDimension]) ^ - (kind == ShapeKind::Static); - }); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - SmallVector shape = linalgOp.getStaticLoopRanges(); - return llvm::all_of(shape, [=](int64_t dimension) { - return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static); - }); - }); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, - utils::IteratorType kind) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - unsigned rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) - return false; - - utils::IteratorType iteratorKind = - linalgOp.getIteratorTypesArray()[transformedDimension]; - return iteratorKind == kind; - }); - return *this; -} -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - return llvm::all_of( - linalgOp.getIteratorTypesArray(), - [=](utils::IteratorType iteratorType) { return iteratorType == kind; }); - }); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, - DivisibleBy divisibleBy) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - unsigned rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) - return false; - - int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension]; - return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0); - }); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) - return false; - } - return true; - }); - return *this; -} - -/// Traverses the transitive sources of `val` until it reaches an operation that -/// is not a known "subset-like" operation, i.e. `extract_slice` or -/// `foreach_thread`. -static Operation *traverseSubsetsBackwards(Value val) { - do { - Operation *op = val.getDefiningOp(); - if (!op) { - // TODO: This should likely be done via RegionBranchOpInterface as a sort - // of data flow analysis. - auto bbArg = val.cast(); - Operation *blockOp = bbArg.getOwner()->getParentOp(); - assert(blockOp && "detached block"); - if (auto loop = dyn_cast(blockOp)) { - val = loop.getTiedOpOperand(bbArg)->get(); - continue; - } - return blockOp; - } - - // TODO: We may eventually want a "subset-like" interface that we can use to - // traverse ops here and in post-canonicalization replacement - // identification. - if (auto extractSlice = dyn_cast(op)) { - val = extractSlice.getSource(); - continue; - } - return op; - } while (true); -} - -/// Greedily traverses the transitive uses of `val` until it reaches an -/// operation that is not a known "subset-like" operation, i.e. `extract_slice` -/// or `foreach_thread`. -static Operation *traverseSubsetsForwardAnyUse(Value val) { - do { - for (OpOperand &use : val.getUses()) { - Operation *user = use.getOwner(); - if (auto loop = dyn_cast(user)) { - auto range = loop.getOutputBlockArguments(); - auto it = llvm::find_if(range, [&](BlockArgument bbarg) { - return loop.getTiedOpOperand(bbarg) != &use; - }); - if (it == range.end()) - return user; - val = *it; - continue; - } - if (auto slice = dyn_cast(user)) { - val = slice.getResult(); - continue; - } - return user; - } - } while (true); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, SubsetOf subset) { - // Implementation note: SubsetOf must *not* be passed by-reference because - // it is typically a temporary constructed within the argument of a function - // call, but it will be used in the lambda that outlives the temporary. The - // lambda itself must capture by value for the same reason. - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp.getNumDpsInputs() + position; - if (transformedPosition >= linalgOp.getNumDpsInputs()) - return false; - - Operation *producer = traverseSubsetsBackwards( - linalgOp.getDpsInputOperand(transformedPosition)->get()); - return subset.matcher.match(producer); - }); - recordNestedMatcher(subset.matcher); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - for (OpOperand *operand : linalgOp.getDpsInitOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) - return false; - } - return true; - }); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - ElementTypeBitWidth width) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - int64_t updatedPosition = - position >= 0 ? position : linalgOp.getNumDpsInits() + position; - if (updatedPosition >= linalgOp.getNumDpsInits()) - return false; - auto shapedType = linalgOp.getDpsInitOperand(updatedPosition) - ->get() - .getType() - .dyn_cast(); - return shapedType && shapedType.getElementType().isIntOrFloat() && - shapedType.getElementType().getIntOrFloatBitWidth() == width.value; - }); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - SingleCombinerReduction tag) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - int64_t updatedPosition = - position >= 0 ? position : linalgOp.getNumDpsInits() + position; - if (updatedPosition >= linalgOp.getNumDpsInits()) - return false; - SmallVector combinerOps; - return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition, - combinerOps) && - llvm::hasSingleElement(combinerOps); - }); - return *this; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, SubsetOf subset) { - // Implementation note: SubsetOf must *not* be passed by-reference because - // it is typically a temporary constructed within the argument of a function - // call, but it will be used in the lambda that outlives the temporary. The - // lambda itself must capture by value for the same reason. - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp.getNumDpsInputs() + position; - if (transformedPosition >= linalgOp.getNumDpsInputs()) - return false; - - Operation *producer = traverseSubsetsBackwards( - linalgOp.getDpsInitOperand(transformedPosition)->get()); - return subset.matcher.match(producer); - }); - recordNestedMatcher(subset.matcher); - return *this; -} - -transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::result( - int64_t position, HasAnyUse tag, SubsetOf subset, OptionalMatch optional) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp->getNumResults() + position; - if (transformedPosition >= linalgOp->getNumResults()) - return false; - - Operation *user = - traverseSubsetsForwardAnyUse(linalgOp->getResult(transformedPosition)); - return subset.matcher.match(user) || optional.value; - }); - return *this; -} - -//===---------------------------------------------------------------------===// -// MatchCallbackResult. -//===---------------------------------------------------------------------===// - -ArrayRef -transform_ext::MatchCallbackResult::getPayloadGroup(unsigned position) const { - assert(position < payloadGroupLengths.size()); - int64_t start = 0; - for (unsigned i = 0; i < position; ++i) { - start += payloadGroupLengths[i]; - } - return llvm::makeArrayRef(payloadOperations) - .slice(start, payloadGroupLengths[position]); -} - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -static constexpr unsigned kCudaWarpSize = 32; - -void transform_ext::makeReductionMatcher( - transform_ext::StructuredOpMatcher &reduction, - transform_ext::StructuredOpMatcher &fill, - transform_ext::StructuredOpMatcher &leading, - transform_ext::StructuredOpMatcher &trailing) { - fill = m_StructuredOp(); - trailing = m_StructuredOp() - .input(AllOperands(), IsPermutation()) - .output(AllOperands(), IsPermutation()) - .input(NumEqualsTo(1)) - .output(NumEqualsTo(1)); - leading = trailing; - reduction = m_StructuredOp() - .dim(AllDims(), ShapeKind::Static) - .dim(-1, utils::IteratorType::reduction) - .dim(-1, DivisibleBy(kCudaWarpSize)) - // Can be extended to projected permutation with broadcast. - .input(AllOperands(), IsPermutation()) - // TODO: we want to accept any input position here. - .input(0, leading, OptionalMatch()) - .output(NumEqualsTo(1)) - .output(0, fill) - // Only single combiner over 32 bits for now due to - // reduction distribution. - .output(0, ElementTypeBitWidth(32)) - .output(0, SingleCombinerReduction()) - .result(0, HasAnyUse(), trailing, OptionalMatch()); -} - -void transform_ext::makeSplitReductionMatcher( - transform_ext::StructuredOpMatcher ¶llel_reduction, - transform_ext::StructuredOpMatcher &combiner_reduction, - transform_ext::StructuredOpMatcher ¶llel_fill, - transform_ext::StructuredOpMatcher &original_fill, - transform_ext::StructuredOpMatcher &leading, - transform_ext::StructuredOpMatcher &trailing) { - original_fill = m_StructuredOp(); - parallel_fill = m_StructuredOp(); - trailing = m_StructuredOp() - .input(AllOperands(), IsPermutation()) - .output(AllOperands(), IsPermutation()) - .input(NumEqualsTo(1)) - .output(NumEqualsTo(1)); - leading = m_StructuredOp() - .input(AllOperands(), IsPermutation()) - .output(AllOperands(), IsPermutation()) - .input(NumEqualsTo(1)) - .output(NumEqualsTo(1)); - parallel_reduction = m_StructuredOp() - .dim(AllDims(), ShapeKind::Static) - .dim(-1, utils::IteratorType::reduction) - .input(AllOperands(), IsPermutation()) - // TODO: we want to accept any input position here. - .input(0, leading, OptionalMatch()) - .output(NumEqualsTo(1)) - .output(0, parallel_fill); - combiner_reduction = - m_StructuredOp() - .dim(AllDims(), ShapeKind::Static) - .dim(-1, utils::IteratorType::reduction) - // Can be extended to projected permutation with broadcast. - .input(AllOperands(), IsPermutation()) - .input(0, SubsetOf(parallel_reduction)) - .output(NumEqualsTo(1)) - .output(0, SubsetOf(original_fill)) - .output(0, ElementTypeBitWidth(32)) - .output(0, SingleCombinerReduction()) - .result(0, HasAnyUse(), SubsetOf(trailing), OptionalMatch()); -} diff --git a/integrations/tensorflow/iree-dialects/python/CMakeLists.txt b/integrations/tensorflow/iree-dialects/python/CMakeLists.txt index ad42597f6536..b9127f27bede 100644 --- a/integrations/tensorflow/iree-dialects/python/CMakeLists.txt +++ b/integrations/tensorflow/iree-dialects/python/CMakeLists.txt @@ -23,33 +23,6 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME iree_input ) -declare_mlir_dialect_python_bindings( - ADD_TO_PARENT IREEDialectsPythonSources.Dialects - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler" - TD_FILE dialects/IreeLinalgExtBinding.td - SOURCES dialects/iree_linalg_ext.py - DIALECT_NAME iree_linalg_ext -) - -declare_mlir_dialect_python_bindings( - ADD_TO_PARENT IREEDialectsPythonSources.Dialects - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler" - TD_FILE dialects/LinalgTransformBinding.td - SOURCES dialects/iree_linalg_transform.py - dialects/_iree_linalg_transform_ops_ext.py - DIALECT_NAME iree_linalg_transform - ) - -declare_mlir_dialect_extension_python_bindings( - ADD_TO_PARENT IREEDialectsPythonSources.Dialects - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler" - TD_FILE dialects/IreeStructuredTransformOps.td - SOURCES - dialects/transform/iree_structured.py - dialects/_iree_structured_transform_ops_ext.py - DIALECT_NAME transform - EXTENSION_NAME iree_structured_transform) - ################################################################################ # Extensions ################################################################################ diff --git a/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp b/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp index 1a85609afbf8..f5497980b263 100644 --- a/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp +++ b/integrations/tensorflow/iree-dialects/python/IREEDialectsModule.cpp @@ -34,54 +34,4 @@ PYBIND11_MODULE(_ireeDialects, m) { } }, py::arg("context") = py::none(), py::arg("load") = true); - - //===--------------------------------------------------------------------===// - // IREELinalgExt - //===--------------------------------------------------------------------===// - auto iree_linalg_ext_m = m.def_submodule("iree_linalg_ext"); - iree_linalg_ext_m.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = mlirGetDialectHandle__iree_linalg_ext__(); - mlirDialectHandleRegisterDialect(handle, context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - - //===--------------------------------------------------------------------===// - // LinalgTransform - //===--------------------------------------------------------------------===// - auto iree_linalg_transform_m = m.def_submodule("iree_linalg_transform"); - mlirIREELinalgTransformRegisterPasses(); - iree_linalg_transform_m.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = - mlirGetDialectHandle__iree_linalg_transform__(); - mlirDialectHandleRegisterDialect(handle, context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - - //===--------------------------------------------------------------------===// - // TransformDialect - //===--------------------------------------------------------------------===// - auto transform_m = m.def_submodule("transform"); - mlirIREETransformRegisterPasses(); - - transform_m.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = mlirGetDialectHandle__transform__(); - mlirDialectHandleRegisterDialect(handle, context); - ireeRegisterTransformExtensions(context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - }, - py::arg("context") = py::none(), py::arg("load") = true); } diff --git a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeLinalgExtBinding.td b/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeLinalgExtBinding.td deleted file mode 100644 index da2ceaed9f8c..000000000000 --- a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeLinalgExtBinding.td +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef PYTHON_BINDINGS_IREE_LINALGEXT_OPS -#define PYTHON_BINDINGS_IREE_LINALGEXT_OPS - -include "mlir/Bindings/Python/Attributes.td" -include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td" - -#endif // PYTHON_BINDINGS_IREE_LINALGEXT_OPS diff --git a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td b/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td deleted file mode 100644 index 42c69e22ee82..000000000000 --- a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING -#define PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING - -include "mlir/Bindings/Python/Attributes.td" -include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td" - -#endif // PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING diff --git a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/LinalgTransformBinding.td b/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/LinalgTransformBinding.td deleted file mode 100644 index e908e7951599..000000000000 --- a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/LinalgTransformBinding.td +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef PYTHON_BINDINGS_IREE_LINALGTRANSFORM_BINDING -#define PYTHON_BINDINGS_IREE_LINALGTRANSFORM_BINDING - -include "mlir/Bindings/Python/Attributes.td" -include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td" - -#endif // PYTHON_BINDINGS_IREE_LINALGTRANSFORM_BINDING diff --git a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py b/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py deleted file mode 100644 index e315cb3573d9..000000000000 --- a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Disable PyType, it does not seem to like the specialization pattern used in -# MLIR. -# pytype: skip-file -try: - from .. import ir - from ..dialects import pdl - from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values - from typing import Optional, Sequence, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e -BoolArg = Optional[Union[bool, ir.BoolAttr]] -IntArg = Optional[Union[int, ir.IntegerAttr]] -IntListArg = Optional[Union[Sequence[int], ir.ArrayAttr]] -IntListListArg = Optional[Union[Sequence[Union[Sequence[int], ir.ArrayAttr]], - ir.ArrayAttr]] -StringArg = Optional[Union[str, ir.StringAttr]] -StringListArg = Optional[Union[Sequence[str], ir.ArrayAttr]] - - -def _defaulted_ensure(f): - - def inner(value, default=None): - assert value is not None or default is not None - return f(default if value is None else value) - - return inner - - -@_defaulted_ensure -def _ensure_int_array_attr(value: IntListArg): - i64 = ir.IntegerType.get_signless(64) - if isinstance(value, Sequence): - return ir.ArrayAttr.get([ir.IntegerAttr.get(i64, i) for i in value]) - return value - - -@_defaulted_ensure -def _ensure_string_array_attr(value: StringListArg): - if isinstance(value, Sequence): - return ir.ArrayAttr.get([ir.StringAttr.get(str(i)) for i in value]) - return value - - -@_defaulted_ensure -def _ensure_array_of_array_attr(value: IntListListArg): - if isinstance(value, Sequence): - return ir.ArrayAttr.get([_ensure_int_array_attr(inner) for inner in value]) - return value - - -@_defaulted_ensure -def _ensure_int_attr(value: IntArg): - if isinstance(value, int): - return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) - return value - - -@_defaulted_ensure -def _ensure_bool_attr(value: BoolArg): - if isinstance(value, bool): - return ir.BoolAttr.get(value) - return value - - -@_defaulted_ensure -def _ensure_string_attr(value: StringArg): - if isinstance(value, str): - return ir.StringAttr.get(value) - return value - - -def _count_expected_loops(tile_sizes: ir.ArrayAttr) -> int: - # Number of loops = number of tile sizes != 0 - zero = _ensure_int_attr(0) - return len(list(tile_sizes)) - list(tile_sizes).count(zero) - - -##===----------------------------------------------------------------------===## -## LinalgExt specific transforms -##===----------------------------------------------------------------------===## - - -class TileToLinalgExtTileOp: - """Specialization for the TileToLinalgExtTileOp class.""" - - def __init__(self, - target: Union[ir.Value, ir.Operation, ir.OpView], - *, - sizes: IntListArg = None, - loc=None, - ip=None): - sizes = _ensure_int_array_attr(sizes, []) - operation_type = pdl.OperationType.get() - super().__init__(operation_type, target, sizes, loc=loc, ip=ip) - - -class FuseIntoContainingOp: - """Specialization for the FuseIntoContainingOp class.""" - - def __init__(self, - producerOp: Union[ir.Value, ir.Operation, ir.OpView], - *, - containingOp: Union[ir.Value, ir.Operation, ir.OpView], - loc=None, - ip=None): - super().__init__([], producerOp, containingOp, loc=loc, ip=ip) diff --git a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py b/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py deleted file mode 100644 index 4a1b08b6c74e..000000000000 --- a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Disable PyType, it does not seem to like the specialization pattern used in -# MLIR. -# pytype: skip-file -try: - from ..ir import * - from ..dialects import pdl - from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values - from typing import Optional, overload, Sequence, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e -BoolArg = Optional[Union[bool, BoolAttr]] -IntListArg = Optional[Union[Sequence[int], ArrayAttr]] -StringArg = Optional[Union[str, StringAttr]] - - -def _defaulted_ensure(f): - - def inner(value, default=None): - assert value is not None or default is not None - return f(default if value is None else value) - - return inner - - -@_defaulted_ensure -def _ensure_int_array_attr(value: IntListArg): - i64 = IntegerType.get_signless(64) - if isinstance(value, Sequence): - return ArrayAttr.get([IntegerAttr.get(i64, i) for i in value]) - return value - - -@_defaulted_ensure -def _ensure_bool_attr(value: BoolArg): - if isinstance(value, bool): - return BoolAttr.get(value) - return value - - -@_defaulted_ensure -def _ensure_string_attr(value: StringArg): - if isinstance(value, str): - return StringAttr.get(value) - return value - - -class CanonicalizedSequenceOp: - - @overload - def __init__(self, resultsOrRoot: Sequence[Type], - optionalRoot: Optional[Union[Operation, Value]]): - ... - - @overload - def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]], - optionalRoot: NoneType): - ... - - def __init__(self, resultsOrRoot=None, optionalRoot=None): - results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] - root = (resultsOrRoot - if not isinstance(resultsOrRoot, Sequence) else optionalRoot) - root = _get_op_result_or_value(root) if root else None - super().__init__(results_=results, root=root) - self.regions[0].blocks.append(pdl.OperationType.get()) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - -class LowerVectorsOp: - """Specialization for the LowerVectorsOp class.""" - - def __init__(self, - *, - stages: IntListArg = None, - contraction_lowering: StringArg = None, - multireduction_lowering: StringArg = None, - split_transfers: StringArg = None, - unroll_vector_transfers: BoolArg = None, - transpose_lowering: StringArg = None, - transpose_avx2_lowering: BoolArg = None, - loc=None, - ip=None): - stages = _ensure_int_array_attr(stages, [0, 1, 2, 3, 4, 5, 6]) - contraction_lowering = _ensure_string_attr(contraction_lowering, - "outerproduct") - multireduction_lowering = _ensure_string_attr(multireduction_lowering, - "innerparallel") - split_transfers = _ensure_string_attr(split_transfers, "linalg-copy") - unroll_vector_transfers = _ensure_bool_attr(unroll_vector_transfers, True) - transpose_lowering = _ensure_string_attr(transpose_lowering, "eltwise") - transpose_avx2_lowering = _ensure_bool_attr(transpose_avx2_lowering, False) - super().__init__(stages=stages, - contraction_lowering=contraction_lowering, - multireduction_lowering=multireduction_lowering, - split_transfers=split_transfers, - unroll_vector_transfers=unroll_vector_transfers, - transpose_lowering=transpose_lowering, - transpose_avx2_lowering=transpose_avx2_lowering, - loc=loc, - ip=ip) - - -class LowerToLLVMOp: - """Specialization for the LowerToLLVMOp class.""" - - def __init__(self, - *, - reassociate_fp_reductions: BoolArg = None, - enable_index_optimizations: BoolArg = None, - enable_arm_neon: BoolArg = None, - enable_arm_sve: BoolArg = None, - enable_amx: BoolArg = None, - enable_x86vector: BoolArg = None, - enable_async: BoolArg = None, - loc=None, - ip=None): - super().__init__(_ensure_bool_attr(reassociate_fp_reductions, False), - _ensure_bool_attr(enable_index_optimizations, False), - _ensure_bool_attr(enable_arm_neon, False), - _ensure_bool_attr(enable_arm_sve, False), - _ensure_bool_attr(enable_amx, False), - _ensure_bool_attr(enable_x86vector, False), - _ensure_bool_attr(enable_async, False), - loc=loc, - ip=ip) diff --git a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_ext.py b/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_ext.py deleted file mode 100644 index 01fb4305a686..000000000000 --- a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_ext.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from ._iree_linalg_ext_ops_gen import * -from .._mlir_libs._ireeDialects.iree_linalg_ext import * diff --git a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_transform.py b/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_transform.py deleted file mode 100644 index 8bb7799d5c04..000000000000 --- a/integrations/tensorflow/iree-dialects/python/iree/compiler/dialects/iree_linalg_transform.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from ._iree_linalg_transform_ops_gen import * -from .._mlir_libs._ireeDialects.iree_linalg_transform import * diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/canonicalize.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/canonicalize.mlir deleted file mode 100644 index fec40b55f541..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/canonicalize.mlir +++ /dev/null @@ -1,42 +0,0 @@ -// RUN: iree-dialects-opt --canonicalize --split-input-file %s | FileCheck %s - -func.func @tensor_cast(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - - %casted_arg0 = tensor.cast %arg0 : tensor<3x5xi32> to tensor - %casted_init = tensor.cast %init : tensor<3x5xi32> to tensor - - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%casted_arg0 : tensor) - outs(%casted_init : tensor) : tensor - - %1 = tensor.cast %0 : tensor to tensor<3x5xi32> - - return %1: tensor<3x5xi32> -} -// CHECK-LABEL: func.func @tensor_cast( -// CHECK: iree_linalg_ext.reverse -// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>) -// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>) - -// ----- - -func.func @pack_canonicalize(%arg0 : tensor, - %arg1 : tensor<1x2x3x3xi32>) -> tensor<1x?x3x3xi32> { - %c0_i32 = arith.constant 0 : i32 - %0 = tensor.cast %arg1 : tensor<1x2x3x3xi32> to tensor<1x?x3x3xi32> - %1 = iree_linalg_ext.pack %arg0 padding_value(%c0_i32 : i32) - inner_dims_pos = [0, 1] inner_tiles = [3, 3] into %0 - : (tensor tensor<1x?x3x3xi32>) -> tensor<1x?x3x3xi32> - return %1 : tensor<1x?x3x3xi32> -} -// CHECK-LABEL: func.func @pack_canonicalize -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2x3x3xi32> -// CHECK: %[[PAD_VALUE:.+]] = arith.constant 0 : i32 -// CHECK: %[[PACK:.+]] = iree_linalg_ext.pack %[[ARG0]] -// CHECK-SAME: padding_value(%[[PAD_VALUE]] : i32) -// CHECK-SAME: into %[[ARG1]] -// CHECK: %[[CAST:.+]] = tensor.cast %[[PACK]] -// CHECK: return %[[CAST]] diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir deleted file mode 100644 index be10a1daba68..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/conv2d_to_winograd.mlir +++ /dev/null @@ -1,77 +0,0 @@ -// RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-convert-conv2d-to-winograd -mlir-elide-elementsattrs-if-larger=4 %s | FileCheck %s - -func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { - %c0 = arith.constant dense<0.1> : tensor<3x3x4x16xf32> - %0 = linalg.conv_2d_nhwc_hwcf - {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } - ins(%arg0, %c0: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) - outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> - return %0 : tensor<1x14x14x16xf32> -} -// CHECK: func.func @conv_16433136(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: -// CHECK-SAME: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { -// CHECK: %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x4x16xf32> -// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32> -// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x16x16x4xf32>) outs(%[[D0]] : -// CHECK-SAME: tensor<8x8x1x3x3x4xf32>) -> tensor<8x8x1x3x3x4xf32> -// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]] -// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]] -// CHECK-SAME: tensor<8x8x1x3x3x4xf32> into tensor<64x9x4xf32> -// CHECK: %[[D2:.+]] = tensor.empty() : tensor<64x9x16xf32> -// CHECK: %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x9x16xf32>) -> -// CHECK-SAME: tensor<64x9x16xf32> -// CHECK: %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x9x4xf32>, -// CHECK-SAME: tensor<64x4x16xf32>) outs(%[[D3]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32> -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]] -// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]] -// CHECK-SAME: tensor<64x9x16xf32> into tensor<8x8x1x3x3x16xf32> -// CHECK: %[[D5:.+]] = tensor.empty() : tensor<1x18x18x16xf32> -// CHECK: %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[D5]] : -// CHECK-SAME: tensor<1x18x18x16xf32>) -> tensor<1x18x18x16xf32> -// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : -// CHECK-SAME: tensor<1x18x18x16xf32> to tensor<1x14x14x16xf32> -// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x14x14x16xf32> -// CHECK: } - -// ----- - -func.func @conv2d_non_splat_weights(%inputs : tensor<1x4x4x1xf32>, %arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> { - %c0 = arith.constant dense<[[ [[1.0]], [[3.0]], [[5.0]] ], - [ [[7.0]], [[9.0]], [[11.0]] ], - [ [[13.0]], [[15.0]], [[17.0]] ]]> : tensor<3x3x1x1xf32> - %0 = linalg.conv_2d_nhwc_hwcf - {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } - ins(%inputs, %c0: tensor<1x4x4x1xf32>, tensor<3x3x1x1xf32>) - outs(%arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> - return %0 : tensor<1x2x2x1xf32> -} -// CHECK: func.func @conv2d_non_splat_weights(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x4x1xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> { -// CHECK: %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x1x1xf32> -// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x1x1x1xf32> -// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x4x4x1xf32>) outs(%[[D0]] : tensor<8x8x1x1x1x1xf32>) -// CHECK-SAME: -> tensor<8x8x1x1x1x1xf32> -// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]] -// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]] -// CHECK-SAME: tensor<8x8x1x1x1x1xf32> into tensor<64x1x1xf32> -// CHECK: %[[D2:.+]] = tensor.empty() : tensor<64x1x1xf32> -// CHECK: %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x1x1xf32>) -> -// CHECK-SAME: tensor<64x1x1xf32> -// CHECK: %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x1x1xf32>, tensor<64x1x1xf32>) -// CHECK-SAME: outs(%[[D3]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32> -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]] -// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]] -// CHECK-SAME: tensor<64x1x1xf32> into tensor<8x8x1x1x1x1xf32> -// CHECK: %[[D5:.+]] = tensor.empty() : tensor<1x6x6x1xf32> -// CHECK: %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x1x1x1xf32>) outs(%[[D5]] : -// CHECK-SAME: tensor<1x6x6x1xf32>) -> tensor<1x6x6x1xf32> -// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 2, 2, 1] [1, 1, 1, 1] : -// CHECK-SAME: tensor<1x6x6x1xf32> to tensor<1x2x2x1xf32> -// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x2x2x1xf32> -// CHECK: } diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir deleted file mode 100644 index 5bc4e3ab64b7..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir +++ /dev/null @@ -1,1443 +0,0 @@ -// RUN: iree-dialects-opt --split-input-file --iree-linalg-ext-to-loops %s | FileCheck %s - -func.func @sort_1d(%arg0: memref<128xi32>) { - iree_linalg_ext.sort dimension(0) - outs(%arg0 : memref<128xi32>) { - ^bb0(%arg2: i32, %arg3: i32): // no predecessors - %0 = arith.cmpi sgt, %arg2, %arg3 : i32 - iree_linalg_ext.yield %0 : i1 - } - return -} -// CHECK-LABEL: func.func @sort_1d -// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C127:.+]] = arith.constant 127 : index -// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]] -// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C127]] step %[[C1]] -// CHECK: %[[T1:.+]] = arith.addi %[[ARG2]], %[[C1]] : index -// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[ARG2]]] -// CHECK: %[[V2:.+]] = memref.load %[[BUF]][%[[T1]]] -// CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[V1]], %[[V2]] : i32 -// CHECK: scf.if %[[COND]] { -// CHECK: } else { -// CHECK: %[[T2:.+]] = arith.addi %[[ARG2]], %[[C1]] : index -// CHECK: memref.store %[[V2]], %[[BUF]][%[[ARG2]]] -// CHECK: memref.store %[[V1]], %[[BUF]][%[[T2]]] -// CHECK: } - -// ----- - -func.func @sort_2d(%arg0: memref<16x32xi32>) { - iree_linalg_ext.sort dimension(0) - outs(%arg0 : memref<16x32xi32>) { - ^bb0(%arg2: i32, %arg3: i32): // no predecessors - %0 = arith.cmpi sgt, %arg2, %arg3 : i32 - iree_linalg_ext.yield %0 : i1 - } - return -} -// CHECK-LABEL: func.func @sort_2d -// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index -// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]] -// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]] -// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C15]] step %[[C1]] -// CHECK: %[[T1:.+]] = arith.addi %[[ARG3]], %[[C1]] : index -// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[ARG3]], %[[ARG2]]] -// CHECK: %[[V2:.+]] = memref.load %[[BUF]][%[[T1]], %[[ARG2]]] -// CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[V1]], %[[V2]] : i32 -// CHECK: scf.if %[[COND]] { -// CHECK: } else { -// CHECK: %[[T2:.+]] = arith.addi %[[ARG3]], %[[C1]] : index -// CHECK: memref.store %[[V2]], %[[BUF]][%[[ARG3]], %[[ARG2]]] -// CHECK: memref.store %[[V1]], %[[BUF]][%[[T2]], %[[ARG2]]] -// CHECK: } - -// ----- - -func.func @sort_multi(%arg0: memref<128xf32>, %arg1: memref<128xi32>) { - iree_linalg_ext.sort - dimension(0) - outs(%arg0, %arg1 : memref<128xf32>, memref<128xi32>) { - ^bb0(%arg2: f32, %arg3: f32, %arg4: i32, %arg5: i32): // no predecessors - %0 = arith.cmpf ogt, %arg2, %arg3 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} -// CHECK-LABEL: func.func @sort_multi -// CHECK-SAME: %[[BUF1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[BUF2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C127:.+]] = arith.constant 127 : index -// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]] -// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C127]] step %[[C1]] -// CHECK: %[[T1:.+]] = arith.addi %[[ARG2]], %[[C1]] : index -// CHECK: %[[V1:.+]] = memref.load %[[BUF1]][%[[ARG2]]] -// CHECK: %[[V2:.+]] = memref.load %[[BUF1]][%[[T1]]] -// CHECK: %[[V3:.+]] = memref.load %[[BUF2]][%[[ARG2]]] -// CHECK: %[[V4:.+]] = memref.load %[[BUF2]][%[[T1]]] -// CHECK: %[[COND:.+]] = arith.cmpf ogt, %[[V1]], %[[V2]] : f32 -// CHECK: scf.if %[[COND]] { -// CHECK: } else { -// CHECK: %[[T2:.+]] = arith.addi %[[ARG2]], %[[C1]] : index -// CHECK: memref.store %[[V2]], %[[BUF1]][%[[ARG2]]] -// CHECK: memref.store %[[V1]], %[[BUF1]][%[[T2]]] -// CHECK: memref.store %[[V4]], %[[BUF2]][%[[ARG2]]] -// CHECK: memref.store %[[V3]], %[[BUF2]][%[[T2]]] -// CHECK: } - -// ----- - -func.func @scatter_update_scalar_1D( - %original: memref<8xi32>, %indices: memref<3x1xi32>, - %updates: memref<3xi32>) { - iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) - outs(%original : memref<8xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - iree_linalg_ext.yield %arg0 : i32 - } - return -} -// CHECK-LABEL: func.func @scatter_update_scalar_1D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { -// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32> -// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32> -// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index -// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]] - -// ----- - -func.func @scatter_add_scalar_2D( - %original: memref<4x3xi32>, %indices: memref<3x2xi32>, - %updates: memref<3xi32>) { - iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) - ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>) - outs(%original : memref<4x3xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - %0 = arith.addi %arg1, %arg0 : i32 - iree_linalg_ext.yield %0 : i32 - } - return -} -// CHECK-LABEL: func.func @scatter_add_scalar_2D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { -// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32> -// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32> -// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index -// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32> -// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index -// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32> -// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32 -// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] - -// ----- - -func.func @scatter_update_slice_2D( - %original: memref<4x3xi32>, %indices: memref<2x1xi32>, - %updates: memref<2x3xi32>) { - iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) - outs(%original : memref<4x3xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - iree_linalg_ext.yield %arg0 : i32 - } - return -} -// CHECK: func.func @scatter_update_slice_2D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { -// CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]] -// CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] -// CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index -// CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]] -// CHECK: } -// CHECK: } - -// ----- - -func.func @scatter_add_scalar_1D( - %original: memref<8xi32>, %indices: memref<3x1xi32>, - %updates: memref<3xi32>) { - iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) - outs(%original : memref<8xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - %0 = arith.addi %arg1, %arg0 : i32 - iree_linalg_ext.yield %0 : i32 - } - return -} -// CHECK-LABEL: func.func @scatter_add_scalar_1D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { -// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32> -// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32> -// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index -// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32> -// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32 -// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]] - -// ----- - -func.func @scatter_add_slice_2D( - %original: memref<4x3xi32>, %indices: memref<2x1xi32>, - %updates: memref<2x3xi32>) { - iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) - outs(%original : memref<4x3xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - %0 = arith.addi %arg1, %arg0 : i32 - iree_linalg_ext.yield %0 : i32 - } - return -} -// CHECK: func.func @scatter_add_slice_2D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { -// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]] -// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] -// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index -// CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]] -// CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]] -// CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]] - -// ----- - -func.func @scatter_update_scalar_dynamic_1D( - %original: memref, %indices: memref, - %updates: memref) { - iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%updates, %indices : memref, memref) - outs(%original : memref) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - iree_linalg_ext.yield %arg0 : i32 - } - return -} -// CHECK-LABEL: func.func @scatter_update_scalar_dynamic_1D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] { -// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref -// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref -// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index -// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]] - -// ----- - -func.func @scatter_add_scalar_dynamic_2D( - %original: memref, %indices: memref, - %updates: memref) { - iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) - ins(%updates, %indices : memref, memref) - outs(%original : memref) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - %0 = arith.addi %arg1, %arg0 : i32 - iree_linalg_ext.yield %0 : i32 - } - return -} -// CHECK-LABEL: func.func @scatter_add_scalar_dynamic_2D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] { -// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref -// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref -// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index -// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref -// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index -// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref -// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32 -// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] - -// ----- - -func.func @scatter_update_slice_dynamic_2D( - %original: memref, %indices: memref, - %updates: memref) { - iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%updates, %indices : memref, memref) - outs(%original : memref) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - iree_linalg_ext.yield %arg0 : i32 - } - return -} -// CHECK: func.func @scatter_update_slice_dynamic_2D -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref -// CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] { -// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] { -// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]] -// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] -// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index -// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]] - -// ----- - -func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) { - iree_linalg_ext.scatter - dimension_map = [0, 1, 2] - unique_indices(true) - ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>) - outs(%arg0 : memref<2x64x12xf32>) { - ^bb0(%arg3: f32, %arg4: f32): - iree_linalg_ext.yield %arg4 : f32 - } - return -} - -// CHECK-LABEL: @scatter_partial_slices -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant -// CHECK-DAG: %[[C1:.+]] = arith.constant -// CHECK-DAG: %[[C2:.+]] = arith.constant -// CHECK-DAG: %[[C12:.+]] = arith.constant -// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] { -// CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] { -// CHECK-NEXT: %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32> -// CHECK-NEXT: %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index -// CHECK-NEXT: %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32> -// CHECK-NEXT: %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index -// CHECK-NEXT: %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index -// CHECK-NEXT: %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32> -// CHECK-NEXT: %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index -// CHECK-NEXT: %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index -// CHECK-NEXT: %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32> -// CHECK-NEXT: memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32> - -// ----- - -func.func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) { - %stage = arith.constant 1 : index - iree_linalg_ext.fft - ins(%stage: index) - outs(%real, %imag: memref<16xf32>, memref<16xf32>) - return -} -// CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> -// CHECK: func.func @fft_1D -// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[COEFF:.+]] = arith.constant -3.14159274 : f32 -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C2]] -// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[K]]] [%[[C1]]] [1] -// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[K]]] [%[[C1]]] [1] -// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[C1]] : index -// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[R_OFFSET]]] [%[[C1]]] [1] -// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[R_OFFSET]]] [%[[C1]]] [1] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]] -// CHECK: ^bb0(%[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32) -// -// Compute exp coeff. -// CHECK: %[[J_IDX:.+]] = linalg.index 0 : index -// CHECK: %[[J_I32:.+]] = arith.index_cast %[[J_IDX]] : index to i32 -// CHECK: %[[J_F32:.+]] = arith.sitofp %[[J_I32]] : i32 to f32 -// CHECK: %[[EXP_COEF:.+]] = arith.mulf %[[J_F32]], %[[COEFF]] : f32 -// CHECK: %[[W_REAL:.+]] = math.cos %[[EXP_COEF]] -// CHECK: %[[W_IMAG:.+]] = math.sin %[[EXP_COEF]] -// -// Compute "t = w * a[k + j + mh]" by expanding -// (x + yi)(u + vi) = (xu - yv) + (xv + yu)i -// CHECK-DAG: %[[XU:.+]] = arith.mulf %[[W_REAL]], %[[R_REAL]] -// CHECK-DAG: %[[YV:.+]] = arith.mulf %[[W_IMAG]], %[[R_IMAG]] -// CHECK-DAG: %[[XV:.+]] = arith.mulf %[[W_REAL]], %[[R_IMAG]] -// CHECK-DAG: %[[YU:.+]] = arith.mulf %[[W_IMAG]], %[[R_REAL]] -// CHECK: %[[T_REAL:.+]] = arith.subf %[[XU]], %[[YV]] -// CHECK: %[[T_IMAG:.+]] = arith.addf %[[XV]], %[[YU]] -// -// Compute the results. -// u = a[k + j]; -// a[k + j] = u + t; -// a[k + j + mh] = u - t; -// CHECK: %[[RES1:.+]] = arith.addf %[[L_REAL]], %[[T_REAL]] -// CHECK: %[[RES2:.+]] = arith.addf %[[L_IMAG]], %[[T_IMAG]] -// CHECK: %[[RES3:.+]] = arith.subf %[[L_REAL]], %[[T_REAL]] -// CHECK: %[[RES4:.+]] = arith.subf %[[L_IMAG]], %[[T_IMAG]] -// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]] - -// ----- - -func.func @fft_2D(%real: memref, %imag: memref) { - %stage = arith.constant 2 : index - iree_linalg_ext.fft - ins(%stage: index) - outs(%real, %imag: memref, memref) - return -} -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func.func @fft_2D( -// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]] -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C4]] -// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[C2]]] [1, 1] -// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[C2]]] [1, 1] -// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[C2]] : index -// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[C2]]] [1, 1] -// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[C2]]] [1, 1] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]] -// -// The computation is bascially the same, and they are -// checked above. Here only checks the different part. -// CHECK: %{{.+}} = linalg.index 1 : index - -// ----- - -func.func @fft_2D_coef_buf(%real: memref, %imag: memref, - %coef_real: memref<1xf32>, %coef_imag: memref<1xf32>) { - %stage = arith.constant 1 : index - iree_linalg_ext.fft - ins(%stage, %coef_real, %coef_imag: index, memref<1xf32>, memref<1xf32>) - outs(%real, %imag: memref, memref) - return -} -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func.func @fft_2D_coef_buf -// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]] -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C2]] -// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[C1]]] [1, 1] -// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[C1]]] [1, 1] -// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[C1]] : index -// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[C1]]] [1, 1] -// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[C1]]] [1, 1] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP2]], #[[MAP2]], #[[MAP2]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[COEF_REAL]], %[[COEF_IMAG]] -// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]] -// CHECK: ^bb0(%[[W_REAL:.+]]: f32, %[[W_IMAG:.+]]: f32, %[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32) -// Compute "t = w * a[k + j + mh]" by expanding -// (x + yi)(u + vi) = (xu - yv) + (xv + yu)i -// CHECK-DAG: %[[XU:.+]] = arith.mulf %[[W_REAL]], %[[R_REAL]] -// CHECK-DAG: %[[YV:.+]] = arith.mulf %[[W_IMAG]], %[[R_IMAG]] -// CHECK-DAG: %[[XV:.+]] = arith.mulf %[[W_REAL]], %[[R_IMAG]] -// CHECK-DAG: %[[YU:.+]] = arith.mulf %[[W_IMAG]], %[[R_REAL]] -// CHECK: %[[T_REAL:.+]] = arith.subf %[[XU]], %[[YV]] -// CHECK: %[[T_IMAG:.+]] = arith.addf %[[XV]], %[[YU]] -// -// Compute the results. -// u = a[k + j]; -// a[k + j] = u + t; -// a[k + j + mh] = u - t; -// CHECK: %[[RES1:.+]] = arith.addf %[[L_REAL]], %[[T_REAL]] -// CHECK: %[[RES2:.+]] = arith.addf %[[L_IMAG]], %[[T_IMAG]] -// CHECK: %[[RES3:.+]] = arith.subf %[[L_REAL]], %[[T_REAL]] -// CHECK: %[[RES4:.+]] = arith.subf %[[L_IMAG]], %[[T_IMAG]] -// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]] - -// ----- - -func.func @reverse_dim_0(%arg0: memref, %arg1: memref) { - iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : memref) - outs(%arg1 : memref) - return -} -// CHECK-LABEL: func.func @reverse_dim_0 -// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %arg0, %c0 : memref -// CHECK-DAG: %[[D1:.+]] = memref.dim %arg0, %c1 : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]] -// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]] -// CHECK: %[[T0:.+]] = memref.dim %[[IN]], %[[C0]] -// CHECK: %[[T1:.+]] = arith.subi %[[T0]], %[[C1]] : index -// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[I]] : index -// CHECK: %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]] -// CHECK: memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref - -func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) { - %c0 = memref.alloc() : memref - iree_linalg_ext.scan dimension(0) inclusive(true) - ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref) { - ^bb0(%arg0 : i32, %arg1 : i32): - %sum = arith.addi %arg0, %arg1 : i32 - iree_linalg_ext.yield %sum : i32 - } - return -} -// CHECK-LABEL: func.func @scan_1d_inclusive -// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref -// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]] -// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index -// CHECK: scf.if %[[COND]] { -// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]] -// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]]] -// CHECK: } else { -// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index -// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]] -// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]] -// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32 -// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]] -// CHECK: memref.store %[[V4]], %[[ACC]][] -// CHECK: } - -// ----- - -func.func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) { - %c0 = memref.alloc() : memref - iree_linalg_ext.scan dimension(0) inclusive(false) - ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref) { - ^bb0(%arg0 : i32, %arg1 : i32): - %sum = arith.addi %arg0, %arg1 : i32 - iree_linalg_ext.yield %sum : i32 - } - return -} -// CHECK-LABEL: func.func @scan_1d_exclusive -// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref -// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]] -// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index -// CHECK: scf.if %[[COND]] { -// CHECK: %[[V0:.+]] = memref.load %[[ACC]][] : memref -// CHECK: memref.store %[[V0]], %[[BUFO]][%[[ARG1]]] -// CHECK: } else { -// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index -// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]] -// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]] -// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32 -// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]] -// CHECK: memref.store %[[V4]], %[[ACC]][] -// CHECK: } - -// ----- - -func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { - %t0 = memref.alloc() : memref<32xi32> - iree_linalg_ext.scan dimension(0) inclusive(true) - ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) { - ^bb0(%arg0 : i32, %arg1 : i32): - %sum = arith.addi %arg0, %arg1 : i32 - iree_linalg_ext.yield %sum : i32 - } - return -} -// CHECK-LABEL: func.func @scan_2d -// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<32xi32> -// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]] -// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]] -// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index -// CHECK: scf.if %[[COND]] { -// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]] -// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]] -// CHECK: } else { -// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index -// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]] -// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]] -// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32 -// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]] -// CHECK: memref.store %[[V4]], %[[ACC]][%[[ARG2]]] -// CHECK: } - -// ----- - -func.func @topk_memref(%input_values: memref<2x10xf32>, %input_indices: memref<2x10xi32>, %out_values: memref<2x3xf32>, %out_indices: memref<2x3xi32>) { - iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : memref<2x10xf32> , memref<2x10xi32>) - outs(%out_values, %out_indices : memref<2x3xf32>, memref<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} - -// CHECK-LABEL: func.func @topk_memref -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C10]] step %[[C1]] -// CHECK: %[[D0:.+]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] -// CHECK: %[[D1:.+]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] -// CHECK: %[[D2:.+]]:2 = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.+]] = %[[D0]], %[[ARG8:.+]] = %[[D1]]) -// CHECK: %[[D3:.+]] = memref.load %[[ARG2]][%[[ARG4]], %[[ARG6]]] -// CHECK: %[[D4:.+]] = memref.load %[[ARG3]][%[[ARG4]], %[[ARG6]]] -// CHECK: %[[D5:.+]] = arith.cmpf ogt, %[[ARG7]], %[[D3]] : f32 -// CHECK: %[[D6:.+]] = arith.cmpf ogt, %[[D3]], %[[ARG7]] : f32 -// CHECK: %[[D7:.+]] = arith.cmpi eq, %[[D5]], %[[D6]] : i1 -// CHECK: %[[D8:.+]] = arith.cmpi slt, %[[ARG8]], %[[D4]] : i32 -// CHECK: %[[D9:.+]] = arith.andi %[[D7]], %[[D8]] : i1 -// CHECK: %[[D10:.+]] = arith.ori %[[D5]], %[[D9]] : i1 -// CHECK: %[[D11:.+]] = arith.select %[[D5]], %[[ARG7]], %[[D3]] : f32 -// CHECK: %[[D12:.+]] = arith.select %[[D10]], %[[ARG8]], %[[D4]] : i32 -// CHECK: memref.store %[[D11]], %[[ARG2]][%[[ARG4]], %[[ARG6]]] -// CHECK: memref.store %[[D12]], %[[ARG3]][%[[ARG4]], %[[ARG6]]] -// CHECK: %[[D13:.+]] = arith.select %[[D5]], %[[D3]], %[[ARG7]] : f32 -// CHECK: %[[D14:.+]] = arith.select %[[D10]], %[[D4]], %[[ARG8]] : i32 -// CHECK: scf.yield %[[D13]], %[[D14]] : f32, i32 - -// ----- - -func.func @topk_memref_dynamic(%input_values: memref, %input_indices: memref, %out_values: memref, %out_indices: memref) { - iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : memref , memref) - outs(%out_values, %out_indices : memref, memref) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} - -// CHECK-LABEL: func.func @topk_memref_dynamic -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[D0:.+]] = memref.dim %[[ARG0:.+]], %[[C0]] -// CHECK: %[[D1:.+]] = memref.dim %[[ARG0:.+]], %[[C1]] -// CHECK: scf.for %[[ARG4:.+]] = %[[C0]] to %[[D0]] step %[[C1]] -// CHECK: scf.for %[[ARG5:.+]] = %[[C0]] to %[[D1]] step %[[C1]] -// CHECK: %[[D2:.+]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] -// CHECK: %[[D3:.+]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] -// CHECK: %[[D4:.+]]:2 = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.+]] = %[[D2]], %[[ARG8:.+]] = %[[D3]]) -// CHECK: %[[D5:.+]] = memref.load %[[ARG2]][%[[ARG4]], %[[ARG6]]] -// CHECK: %[[D6:.+]] = memref.load %[[ARG3]][%[[ARG4]], %[[ARG6]]] -// CHECK: %[[D7:.+]] = arith.cmpf ogt, %[[ARG7]], %[[D5]] : f32 -// CHECK: %[[D8:.+]] = arith.cmpf ogt, %[[D5]], %[[ARG7]] : f32 -// CHECK: %[[D9:.+]] = arith.cmpi eq, %[[D7]], %[[D8]] : i1 -// CHECK: %[[D10:.+]] = arith.cmpi slt, %[[ARG8]], %[[D6]] : i32 -// CHECK: %[[D11:.+]] = arith.andi %[[D9]], %[[D10]] : i1 -// CHECK: %[[D12:.+]] = arith.ori %[[D7]], %[[D11]] : i1 -// CHECK: %[[D13:.+]] = arith.select %[[D7]], %[[ARG7]], %[[D5]] : f32 -// CHECK: %[[D14:.+]] = arith.select %[[D12]], %[[ARG8]], %[[D6]] : i32 -// CHECK: memref.store %[[D13]], %[[ARG2]][%[[ARG4]], %[[ARG6]]] -// CHECK: memref.store %[[D14]], %[[ARG3]][%[[ARG4]], %[[ARG6]]] -// CHECK: %[[D15:.+]] = arith.select %[[D7]], %[[D5]], %[[ARG7]] : f32 -// CHECK: %[[D16:.+]] = arith.select %[[D12]], %[[D6]], %[[ARG8]] : i32 -// CHECK: scf.yield %[[D15]], %[[D16]] : f32, i32 - -// ----- - -func.func @topk_memref_optional(%input_values: memref<2x10xf32>, %out_values: memref<2x3xf32>, %out_indices: memref<2x3xi32>) { - iree_linalg_ext.topk - dimension(1) - ins(%input_values : memref<2x10xf32>) - outs(%out_values, %out_indices : memref<2x3xf32>, memref<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} - -// CHECK-LABEL: func.func @topk_memref -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C10]] step %[[C1]] -// CHECK: %[[D0:.+]] = memref.load %[[ARG0]][%[[ARG3]], %[[ARG4]]] -// CHECK: %[[D1:.+]] = arith.index_cast %[[ARG4]] : index to i32 -// CHECK: %[[D2:.+]]:2 = scf.for %[[ARG5:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.+]] = %[[D0]], %[[ARG7:.+]] = %[[D1]]) -// CHECK: %[[D3:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[ARG5]]] -// CHECK: %[[D4:.+]] = memref.load %[[ARG2]][%[[ARG3]], %[[ARG5]]] -// CHECK: %[[D5:.+]] = arith.cmpf ogt, %[[ARG6]], %[[D3]] : f32 -// CHECK: %[[D6:.+]] = arith.cmpf ogt, %[[D3]], %[[ARG6]] : f32 -// CHECK: %[[D7:.+]] = arith.cmpi eq, %[[D5]], %[[D6]] : i1 -// CHECK: %[[D8:.+]] = arith.cmpi slt, %[[ARG7]], %[[D4]] : i32 -// CHECK: %[[D9:.+]] = arith.andi %[[D7]], %[[D8]] : i1 -// CHECK: %[[D10:.+]] = arith.ori %[[D5]], %[[D9]] : i1 -// CHECK: %[[D11:.+]] = arith.select %[[D5]], %[[ARG6]], %[[D3]] : f32 -// CHECK: %[[D12:.+]] = arith.select %[[D10]], %[[ARG7]], %[[D4]] : i32 -// CHECK: memref.store %[[D11]], %[[ARG1]][%[[ARG3]], %[[ARG5]]] -// CHECK: memref.store %[[D12]], %[[ARG2]][%[[ARG3]], %[[ARG5]]] -// CHECK: %[[D13:.+]] = arith.select %[[D5]], %[[D3]], %[[ARG6]] : f32 -// CHECK: %[[D14:.+]] = arith.select %[[D10]], %[[D4]], %[[ARG7]] : i32 -// CHECK: scf.yield %[[D13]], %[[D14]] : f32, i32 - -// ----- - -func.func @NC_to_NCnc(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : (memref<128x256xf32> memref<4x8x32x32xf32>) - return -} -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-LABEL: func.func @NC_to_NCnc( -// CHECK-DAG: %[[lb:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[ubN:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[step:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[ubC:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[block:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[N:.*]] = %[[lb]] to %[[ubN]] step %[[step]] { -// CHECK: scf.for %[[C:.*]] = %[[lb]] to %[[ubC]] step %[[step]] { -// CHECK: scf.for %[[n:.*]] = %[[lb]] to %[[block]] step %[[step]] { -// CHECK: scf.for %[[c:.*]] = %[[lb]] to %[[block]] step %[[step]] { -// CHECK-DAG: %[[applyMapI:.*]] = affine.apply #[[MAP]](%[[N]], %[[n]]) -// CHECK-DAG: %[[applyMapJ:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<128x256xf32> -// CHECK: memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<4x8x32x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @NC_to_NCnc_pad_static(%arg0: memref<13x15xf32>, %arg1: memref<2x8x8x2xf32>, %arg2: f32) { - iree_linalg_ext.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : (memref<13x15xf32> memref<2x8x8x2xf32>) - return -} -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> -// CHECK-LABEL: func.func @NC_to_NCnc_pad_static( -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C13:.*]] = arith.constant 13 : index -// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index -// CHECK: scf.for %[[N:.*]] = %[[C0]] to %[[C2]] step %[[step]] { -// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[step]] { -// CHECK: scf.for %[[n:.*]] = %[[C0]] to %[[C8]] step %[[step]] { -// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C2]] step %[[step]] { -// CHECK-DAG: %[[applyMapI:.*]] = affine.apply #[[MAP0]](%[[N]], %[[n]]) -// CHECK-DAG: %[[applyMapJ:.*]] = affine.apply #[[MAP1]](%[[C]], %[[c]]) -// CHECK: %[[isIInBound:.*]] = arith.cmpi slt, %[[applyMapI]], %[[C13]] : index -// CHECK: %[[isJInBound:.*]] = arith.cmpi slt, %[[applyMapJ]], %[[C15]] : index -// CHECK: %[[isAllInBounds:.*]] = arith.andi %[[isIInBound]], %[[isJInBound]] : i1 -// CHECK: %[[scalar:.*]] = scf.if %[[isAllInBounds]] -> (f32) { -// CHECK: %[[load:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<13x15xf32> -// CHECK: scf.yield %[[load]] -// CHECK: } else { -// CHECK: scf.yield %arg2 -// CHECK: } -// CHECK: memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<2x8x8x2xf32> - -// ----- - -func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1 : (memref<128x256xf32> memref<4x8x32x32xf32>) - return -} -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-LABEL: func.func @KC_to_KCck( -// CHECK-DAG: %[[lb:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[ubK:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[step:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[ubC:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[block:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[K:.*]] = %[[lb]] to %[[ubK]] step %[[step]] { -// CHECK: scf.for %[[C:.*]] = %[[lb]] to %[[ubC]] step %[[step]] { -// CHECK: scf.for %[[c:.*]] = %[[lb]] to %[[block]] step %[[step]] { -// CHECK: scf.for %[[k:.*]] = %[[lb]] to %[[block]] step %[[step]] { -// CHECK-DAG: %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]]) -// CHECK-DAG: %[[applyMapK:.*]] = affine.apply #[[MAP]](%[[K]], %[[k]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[applyMapK]], %[[applyMapC]]] : memref<128x256xf32> -// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[c]], %[[k]]] : memref<4x8x32x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -// This should be a simple expand shape. -func.func @KC_to_KCc(%arg0: memref<128x256xf32>, %arg1: memref<128x8x32xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : (memref<128x256xf32> memref<128x8x32xf32>) - return -} -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-LABEL: func.func @KC_to_KCc( -// CHECK-DAG: %[[lb:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[step:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[ubK:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[ubC:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[block:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[K:.*]] = %[[lb]] to %[[ubK]] step %[[step]] { -// CHECK: scf.for %[[C:.*]] = %[[lb]] to %[[ubC]] step %[[step]] { -// CHECK: scf.for %[[c:.*]] = %[[lb]] to %[[block]] step %[[step]] { -// CHECK: %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[applyMapC]]] : memref<128x256xf32> -// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[c]]] : memref<128x8x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @KC_to_KCk(%arg0: memref<128x256xf32>, %arg1: memref<4x256x32xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %arg1 : (memref<128x256xf32> memref<4x256x32xf32>) - return -} - -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-LABEL: func.func @KC_to_KCk( -// CHECK-DAG: %[[lb:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[step:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[ubC:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[ubK:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[block:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[K:.*]] = %[[lb]] to %[[ubK]] step %[[step]] { -// CHECK: scf.for %[[C:.*]] = %[[lb]] to %[[ubC]] step %[[step]] { -// CHECK: scf.for %[[k:.*]] = %[[lb]] to %[[block]] step %[[step]] { -// CHECK: %[[applyMapK:.*]] = affine.apply #[[MAP]](%[[K]], %[[k]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[applyMapK]], %[[C]]] : memref<128x256xf32> -// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[k]]] : memref<4x256x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @KCRS_to_KCRSck(%arg0: memref<128x64x1x1xf32>, %arg1: memref<4x8x1x1x8x32xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [8, 32] into %arg1 : (memref<128x64x1x1xf32> memref<4x8x1x1x8x32xf32>) - return -} - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK-LABEL: func.func @KCRS_to_KCRSck( -// CHECK-DAG: %[[lb:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[one:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[ubK:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[ubC:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[blockK:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[K:.*]] = %[[lb]] to %[[ubK]] step %[[one]] { -// CHECK: scf.for %[[C:.*]] = %[[lb]] to %[[ubC]] step %[[one]] { -// CHECK: scf.for %[[R:.*]] = %[[lb]] to %[[one]] step %[[one]] { -// CHECK: scf.for %[[S:.*]] = %[[lb]] to %[[one]] step %[[one]] { -// CHECK: scf.for %[[c:.*]] = %[[lb]] to %[[ubC]] step %[[one]] { -// CHECK: scf.for %[[k:.*]] = %[[lb]] to %[[blockK]] step %[[one]] { -// CHECK-DAG: %[[affineMapK:.*]] = affine.apply #[[MAP0]](%[[K]], %[[k]]) -// CHECK-DAG: %[[affineMapC:.*]] = affine.apply #[[MAP1]](%[[C]], %[[c]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[affineMapK]], %[[affineMapC]], %[[R]], %[[S]]] : memref<128x64x1x1xf32> -// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[R]], %[[S]], %[[c]], %[[k]]] : memref<4x8x1x1x8x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @KCRS_to_KCRSsr(%arg0: memref<1x1x128x64xf32>, %arg1: memref<1x1x4x8x8x32xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : (memref<1x1x128x64xf32> memref<1x1x4x8x8x32xf32>) - return -} - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK-LABEL: func.func @KCRS_to_KCRSsr( -// CHECK-DAG: %[[lb:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[one:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[ubR:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[ubS:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[blockR:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[K:.*]] = %[[lb]] to %[[one]] step %[[one]] { -// CHECK: scf.for %[[C:.*]] = %[[lb]] to %[[one]] step %[[one]] { -// CHECK: scf.for %[[R:.*]] = %[[lb]] to %[[ubR]] step %[[one]] { -// CHECK: scf.for %[[S:.*]] = %[[lb]] to %[[ubS]] step %[[one]] { -// CHECK: scf.for %[[s:.*]] = %[[lb]] to %[[ubS]] step %[[one]] { -// CHECK: scf.for %[[r:.*]] = %[[lb]] to %[[blockR]] step %[[one]] { -// CHECK-DAG: %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]]) -// CHECK-DAG: %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[C]], %[[affineMapR]], %[[affineMapS]]] : memref<1x1x128x64xf32> -// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref<1x1x4x8x8x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -// Test to check that we properly handle shuffled `inner_dims_pos` and `tiles. -// In this example, the dimension at position `0` (aka `128`) is tiled with a factor of `32`. -// While the dimension at position `2` (aka `2`) is tiled with a factor of `2`. -func.func @shuffled_dim_pos_and_tiles(%arg0: memref<128x256x2x1000xf32>, %arg1: memref<4x256x1x1000x2x32xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [2, 0] inner_tiles = [2, 32] into %arg1 : (memref<128x256x2x1000xf32> memref<4x256x1x1000x2x32xf32>) - return -} - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> -// CHECK-LABEL: func.func @shuffled_dim_pos_and_tiles( -// CHECK-DAG: %[[lb:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[step:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[ubDimZero:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[ubDimOne:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[ubDimThree:.*]] = arith.constant 1000 : index -// CHECK-DAG: %[[ubDimFour:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[ubDimFive:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[i:.*]] = %[[lb]] to %[[ubDimZero]] step %[[step]] { -// CHECK: scf.for %[[j:.*]] = %[[lb]] to %[[ubDimOne]] step %[[step]] { -// CHECK: scf.for %[[k:.*]] = %[[lb]] to %[[step]] step %[[step]] { -// CHECK: scf.for %[[l:.*]] = %[[lb]] to %[[ubDimThree]] step %[[step]] { -// CHECK: scf.for %[[m:.*]] = %[[lb]] to %[[ubDimFour]] step %[[step]] { -// CHECK: scf.for %[[n:.*]] = %[[lb]] to %[[ubDimFive]] step %[[step]] { -// CHECK-DAG: %[[affineApplyZero:.*]] = affine.apply #[[MAP0]](%[[i]], %[[n]]) -// CHECK-DAG: %[[affineApplyOne:.*]] = affine.apply #[[MAP1]](%[[k]], %[[m]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[affineApplyZero]], %[[j]], %[[affineApplyOne]], %[[l]]] : memref<128x256x2x1000xf32> -// CHECK: memref.store %[[scalar]], %arg1[%[[i]], %[[j]], %[[k]], %[[l]], %[[m]], %[[n]]] : memref<4x256x1x1000x2x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @KCRS_to_KCRSsr(%arg0: memref, %arg1: memref) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : (memref memref) - return -} - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK: func.func @KCRS_to_KCRSsr( -// CHECK-DAG: %[[zero:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[one:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[two:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[three:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[eight:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[thirtyTwo:.*]] = arith.constant 32 : index -// CHECK-DAG: %[[dimZero:.*]] = memref.dim %arg1, %[[zero]] : memref -// CHECK-DAG: %[[dimOne:.*]] = memref.dim %arg1, %[[one]] : memref -// CHECK-DAG: %[[dimTwo:.*]] = memref.dim %arg1, %[[two]] : memref -// CHECK-DAG: %[[dimThree:.*]] = memref.dim %arg1, %[[three]] : memref -// CHECK: scf.for %[[K:.*]] = %[[zero]] to %[[dimZero]] step %[[one]] { -// CHECK: scf.for %[[C:.*]] = %[[zero]] to %[[dimOne]] step %[[one]] { -// CHECK: scf.for %[[R:.*]] = %[[zero]] to %[[dimTwo]] step %[[one]] { -// CHECK: scf.for %[[S:.*]] = %[[zero]] to %[[dimThree]] step %[[one]] { -// CHECK: scf.for %[[s:.*]] = %[[zero]] to %[[eight]] step %[[step]] { -// CHECK: scf.for %[[r:.*]] = %[[zero]] to %[[thirtyTwo]] step %[[step]] { -// CHECK-DAG: %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]]) -// CHECK-DAG: %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]]) -// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[C]], %[[affineMapR]], %[[affineMapS]]] : memref -// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @KCRS_to_KCRSsr(%arg0: memref, %arg1: memref, %block : index) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, %block] into %arg1 : (memref memref) - return -} - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK: func.func @KCRS_to_KCRSsr -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[zero:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[one:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[two:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[three:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[eight:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[five:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[dimZero:.*]] = memref.dim %[[ARG1]], %[[zero]] : memref -// CHECK-DAG: %[[dimOne:.*]] = memref.dim %[[ARG1]], %[[one]] : memref -// CHECK-DAG: %[[dimTwo:.*]] = memref.dim %[[ARG1]], %[[two]] : memref -// CHECK-DAG: %[[dimThree:.*]] = memref.dim %[[ARG1]], %[[three]] : memref -// CHECK: scf.for %[[K:.*]] = %[[zero]] to %[[dimZero]] step %[[one]] { -// CHECK: scf.for %[[C:.*]] = %[[zero]] to %[[dimOne]] step %[[one]] { -// CHECK: scf.for %[[R:.*]] = %[[zero]] to %[[dimTwo]] step %[[one]] { -// CHECK: scf.for %[[S:.*]] = %[[zero]] to %[[dimThree]] step %[[one]] { -// CHECK: %[[dimFive:.*]] = memref.dim %[[ARG1]], %[[five]] : memref -// CHECK: scf.for %[[s:.*]] = %[[zero]] to %[[eight]] step %[[step]] { -// CHECK: scf.for %[[r:.*]] = %[[zero]] to %[[dimFive]] step %[[step]] { -// CHECK-DAG: %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]])[%[[ARG2]]] -// CHECK-DAG: %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]]) -// CHECK: %[[scalar:.*]] = memref.load %[[ARG0]][%[[K]], %[[C]], %[[affineMapR]], %[[affineMapS]]] : memref -// CHECK: memref.store %[[scalar]], %[[ARG1]][%[[K]], %[[C]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @NCnc_to_NC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : (memref<4x8x32x32xf32> memref<128x256xf32>) - return -} - -// CHECK-DAG: #[[MAP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP_MOD:.*]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK: func.func @NCnc_to_NC -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[UBI:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[UBJ:.*]] = arith.constant 256 : index -// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UBI]] step %[[STEP]] { -// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UBJ]] step %[[STEP]] { -// CHECK-DAG: %[[FLOORI:.*]] = affine.apply #[[MAP_FLOOR]](%[[I]]) -// CHECK-DAG: %[[FLOORJ:.*]] = affine.apply #[[MAP_FLOOR]](%[[J]]) -// CHECK-DAG: %[[MODI:.*]] = affine.apply #[[MAP_MOD]](%[[I]]) -// CHECK-DAG: %[[MODJ:.*]] = affine.apply #[[MAP_MOD]](%[[J]]) -// CHECK: %[[SCALAR:.*]] = memref.load %[[ARG1]][%[[FLOORI]], %[[FLOORJ]], %[[MODI]], %[[MODJ]]] : memref<4x8x32x32xf32> -// CHECK: memref.store %[[SCALAR]], %[[ARG0]][%[[I]], %[[J]]] : memref<128x256xf32> -// CHECK: } -// CHECK: } - -// ----- - -func.func @KCck_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg0 : (memref<4x8x32x32xf32> memref<128x256xf32>) - return -} - -// CHECK-DAG: #[[MAP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP_MOD:.*]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK: func.func @KCck_to_KC -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[UBI:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[UBJ:.*]] = arith.constant 256 : index -// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UBI]] step %[[STEP]] { -// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UBJ]] step %[[STEP]] { -// CHECK-DAG: %[[FLOORI:.*]] = affine.apply #[[MAP_FLOOR]](%[[I]]) -// CHECK-DAG: %[[FLOORJ:.*]] = affine.apply #[[MAP_FLOOR]](%[[J]]) -// CHECK-DAG: %[[MODI:.*]] = affine.apply #[[MAP_MOD]](%[[I]]) -// CHECK-DAG: %[[MODJ:.*]] = affine.apply #[[MAP_MOD]](%[[J]]) -// CHECK: %[[SCALAR:.*]] = memref.load %[[ARG1]][%[[FLOORI]], %[[FLOORJ]], %[[MODJ]], %[[MODI]]] : memref<4x8x32x32xf32> -// CHECK: memref.store %[[SCALAR]], %[[ARG0]][%[[I]], %[[J]]] : memref<128x256xf32> -// CHECK: } -// CHECK: } - -// ----- - -// This should be a simple collapse shape. -func.func @KCc_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<128x8x32xf32>) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [1] inner_tiles = [32] into %arg0 : (memref<128x8x32xf32> memref<128x256xf32>) - return -} - -// CHECK-DAG: #[[MAP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP_MOD:.*]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK: func.func @KCc_to_KC -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[UBI:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[UBJ:.*]] = arith.constant 256 : index -// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UBI]] step %[[STEP]] { -// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UBJ]] step %[[STEP]] { -// CHECK-DAG: %[[FLOORJ:.*]] = affine.apply #[[MAP_FLOOR]](%[[J]]) -// CHECK-DAG: %[[MODJ:.*]] = affine.apply #[[MAP_MOD]](%[[J]]) -// CHECK: %[[SCALAR:.*]] = memref.load %[[ARG1]][%[[I]], %[[FLOORJ]], %[[MODJ]]] : memref<128x8x32xf32> -// CHECK: memref.store %[[SCALAR]], %[[ARG0]][%[[I]], %[[J]]] : memref<128x256xf32> -// CHECK: } -// CHECK: } - - - -// ----- - -func.func @KCRSsr_to_KCRS(%arg0: memref<1x1x128x64xf32>, %arg1: memref<1x1x4x8x8x32xf32>) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg0 : (memref<1x1x4x8x8x32xf32> memref<1x1x128x64xf32>) - return -} - -// CHECK-DAG: #[[MAP_FLOORK:.*]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP_MODK:.*]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK-DAG: #[[MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-DAG: #[[MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)> -// CHECK: func.func @KCRSsr_to_KCRS -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[UBK:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[UBL:.*]] = arith.constant 64 : index -// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[STEP]] step %[[STEP]] { -// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[STEP]] step %[[STEP]] { -// CHECK: scf.for %[[K:.*]] = %[[LB]] to %[[UBK]] step %[[STEP]] { -// CHECK: scf.for %[[L:.*]] = %[[LB]] to %[[UBL]] step %[[STEP]] { -// CHECK-DAG: %[[FLOORK:.*]] = affine.apply #[[MAP_FLOORK]](%[[K]]) -// CHECK-DAG: %[[FLOORL:.*]] = affine.apply #[[MAP_FLOORL]](%[[L]]) -// CHECK-DAG: %[[MODK:.*]] = affine.apply #[[MAP_MODK]](%[[K]]) -// CHECK-DAG: %[[MODL:.*]] = affine.apply #[[MAP_MODL]](%[[L]]) -// CHECK: %[[SCALAR:.*]] = memref.load %[[ARG1]][%[[I]], %[[J]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref<1x1x4x8x8x32xf32> -// CHECK: memref.store %[[SCALAR]], %[[ARG0]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<1x1x128x64xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @shuffled_dim_pos_and_tiles(%arg0: memref<128x256x2x1000xf32>, %arg1: memref<4x256x1x1000x2x32xf32>) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [2, 0] inner_tiles = [2, 32] into %arg0 : (memref<4x256x1x1000x2x32xf32> memref<128x256x2x1000xf32>) - return -} - -// CHECK-DAG: #[[MAP_FLOORI:.*]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP_MODI:.*]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK-DAG: #[[MAP_FLOORK:.*]] = affine_map<(d0) -> (d0 floordiv 2)> -// CHECK-DAG: #[[MAP_MODK:.*]] = affine_map<(d0) -> (d0 mod 2)> -// CHECK: func.func @shuffled_dim_pos_and_tiles -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[UBI:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[UBJ:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[UBK:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[UBL:.*]] = arith.constant 1000 : index -// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UBI]] step %[[STEP]] { -// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UBJ]] step %[[STEP]] { -// CHECK: scf.for %[[K:.*]] = %[[LB]] to %[[UBK]] step %[[STEP]] { -// CHECK: scf.for %[[L:.*]] = %[[LB]] to %[[UBL]] step %[[STEP]] { -// CHECK-DAG: %[[FLOORI:.*]] = affine.apply #[[MAP_FLOORI]](%[[I]]) -// CHECK-DAG: %[[MODI:.*]] = affine.apply #[[MAP_MODI]](%[[I]]) -// CHECK-DAG: %[[FLOORK:.*]] = affine.apply #[[MAP_FLOORK]](%[[K]]) -// CHECK-DAG: %[[MODK:.*]] = affine.apply #[[MAP_MODK]](%[[K]]) -// CHECK: %[[SCALAR:.*]] = memref.load %[[ARG1]][%[[FLOORI]], %[[J]], %[[FLOORK]], %[[L]], %[[MODK]], %[[MODI]]] : memref<4x256x1x1000x2x32xf32> -// CHECK: memref.store %[[SCALAR]], %[[ARG0]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<128x256x2x1000xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @KCRSsr_to_KCRS(%arg0: memref, %arg1: memref) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg0 : (memref memref) - return -} - -// CHECK-DAG: #[[MAP_FLOORK:.*]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP_MODK:.*]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK-DAG: #[[MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-DAG: #[[MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)> -// CHECK: func.func @KCRSsr_to_KCRS -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[UBI:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref -// CHECK-DAG: %[[UBJ:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref -// CHECK-DAG: %[[UBK:.*]] = memref.dim %[[ARG0]], %[[C2]] : memref -// CHECK-DAG: %[[UBL:.*]] = memref.dim %[[ARG0]], %[[C3]] : memref -// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[UBI]] step %[[C1]] { -// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[UBJ]] step %[[C1]] { -// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[UBK]] step %[[C1]] { -// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[UBL]] step %[[C1]] { -// CHECK-DAG: %[[FLOORK:.*]] = affine.apply #[[MAP_FLOORK]](%[[K]]) -// CHECK-DAG: %[[FLOORL:.*]] = affine.apply #[[MAP_FLOORL]](%[[L]]) -// CHECK-DAG: %[[MODK:.*]] = affine.apply #[[MAP_MODK]](%[[K]]) -// CHECK-DAG: %[[MODL:.*]] = affine.apply #[[MAP_MODL]](%[[L]]) -// CHECK: %[[SCALAR:.*]] = memref.load %[[ARG1]][%[[I]], %[[J]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref -// CHECK: memref.store %[[SCALAR]], %[[ARG0]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @KCRSsr_to_KCRS(%arg0: memref, %arg1: memref, %block : index) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, %block] into %arg0 : (memref memref) - return -} - -// CHECK-DAG: #[[MAP_FLOORK:.*]] = affine_map<(d0)[s0] -> (d0 floordiv s0)> -// CHECK-DAG: #[[MAP_MODK:.*]] = affine_map<(d0)[s0] -> (d0 mod s0)> -// CHECK-DAG: #[[MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-DAG: #[[MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)> -// CHECK: func.func @KCRSsr_to_KCRS -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[UBI:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref -// CHECK-DAG: %[[UBJ:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref -// CHECK-DAG: %[[UBK:.*]] = memref.dim %[[ARG0]], %[[C2]] : memref -// CHECK-DAG: %[[UBL:.*]] = memref.dim %[[ARG0]], %[[C3]] : memref -// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[UBI]] step %[[C1]] { -// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[UBJ]] step %[[C1]] { -// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[UBK]] step %[[C1]] { -// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[UBL]] step %[[C1]] { -// CHECK-DAG: %[[FLOORK:.*]] = affine.apply #[[MAP_FLOORK]](%[[K]])[%[[ARG2]]] -// CHECK-DAG: %[[FLOORL:.*]] = affine.apply #[[MAP_FLOORL]](%[[L]]) -// CHECK-DAG: %[[MODK:.*]] = affine.apply #[[MAP_MODK]](%[[K]])[%[[ARG2]]] -// CHECK-DAG: %[[MODL:.*]] = affine.apply #[[MAP_MODL]](%[[L]]) -// CHECK: %[[SCALAR:.*]] = memref.load %[[ARG1]][%[[I]], %[[J]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref -// CHECK: memref.store %[[SCALAR]], %[[ARG0]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @unpack_undo_padding(%input: memref<2x8x8x2xf32>, %output: memref<13x15xf32>) { - iree_linalg_ext.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (memref<2x8x8x2xf32> memref<13x15xf32>) - return -} -// CHECK-DAG: #[[MAP_FLOORI:.*]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-DAG: #[[MAP_MODI:.*]] = affine_map<(d0) -> (d0 mod 8)> -// CHECK-DAG: #[[MAP_FLOORJ:.*]] = affine_map<(d0) -> (d0 floordiv 2)> -// CHECK-DAG: #[[MAP_MODJ:.*]] = affine_map<(d0) -> (d0 mod 2)> -// CHECK: func.func @unpack_undo_padding -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C13:.+]] = arith.constant 13 : index -// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C13]] step %[[C1]] { -// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C15]] step %[[C1]] { -// CHECK-DAG: %[[OUTER_I:.+]] = affine.apply #[[MAP_FLOORI]](%[[I]]) -// CHECK-DAG: %[[INNER_I:.+]] = affine.apply #[[MAP_MODI]](%[[I]]) -// CHECK-DAG: %[[OUTER_J:.+]] = affine.apply #[[MAP_FLOORJ]](%[[J]]) -// CHECK-DAG: %[[INNER_J:.+]] = affine.apply #[[MAP_MODJ]](%[[J]]) -// CHECK: %[[VAL:.+]] = memref.load %[[INPUT]][%[[OUTER_I]], %[[OUTER_J]], %[[INNER_I]], %[[INNER_J]]] -// CHECK: memref.store %[[VAL]], %[[OUTPUT]][%[[I]], %[[J]]] - -// ----- - -func.func @KC_to_CKkc(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) { - iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<32x4x32x8xf32>) - return -} - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK: func.func @KC_to_CKkc -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK: scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C1]] { -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C4]] step %[[C1]] { -// CHECK: scf.for %[[k:.+]] = %[[C0]] to %[[C32]] step %[[C1]] { -// CHECK: scf.for %[[c:.+]] = %[[C0]] to %[[C8]] step %[[C1]] { -// CHECK: %[[MAPK:.+]] = affine.apply #[[MAP0]](%[[K]], %[[k]]) -// CHECK: %[[MAPC:.+]] = affine.apply #[[MAP1]](%[[C]], %[[c]]) -// CHECK: %[[VAL:.+]] = memref.load %[[ARG0]][%[[MAPK]], %[[MAPC]]] : memref<128x256xf32> -// CHECK: memref.store %[[VAL]], %[[ARG1]][%[[C]], %[[K]], %[[k]], %[[c]]] : memref<32x4x32x8xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @CKkc_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) { - iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg0 : (memref<32x4x32x8xf32> memref<128x256xf32>) - return -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 mod 8)> -// CHECK: func.func @CKkc_to_KC -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C128]] step %[[C1]] { -// CHECK: scf.for %[[C:.+]] = %[[C0]] to %[[C256]] step %[[C1]] { -// CHECK-DAG: %[[FLOORK:.+]] = affine.apply #[[MAP0]](%[[K]]) -// CHECK-DAG: %[[MODK:.+]] = affine.apply #[[MAP1]](%[[K]]) -// CHECK-DAG: %[[FLOORC:.+]] = affine.apply #[[MAP2]](%[[C]]) -// CHECK-DAG: %[[MODC:.+]] = affine.apply #[[MAP3]](%[[C]]) -// CHECK: %[[VAL:.+]] = memref.load %[[ARG1]][%[[FLOORC]], %[[FLOORK]], %[[MODK]], %[[MODC]]] : memref<32x4x32x8xf32> -// CHECK: memref.store %[[VAL]], %[[ARG0]][%[[K]], %[[C]]] : memref<128x256xf32> -// CHECK: } -// CHECK: } - -// ----- - -func.func @NPQK_to_NKPQk(%arg0: memref<1x56x56x64xf32>, %arg1: memref<1x2x56x56x32xf32>) { - iree_linalg_ext.pack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %arg1 : (memref<1x56x56x64xf32> memref<1x2x56x56x32xf32>) - return -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 32 + d1)> -// CHECK: func.func @NPQK_to_NKPQk -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C56:.+]] = arith.constant 56 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK: scf.for %[[N:.+]] = %[[C0]] to %[[C1]] step %[[C1]] { -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { -// CHECK: scf.for %[[P:.+]] = %[[C0]] to %[[C56]] step %[[C1]] { -// CHECK: scf.for %[[Q:.+]] = %[[C0]] to %[[C56]] step %[[C1]] { -// CHECK: scf.for %[[k:.+]] = %[[C0]] to %[[C32]] step %[[C1]] { -// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP0]](%[[K]], %[[k]]) -// CHECK: %[[VAL:.+]] = memref.load %[[INPUT]][%[[N]], %[[P]], %[[Q]], %[[APPLY]]] : memref<1x56x56x64xf32> -// CHECK: memref.store %[[VAL]], %[[OUTPUT]][%[[N]], %[[K]], %[[P]], %[[Q]], %[[k]]] : memref<1x2x56x56x32xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } - -// ----- - -func.func @unpack(%arg0: memref<1x4x6x6x2xf32>, %arg1: memref<1x6x6x8xf32>) { - iree_linalg_ext.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %arg1 : (memref<1x4x6x6x2xf32> memref<1x6x6x8xf32>) - return -} - -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0 floordiv 2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 2)> -// CHECK: func.func @unpack( -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C1]] step %[[C1]] { -// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C6]] step %[[C1]] { -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C6]] step %[[C1]] { -// CHECK: scf.for %[[L:.+]] = %[[C0]] to %[[C8]] step %[[C1]] { -// CHECK: %[[APPLY_TILE:.+]] = affine.apply #[[MAP]](%[[L]]) -// CHECK: %[[APPLY_LOOP:.+]] = affine.apply #[[MAP1]](%[[L]]) -// CHECK: %[[LOAD:.+]] = memref.load %[[INPUT]][%[[I]], %[[APPLY_TILE]], %[[J]], %[[K]], %[[APPLY_LOOP]]] : memref<1x4x6x6x2xf32> -// CHECK: memref.store %[[LOAD]], %[[OUTPUT]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<1x6x6x8xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fold_into_pack_unpack_ops.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fold_into_pack_unpack_ops.mlir deleted file mode 100644 index 3c7bd18a29db..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fold_into_pack_unpack_ops.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: iree-dialects-opt -iree-linalg-ext-fold-into-pack-unpack-ops %s | FileCheck %s - -func.func @fold_unpack_slice(%arg0 : tensor, %arg1 : tensor, - %arg2 : index, %arg3 : index) -> tensor { - %0 = iree_linalg_ext.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 - : (tensor tensor) -> tensor - %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor to tensor - return %1 : tensor -} -// CHECK: func @fold_unpack_slice( -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index -// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] -// CHECK-SAME: into %[[INIT]] -// CHECK: return %[[UNPACK]] - -// ----- - -func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor, %arg1 : tensor, - %arg2 : index, %arg3 : index, %arg4 : index) -> tensor { - %0 = iree_linalg_ext.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 - : (tensor tensor) -> tensor - %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor to tensor - return %1 : tensor -} -// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset( -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK: tensor.extract_slice %[[UNPACK]] - -// ----- - -func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor, %arg1 : tensor, - %arg2 : index, %arg3 : index, %arg4 : index) -> tensor { - %0 = iree_linalg_ext.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1 - : (tensor tensor) -> tensor - %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor to tensor - return %1 : tensor -} -// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride( -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK: tensor.extract_slice %[[UNPACK]] diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir deleted file mode 100644 index 45b26ef1baef..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s - -// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-(d0 * s0) + s1, s0)> -// CHECK-DAG: #[[$ID1_MAP:.*]] = affine_map<(d0) -> (d0)> -#map0 = affine_map<(d0)[s0] -> (d0 ceildiv s0)> -#map1 = affine_map<(d0)[s0] -> (d0 * s0)> -#map2 = affine_map<(d0, d1) -> (d0 - d1)> -#map3 = affine_map<(d0, d1) -> (d0, d1)> -#map4 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: func.func @static_tile -// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index -// CHECK-SAME: %[[IN:[0-9a-z]+]]: memref -// CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref -func.func @static_tile(%arg0: index, %arg1: memref, %arg2: memref) { - %cst = arith.constant 4.200000e+01 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg2, %c0 : memref - %1 = affine.apply #map0(%0)[%arg0] - - // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref - // CHECK: %[[group:.*]] = async.create_group {{.*}}: !async.group - // CHECK: scf.for %[[IV:.*]] = {{.*}} - // CHECK: %[[token:.*]] = async.execute { - // CHECK: subview - // CHECK: subview - // CHECK: linalg.generic - // CHECK: async.yield - // CHECK: } - // CHECK: async.add_to_group %[[token]], %[[group]] : !async.token - // CHECK: } - // CHECK: async.await_all %[[group]] - scf.foreach_thread (%arg3) in (%1) shared_outs() -> () { - %3 = affine.apply #map1(%arg3)[%arg0] - %4 = affine.apply #map2(%0, %3) - %5 = affine.min #map3(%4, %arg0) - - %6 = memref.subview %arg2[%3] [%5] [%c1] : memref to memref> - %7 = memref.subview %arg1[%3] [%5] [1] : memref to memref> - - linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]} - ins(%7 : memref>) outs(%6 : memref>) { - ^bb0(%arg4: f32, %arg5: f32): // no predecessors - %9 = arith.mulf %arg4, %cst : f32 - linalg.yield %9 : f32 - } - } - return -} - -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%module_op: !pdl.operation): - %0 = transform.structured.match ops{["scf.foreach_thread"]} in %module_op - %1 = foreach_thread_to_async %0 -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir deleted file mode 100644 index a261e9e8b75f..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s - -// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-(d0 * s0) + s1, s0)> -// CHECK-DAG: #[[$ID1_MAP:.*]] = affine_map<(d0) -> (d0)> -#map0 = affine_map<(d0)[s0] -> (d0 ceildiv s0)> -#map1 = affine_map<(d0)[s0] -> (d0 * s0)> -#map2 = affine_map<(d0, d1) -> (d0 - d1)> -#map3 = affine_map<(d0, d1) -> (d0, d1)> -#map4 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: func.func @static_tile_buffers -// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index -// CHECK-SAME: %[[IN:[0-9a-z]+]]: memref -// CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref -func.func @static_tile_buffers(%arg0: index, %arg1: memref, %arg2: memref) { - %cst = arith.constant 4.200000e+01 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg2, %c0 : memref - %1 = affine.apply #map0(%0)[%arg0] - - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref - // CHECK: scf.for %[[IV:.*]] = {{.*}} step %[[C1]] { - scf.foreach_thread (%arg3) in (%1) shared_outs() -> () { - %3 = affine.apply #map1(%arg3)[%arg0] - %4 = affine.apply #map2(%0, %3) - %5 = affine.min #map3(%4, %arg0) - - %6 = memref.subview %arg2[%3] [%5] [%c1] : memref to memref> - %7 = memref.subview %arg1[%3] [%5] [1] : memref to memref> - - linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]} - ins(%7 : memref>) outs(%6 : memref>) { - ^bb0(%arg4: f32, %arg5: f32): // no predecessors - %9 = arith.mulf %arg4, %cst : f32 - linalg.yield %9 : f32 - } - - // Nothing is yielded, skip the terminator. - // CHECK-NOT: scf.yield - } - return -} - -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%module_op: !pdl.operation): - %0 = transform.structured.match ops{["scf.foreach_thread"]} in %module_op - %1 = foreach_thread_to_scf_for %0 -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir deleted file mode 100644 index b21c5d3702f1..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir +++ /dev/null @@ -1,122 +0,0 @@ -// RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s - -#map0 = affine_map<()[s0] -> (64 ceildiv s0)> -#map1 = affine_map<(d0)[s0] -> (d0 * s0)> -#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)> - -module { - // CHECK-LABEL: func.func @fuse_static - // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index - // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32> - // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32> - func.func @fuse_static(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { - %cst = arith.constant 4.200000e+01 : f32 - %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<64xf32>) -> tensor<64xf32> - %1 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<64xf32>) -> tensor<64xf32> - - %2 = affine.apply #map0()[%arg0] - // CHECK: scf.foreach_thread - %3 = scf.foreach_thread (%arg3) in (%2) shared_outs(%O = %arg2) -> (tensor<64xf32>) { - // CHECK: %[[OFFSET:.*]] = affine.apply - // CHECK: %[[SIZE:.*]] = affine.min - %4 = affine.apply #map1(%arg3)[%arg0] - %5 = affine.min #map2(%arg3)[%arg0] - %6 = tensor.extract_slice %0[%4] [%5] [1] : tensor<64xf32> to tensor - - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%[[OFFSET]]] [%[[SIZE]]] [{{.*}}] - // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] - // CHECK: %[[T2:.*]] = tensor.extract_slice %[[OUT]][%[[OFFSET]]] [%[[SIZE]]] [{{.*}}] - // CHECK: %[[T3:.*]] = linalg.fill {{.*}} outs(%[[T2]] - %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<64xf32> to tensor - - // CHECK: %[[T4:.*]] = linalg.elemwise_unary ins(%[[T1]] {{.*}} outs(%[[T3]] - %8 = linalg.elemwise_unary ins(%6 : tensor) outs(%7 : tensor) -> tensor - scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %8 into %O[%4] [%5] [1] : tensor into tensor<64xf32> - } - } - func.return %3 : tensor<64xf32> - } - - transform.with_pdl_patterns { - ^bb0(%arg0: !pdl.operation): - pdl.pattern @match_elemwise : benefit(1) { - %0 = operands - %1 = types - %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range) -> (%1 : !pdl.range) - rewrite %2 with "transform.dialect" - } - pdl.pattern @match_in_parallel : benefit(1) { - %0 = operands - %1 = types - %2 = operation "scf.foreach_thread"(%0 : !pdl.range) -> (%1 : !pdl.range) - rewrite %2 with "transform.dialect" - } - transform.structured.canonicalized_sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @match_elemwise in %arg1 : (!pdl.operation) -> !pdl.operation - %1, %fusedOps:2 = fuse_producers %0 {operands_to_fuse=[0, 1]} - } - } -} - -// ----- - -#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> -#map1 = affine_map<(d0)[s0] -> (d0 * s0)> -#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> - -module { - // CHECK-LABEL: func.func @fuse_dynamic - // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index - // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor - // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor - func.func @fuse_dynamic(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { - %cst = arith.constant 4.200000e+01 : f32 - %c0 = arith.constant 0 : index - %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor) -> tensor - // TODO: Choosing %arg2 here complicates the size computation. - %d0 = tensor.dim %arg1, %c0 : tensor - %1 = affine.apply #map0()[%d0, %arg0] - // CHECK: scf.foreach_thread - %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%O = %arg2) -> (tensor) { - // CHECK: %[[OFFSET:.*]] = affine.apply - // CHECK: %[[SIZE:.*]] = affine.min - %3 = affine.apply #map1(%arg3)[%arg0] - %4 = affine.min #map2(%arg3)[%d0, %arg0] - %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor to tensor - - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%[[OFFSET]]] [%[[SIZE]]] [{{.*}}] - // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] - %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor - - // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]] - %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor - scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %7 into %O[%3] [%4] [1] : tensor into tensor - } - } - func.return %2 : tensor - } - - transform.with_pdl_patterns { - ^bb0(%arg0: !pdl.operation): - pdl.pattern @match_elemwise : benefit(1) { - %0 = operands - %1 = types - %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range) -> (%1 : !pdl.range) - rewrite %2 with "transform.dialect" - } - pdl.pattern @match_in_parallel : benefit(1) { - %0 = operands - %1 = types - %2 = operation "scf.foreach_thread"(%0 : !pdl.range) -> (%1 : !pdl.range) - rewrite %2 with "transform.dialect" - } - transform.structured.canonicalized_sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @match_elemwise in %arg1 : (!pdl.operation) -> !pdl.operation - %1, %fusedOps = fuse_producers %0 {operands_to_fuse=[0]} - } - } -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir deleted file mode 100644 index 038f29350c59..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir +++ /dev/null @@ -1,694 +0,0 @@ -// RUN: iree-dialects-opt --split-input-file --verify-diagnostics %s - -func.func @sort_invalid_dimension(%arg0: tensor<128xi32>) -> tensor<128xi32> { - // expected-error @+1 {{dimension must be within (0, 1]}} - %0 = iree_linalg_ext.sort dimension(1) - outs(%arg0 : tensor<128xi32>) { - ^bb0(%arg1: i32, %arg2: i32): // no predecessors - %1 = arith.cmpi sgt, %arg1, %arg2 : i32 - iree_linalg_ext.yield %1 : i1 - } -> tensor<128xi32> - return %0 : tensor<128xi32> -} - -// ----- - -func.func @sort_mismatch_rank(%arg0: tensor, %arg1: tensor) - -> (tensor, tensor) { - // expected-error @+1 {{expected operand 1 to be rank 2, same as other operands}} - %0:2 = iree_linalg_ext.sort dimension(0) - outs(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %1 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %1 : i1 - } -> tensor, tensor - return %0#0, %0#1 : tensor, tensor -} - -// ----- - -func.func @sort_mismatch_shape(%arg0: tensor, %arg1: tensor<42xf32>) - -> (tensor, tensor<42xf32>) { - // expected-error @+1 {{expected operand 1 to have same shape as other operands}} - %0:2 = iree_linalg_ext.sort dimension(0) - outs(%arg0, %arg1 : tensor, tensor<42xf32>) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %1 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %1 : i1 - } -> tensor, tensor<42xf32> - return %0#0, %0#1 : tensor, tensor<42xf32> -} - -// ----- - -func.func @scatter_extra_outputs( - %update : tensor, %indices : tensor, - %original : tensor) -> (tensor, tensor) { - // expected-error @+1 {{expected number of outputs to be same as the number of results}} - %0, %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor, tensor - return %0, %1 : tensor, tensor -} - -// ----- - -func.func @scatter_mistmatch_dim_map_entries( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{invalid number of dimension map entries}} - %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_duplicate_dim_map_entries( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{dimension map is invalid}} - %0 = iree_linalg_ext.scatter dimension_map = [1, 1] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_invalid_dim_map_entries( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{dimension map is invalid}} - %0 = iree_linalg_ext.scatter dimension_map = [2] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_output_type_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor<4x?xf32> { - // expected-error @+1 {{expected type of `outs` operand #0 'tensor' to be same as result type 'tensor<4x?xf32>'}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor<4x?xf32> - return %0 : tensor<4x?xf32> -} - -// ----- - -func.func @scatter_dim_mismatch( - %update : tensor, %indices : tensor<48x1xi32>, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor<48x1xi32>) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_dim_mismatch( - %update : tensor<64x?xf32>, %indices : tensor<48x1xi32>, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_dim_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{op update value rank exceeds the rank of the original value}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_dim_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{op shape of update value dim#1 exceeds original value at dim#1}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_region_type_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{expected region to have scalar argument of integer or float types}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: index, %arg2: index): - %1 = arith.addi %arg1, %arg2 : index - %2 = arith.index_cast %1 : index to i32 - iree_linalg_ext.yield %2 : i32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_region_type_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i32): - %1 = arith.trunci %arg1 : i64 to i32 - %2 = arith.addi %1, %arg2 : i32 - iree_linalg_ext.yield %2 : i32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_region_type_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i32, %arg2: i64): - %1 = arith.trunci %arg2 : i64 to i32 - %2 = arith.addi %1, %arg1 : i32 - iree_linalg_ext.yield %2 : i32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_region_type_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i32, %arg2: i64): - %1 = arith.extsi %arg1 : i32 to i64 - %2 = arith.addi %1, %arg2 : i64 - iree_linalg_ext.yield %2 : i64 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_region_type_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{expected region to have two arguments}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64): - %1 = arith.addi %arg1, %arg2 : i64 - iree_linalg_ext.yield %1 : i64 - } -> tensor - return %0 : tensor -} - - -// ----- - -func.func @scatter_yield_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i64): - %1 = arith.addi %arg1, %arg2 : i64 - %2 = arith.trunci %1 : i64 to i32 - // expected-error @+1 {{mismatch in type of yielded value 'i32' and argument of the region 'i64'}} - iree_linalg_ext.yield %2 : i32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_yield_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i64): - %1 = arith.addi %arg1, %arg2 : i64 - %2 = arith.trunci %1 : i64 to i32 - // expected-error @+1 {{expected region to yield a single value}} - iree_linalg_ext.yield %1, %2 : i64, i32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_index_depth_dynamic( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{expected index depth is static}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i64): - %1 = arith.addi %arg1, %arg2 : i64 - %2 = arith.trunci %1 : i64 to i32 - iree_linalg_ext.yield %1, %2 : i64, i32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @scatter_original_rank_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{op index depth and update value does not cover rank of original value}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i64): - %1 = arith.addi %arg1, %arg2 : i64 - %2 = arith.trunci %1 : i64 to i32 - iree_linalg_ext.yield %1, %2 : i64, i32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @reverse_diff_element_type(%arg0: tensor<3x5xi32>) -> tensor<3x5xf32> { - %init = tensor.empty() : tensor<3x5xf32> - // expected-error @+1 {{expected input/output element types to be identical}} - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xf32>) : tensor<3x5xf32> - return %0 : tensor<3x5xf32> -} - -// ----- - -func.func @reverse_diff_shape(%arg0: tensor<3x5xi32>) -> tensor<3x6xi32> { - %init = tensor.empty() : tensor<3x6xi32> - // expected-error @+1 {{incompatible input/output shapes}} - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x6xi32>) : tensor<3x6xi32> - return %0 : tensor<3x6xi32> -} - -// ----- - -func.func @reverse_dup_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - // expected-error @+1 {{expected dimensions numbers are all unique}} - %0 = iree_linalg_ext.reverse - dimensions(dense<[0, 0]> : tensor<2xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xi32>) : tensor<3x5xi32> - return %0 : tensor<3x5xi32> -} - -// ----- - -func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { - // expected-error@+1 {{expected one or two input operands}} - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_indices, %input_indices, %input_indices : tensor<2x10xi32>, tensor<2x10xi32>, tensor<2x10xi32>) - outs(%out_values, %out_indices : tensor<2x3xf32>, tensor<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<2x3xf32>, tensor<2x3xi32> - return %0#0, %0#1 : tensor<2x3xf32>, tensor<2x3xi32> -} - -// ----- - -func.func @topk_invalid(%input_values: tensor<2x10xi32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { - // expected-error@+1 {{expected input/output value types to be identical}} - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : tensor<2x10xi32> , tensor<2x10xi32>) - outs(%out_values, %out_indices : tensor<2x3xf32>, tensor<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<2x3xf32>, tensor<2x3xi32> - return %0#0, %0#1 : tensor<2x3xf32>, tensor<2x3xi32> -} - -// ----- - -func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xf32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { - // expected-error@+1 {{expected input/output indices types to be int}} - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : tensor<2x10xf32> , tensor<2x10xf32>) - outs(%out_values, %out_indices : tensor<2x3xf32>, tensor<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<2x3xf32>, tensor<2x3xi32> - return %0#0, %0#1 : tensor<2x3xf32>, tensor<2x3xi32> -} - -// ----- - -func.func @topk_invalid(%input_values: tensor<10x2x10xf32>, %input_indices: tensor<10x2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { - // expected-error@+1 {{expected input/output to have the same rank}} - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : tensor<10x2x10xf32> , tensor<10x2x10xi32>) - outs(%out_values, %out_indices : tensor<2x3xf32>, tensor<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<2x3xf32>, tensor<2x3xi32> - return %0#0, %0#1 : tensor<2x3xf32>, tensor<2x3xi32> -} - -// ----- - -func.func @topk_invalid(%input_values: tensor<3x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { - // expected-error@+1 {{input indices/values shape must match}} - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : tensor<3x10xf32> , tensor<2x10xi32>) - outs(%out_values, %out_indices : tensor<2x3xf32>, tensor<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<2x3xf32>, tensor<2x3xi32> - return %0#0, %0#1 : tensor<2x3xf32>, tensor<2x3xi32> -} - -// ----- - -func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<3x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<3x3xf32>, tensor<2x3xi32>) { - // expected-error@+1 {{output indices/values shape must match}} - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : tensor<2x10xf32> , tensor<2x10xi32>) - outs(%out_values, %out_indices : tensor<3x3xf32>, tensor<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<3x3xf32>, tensor<2x3xi32> - return %0#0, %0#1 : tensor<3x3xf32>, tensor<2x3xi32> -} - -// ----- - -func.func @topk_invalid(%input_values: tensor<3x10xf32>, %input_indices: tensor<3x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { - // expected-error@+1 {{incompatible input/output shapes}} - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : tensor<3x10xf32> , tensor<3x10xi32>) - outs(%out_values, %out_indices : tensor<2x3xf32>, tensor<2x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<2x3xf32>, tensor<2x3xi32> - return %0#0, %0#1 : tensor<2x3xf32>, tensor<2x3xi32> -} - -// ----- - -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = iree_linalg_ext.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : (tensor<256x128xf32> tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> -} - -// ----- - -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> { - // expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}} - %0 = iree_linalg_ext.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : (tensor<256x128xf32> tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> - return %0 : tensor<8x8x16x33xf32> -} - -// ----- - -func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> { - // expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}} - %0 = iree_linalg_ext.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> - return %0 : tensor<2x8x8x2xf32> -} - -// ----- - -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{invalid inner_dims_pos vector}} - %0 = iree_linalg_ext.pack %input inner_dims_pos = [2, 0] inner_tiles = [2, 2] into %output : (tensor<256x128xf32> tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> -} - -// ----- - -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{invalid tile factor}} - %0 = iree_linalg_ext.pack %input inner_dims_pos = [1, 0] inner_tiles = [0, 2] into %output : (tensor<256x128xf32> tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> -} - -// ----- - -// duplicate element in `inner_dims_pos`, fail. -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{invalid inner_dims_pos vector}} - %0 = iree_linalg_ext.pack %input inner_dims_pos = [1, 1] inner_tiles = [2, 2] into %output : (tensor<256x128xf32> tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> -} - -// ----- - -func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = iree_linalg_ext.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : (tensor<8x8x32x16xf32> tensor<256x128xf32>) -> tensor<256x128xf32> - return %0 : tensor<256x128xf32> -} - -// ----- - -// duplicate element in `outer_dims_perm`, fail. -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{invalid outer_dims_perm vector}} - %0 = iree_linalg_ext.pack %input outer_dims_perm = [1, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %output : (tensor<256x128xf32> tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> -} - -// ----- - -// duplicate element in `outer_dims_perm`, fail. -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{invalid outer_dims_perm vector}} - %0 = iree_linalg_ext.unpack %output outer_dims_perm = [1, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : (tensor<8x8x32x16xf32> tensor<256x128xf32>) -> tensor<256x128xf32> - return %0 : tensor<256x128xf32> -} - -// ----- - -// `outer_dims_perm` is out of bound. -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{invalid outer_dims_perm vector}} - %0 = iree_linalg_ext.unpack %output outer_dims_perm = [2, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : (tensor<8x8x32x16xf32> tensor<256x128xf32>) -> tensor<256x128xf32> - return %0 : tensor<256x128xf32> -} - -// ----- -func.func @pack_mismatch_inner_tile_size_and_output_shape( - %input : tensor, %output : tensor) -> tensor { - // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} - %0 = iree_linalg_ext.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output - : (tensor tensor) -> tensor - return %0 : tensor -} - -// ----- - -func.func @unpack_mismatch_inner_tile_size_and_output_shape( - %input : tensor, %output : tensor) -> tensor { - // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} - %0 = iree_linalg_ext.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output - : (tensor tensor) -> tensor - return %0 : tensor -} - -// ----- - -func.func @illegal_set_encoding_op_with_no_result_encoding(%arg0 : tensor) -> tensor { - // expected-error @+1 {{result of set_encoding op expected to have a valid tensor encoding}} - %0 = iree_linalg_ext.set_encoding %arg0: tensor -> tensor - return %0 : tensor -} - -// ----- - -func.func @illegal_set_encoding_op_with_source_encoding(%arg0 : tensor>) -> tensor { - // expected-error @+1 {{source of set_encoding op cannot have a tensor encoding}} - %0 = iree_linalg_ext.set_encoding %arg0: tensor> -> tensor - return %0 : tensor -} - -// ----- - -func.func @illegal_set_encoding_op_with_unknown_encoding(%arg0 : tensor) -> tensor { - // expected-error @+1 {{result of set_encoding op expected to have a valid tensor encoding}} - %0 = iree_linalg_ext.set_encoding %arg0: tensor -> tensor - return %0 : tensor -} - -// ----- - -func.func @illegal_set_encoding_op_with_rank_change(%arg0 : tensor) -> tensor> { - // expected-error @+1 {{cannot change the rank of the tensor}} - %0 = iree_linalg_ext.set_encoding %arg0: tensor -> tensor> - return %0 : tensor> -} - -// ----- - -func.func @illegal_set_encoding_op_with_shape_change(%arg0 : tensor<10x20xf32>) -> tensor<20x30xf32, #iree_linalg_ext.encoding> { - // expected-error @+1 {{expected to preserve the logical shape of the tensor}} - %0 = iree_linalg_ext.set_encoding %arg0: tensor<10x20xf32> -> tensor<20x30xf32, #iree_linalg_ext.encoding> - return %0 : tensor<20x30xf32, #iree_linalg_ext.encoding> -} - -// ----- - -func.func @illegal_unset_encoding_op_with_no_source_encoding(%arg0 : tensor) -> tensor { - // expected-error @+1 {{source of unset_encoding op expected to have a valid tensor encoding}} - %0 = iree_linalg_ext.unset_encoding %arg0: tensor -> tensor - return %0 : tensor -} - -// ----- - -func.func @illegal_unset_encoding_op_with_result_encoding(%arg0 : tensor) -> tensor> { - // expected-error @+1 {{result of unset_encoding op cannot have a tensor encoding}} - %0 = iree_linalg_ext.unset_encoding %arg0: tensor -> tensor> - return %0 : tensor> -} - -// ----- - -func.func @illegal_unset_encoding_op_with_unknown_encoding(%arg0 : tensor) -> tensor { - // expected-error @+1 {{source of unset_encoding op expected to have a valid tensor encoding}} - %0 = iree_linalg_ext.unset_encoding %arg0: tensor -> tensor - return %0 : tensor -} - -// ----- - -func.func @illegal_unset_encoding_op_with_rank_change(%arg0 : tensor>) -> tensor { - // expected-error @+1 {{cannot change the rank of the tensor}} - %0 = iree_linalg_ext.unset_encoding %arg0: tensor> -> tensor - return %0 : tensor -} - -// ----- - -func.func @illegal_unset_encoding_op_with_shape_change(%arg0 : tensor<20x30xf32, #iree_linalg_ext.encoding>) -> tensor<10x20xf32> { - // expected-error @+1 {{expected to preserve the logical shape of the tensor}} - %0 = iree_linalg_ext.unset_encoding %arg0: tensor<20x30xf32, #iree_linalg_ext.encoding> -> tensor<10x20xf32> - return %0 : tensor<10x20xf32> -} - -// ----- - -func.func @illegal_winograd_input_shape(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6x6x32xf32> { - %0 = tensor.empty() : tensor<8x8x1x6x6x32xf32> - // expected-error @+1 {{incompatible output shape}} - %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<1x10x10x32xf32>) outs(%0 : tensor<8x8x1x6x6x32xf32>) -> tensor<8x8x1x6x6x32xf32> - return %1 : tensor<8x8x1x6x6x32xf32> -} - -// ----- - -func.func @illegal_winograd_input_rank(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6xf32> { - %0 = tensor.empty() : tensor<8x8x1x6xf32> - // expected-error @+1 {{expected output rank to be equal to input rank + 2}} - %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<1x10x10x32xf32>) outs(%0 : tensor<8x8x1x6xf32>) -> tensor<8x8x1x6xf32> - return %1 : tensor<8x8x1x6xf32> -} - -// ----- - -func.func @illegal_winograd_output_shape(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x8x8x32xf32> { - %0 = tensor.empty() : tensor<1x8x8x32xf32> - // expected-error @+1 {{incompatible output shape}} - %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) - kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<8x8x1x2x2x32xf32>) outs(%0 : tensor<1x8x8x32xf32>) -> tensor<1x8x8x32xf32> - return %1 : tensor<1x8x8x32xf32> -} - -// ----- diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir deleted file mode 100644 index 35e2053fec3a..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir +++ /dev/null @@ -1,166 +0,0 @@ -// RUN: iree-dialects-opt --iree-linalg-ext-materialize-encoding -cse -split-input-file %s | FileCheck %s - -func.func @pack_unpack_gemm_lhs(%arg0 : tensor) -> tensor { - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor - return %1 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> -// CHECK: func @pack_unpack_gemm_lhs( -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]] -// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]] -// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor -// CHECK: %[[PACK:.+]] = iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[PACK_DEST]] -// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]] -// CHECK: return %[[UNPACK]] - -// ----- - -func.func @pack_unpack_gemm_rhs(%arg0 : tensor) -> tensor { - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor - return %1 : tensor -} -// CHECK-LABEL: func @pack_unpack_gemm_rhs( -// CHECK: linalg_ext.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8] -// CHECK: linalg_ext.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8] - -// ----- - -func.func @pack_unpack_gemm_rhs_transpose(%arg0 : tensor) -> tensor { - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor - return %1 : tensor -} -// CHECK-LABEL: func @pack_unpack_gemm_rhs_transpose( -// CHECK: linalg_ext.pack %{{.+}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] -// CHECK: linalg_ext.unpack %{{.+}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] - -// ----- - -func.func @pack_unpack_gemm_result(%arg0 : tensor) -> tensor { - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor - return %1 : tensor -} -// CHECK-LABEL: func @pack_unpack_gemm_result( -// CHECK: linalg_ext.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 8] -// CHECK: linalg_ext.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 8] - -// ----- - -func.func @pack_gemm(%arg0 : tensor<100x250xf32>, %arg1 : tensor<250x500xf32>, %arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> { - %pad_value = arith.constant 0.0 : f32 - %pad_lhs = tensor.pad %arg0 low[0, 0] high[4, 2] { - ^bb0(%b0: index, %b1 : index): - tensor.yield %pad_value : f32 - } : tensor<100x250xf32> to tensor<104x252xf32> - %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding> - %pad_rhs = tensor.pad %arg1 low[0, 0] high[2, 4] { - ^bb0(%b0: index, %b1 : index): - tensor.yield %pad_value : f32 - } : tensor<250x500xf32> to tensor<252x504xf32> - %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x504xf32> -> tensor<252x504xf32, #iree_linalg_ext.encoding> - %pad_output = tensor.pad %arg2 low[0, 0] high[4, 4] { - ^bb0(%b0: index, %b1 : index): - tensor.yield %pad_value : f32 - } : tensor<100x500xf32> to tensor<104x504xf32> - %output = iree_linalg_ext.set_encoding %pad_output : tensor<104x504xf32> -> tensor<104x504xf32, #iree_linalg_ext.encoding> - %gemm_packed = linalg.matmul ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding>, tensor<252x504xf32, #iree_linalg_ext.encoding>) - outs(%output : tensor<104x504xf32, #iree_linalg_ext.encoding>) -> tensor<104x504xf32, #iree_linalg_ext.encoding> - %gemm = iree_linalg_ext.unset_encoding %gemm_packed : tensor<104x504xf32, #iree_linalg_ext.encoding> -> tensor<104x504xf32> - %result = tensor.extract_slice %gemm[0, 0] [100, 500] [1, 1] : tensor<104x504xf32> to tensor<100x500xf32> - return %result : tensor<100x500xf32> -} -// CHECK: func @pack_gemm( -// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32> -// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32> -// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32> -// CHECK: %[[CST:.+]] = arith.constant 0.0 -// CHECK: %[[INIT_LHS:.+]] = tensor.empty() : tensor<13x63x8x4xf32> -// CHECK: %[[PACK_LHS:.+]] = iree_linalg_ext.pack %[[ARG0]] padding_value(%[[CST]] : f32) -// CHECK-SAME: into %[[INIT_LHS]] -// CHECK: %[[INIT_RHS:.+]] = tensor.empty() : tensor<63x63x8x4xf32> -// CHECK: %[[PACK_RHS:.+]] = iree_linalg_ext.pack %[[ARG1]] padding_value(%[[CST]] : f32) -// CHECK-SAME: into %[[INIT_RHS]] -// CHECK: %[[INIT_RESULT:.+]] = tensor.empty() : tensor<13x63x8x8xf32> -// CHECK: %[[PACK_RESULT:.+]] = iree_linalg_ext.pack %[[ARG2]] padding_value(%[[CST]] : f32) -// CHECK-SAME: into %[[INIT_RESULT]] -// CHECK: %[[MMT4D:.+]] = linalg.mmt4d -// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : -// CHECK-SAME: outs(%[[PACK_RESULT]] : -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]] -// CHECK: return %[[UNPACK]] - -// ----- - -func.func @pack_gemm_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> - %2 = iree_linalg_ext.set_encoding %arg2 : tensor -> tensor> - %3 = linalg.matmul ins(%0, %1 : tensor>, tensor>) - outs(%2 : tensor>) -> tensor> - %4 = iree_linalg_ext.unset_encoding %3 : tensor> -> tensor - return %4 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> -// CHECK: func @pack_gemm_dynamic( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor -// CHECK: %[[PACK_LHS:.+]] = iree_linalg_ext.pack %[[ARG0]] -// CHECK: %[[PACK_RHS:.+]] = iree_linalg_ext.pack %[[ARG1]] -// CHECK: %[[PACK_RESULT:.+]] = iree_linalg_ext.pack %[[ARG2]] -// CHECK: %[[MMT4D:.+]] = linalg.mmt4d -// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : -// CHECK-SAME: outs(%[[PACK_RESULT]] : -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]] -// CHECK: return %[[UNPACK]] - -// ----- - -func.func @pack_gemm_fill_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg1, %c1 : tensor - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> - %2 = tensor.empty(%d0, %d1) : tensor> - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor>) - -> tensor> - %4 = linalg.matmul ins(%0, %1 : tensor>, tensor>) - outs(%3 : tensor>) -> tensor> - %5 = iree_linalg_ext.unset_encoding %4 : tensor> -> tensor - return %5 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> -// CHECK: func @pack_gemm_fill_dynamic( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[OUT_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]] -// CHECK-DAG: %[[PACK_LHS:.+]] = iree_linalg_ext.pack %[[ARG0]] -// CHECK-DAG: %[[PACK_RHS:.+]] = iree_linalg_ext.pack %[[ARG1]] -// CHECK-DAG: %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor -// CHECK: %[[FILL:.+]] = linalg.fill -// CHECK-SAME: outs(%[[EMPTY]] : -// CHECK: %[[MMT4D:.+]] = linalg.mmt4d -// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : -// CHECK-SAME: outs(%[[FILL]] : -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]] -// CHECK: return %[[UNPACK]] diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_contraction_to_block_size.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_contraction_to_block_size.mlir deleted file mode 100644 index 9caaf6b1b162..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_contraction_to_block_size.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: iree-dialects-opt --pass-pipeline='builtin.module(iree-linalg-pad-contraction-to-block-size{rowAlignment=16 columnAlignment=32})' --split-input-file %s | FileCheck %s - -// CHECK-LABEL: @pad_matmul_static -// Full verification is done on this case. Others use reduced checks. -// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_4:.*]] = tensor.pad %arg0 low[0, 0] high[6, 12] { -// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index): -// CHECK: tensor.yield %[[VAL_3]] : f32 -// CHECK: } : tensor<250x500xf32> to tensor<256x512xf32> -// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_8:.*]] = tensor.pad %arg1 low[0, 0] high[12, 4] { -// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: index): -// CHECK: tensor.yield %[[VAL_7]] : f32 -// CHECK: } : tensor<500x1020xf32> to tensor<512x1024xf32> -// CHECK: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_12:.*]] = tensor.pad %arg2 low[0, 0] high[6, 4] { -// CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: index): -// CHECK: tensor.yield %[[VAL_11]] : f32 -// CHECK: } : tensor<250x1020xf32> to tensor<256x1024xf32> -// CHECK: %[[VAL_15:.*]] = linalg.matmul ins(%[[VAL_16:.*]], %[[VAL_17:.*]] : tensor<256x512xf32>, tensor<512x1024xf32>) outs(%[[VAL_18:.*]] : tensor<256x1024xf32>) -> tensor<256x1024xf32> -// CHECK: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_15]][0, 0] [250, 1020] [1, 1] : tensor<256x1024xf32> to tensor<250x1020xf32> -// CHECK: return %[[VAL_19]] : tensor<250x1020xf32> -func.func @pad_matmul_static(%arg0 : tensor<250x500xf32>, %arg1 : tensor<500x1020xf32>, - %arg2 : tensor<250x1020xf32>) -> tensor<250x1020xf32> { - %matmul = linalg.matmul - ins(%arg0, %arg1 : tensor<250x500xf32>, tensor<500x1020xf32>) - outs(%arg2 : tensor<250x1020xf32>) -> tensor<250x1020xf32> - return %matmul : tensor<250x1020xf32> -} - -// ----- -// CHECK-LABEL: @pad_matmul_noop -// CHECK-NOT: pad_tensor -// CHECK-NOT: extract_slice -func.func @pad_matmul_noop(%arg0 : tensor<256x512xf32>, %arg1 : tensor<512x1024xf32>, - %arg2 : tensor<256x1024xf32>) -> tensor<256x1024xf32> { - %matmul = linalg.matmul - ins(%arg0, %arg1 : tensor<256x512xf32>, tensor<512x1024xf32>) - outs(%arg2 : tensor<256x1024xf32>) -> tensor<256x1024xf32> - return %matmul : tensor<256x1024xf32> -} - -// ----- -// CHECK-LABEL: @pad_matmul_dynamic_row -// Should trigger row alignment (16). -// Pad LHS: -// CHECK: %[[LHS_DIM0:.*]] = arith.constant 0 : index -// CHECK: %[[LHS_DIM:.*]] = tensor.dim %arg0, %[[LHS_DIM0]] : tensor -// CHECK: %[[LHS_ALIGN:.*]] = arith.constant 16 : index -// CHECK: %[[LHS_DIM_ALIGNED:.*]] = iree_input.align %[[LHS_DIM]], %[[LHS_ALIGN]] : index -// CHECK: %[[LHS_ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[LHS_PADDED:.*]] = tensor.pad %arg0 low[0, 0] high{{\[}}%[[LHS_DIM_ALIGNED]], 0] { -// CHECK: } : tensor to tensor -// Pad Output: -// CHECK: %[[OUTPUT_PADDED:.*]] = tensor.pad %arg2 low[0, 0] high{{\[}}{{.*}}, 0] { -// CHECK: } : tensor to tensor -// Matmul: -// CHECK: %[[PADDED_RESULT:.*]] = linalg.matmul ins(%[[LHS_PADDED]], %arg1 : tensor, tensor<512x1024xf32>) outs(%[[OUTPUT_PADDED]] : tensor) -> tensor -// CHECK: %[[DIM0:.*]] = arith.constant 0 : index -// CHECK: %[[ORIG_DIM_VALUE:.*]] = tensor.dim %arg2, %[[DIM0]] -// CHECK: %[[RETURN:.*]] = tensor.extract_slice %[[PADDED_RESULT]][0, 0] {{\[}}%[[ORIG_DIM_VALUE]], 1024] [1, 1] : tensor to tensor -// CHECK: return %[[RETURN]] : tensor -func.func @pad_matmul_dynamic_row(%arg0 : tensor, %arg1 : tensor<512x1024xf32>, - %arg2 : tensor) -> tensor { - %matmul = linalg.matmul - ins(%arg0, %arg1 : tensor, tensor<512x1024xf32>) - outs(%arg2 : tensor) -> tensor - return %matmul : tensor -} - -// ----- -// CHECK-LABEL: @pad_matmul_dynamic_col -// Should trigger column alignment (32). -// Pad RHS: -// CHECK: %[[RHS_ALIGNMENT:.*]] = arith.constant 32 : index -// CHECK: %[[RHS_ALIGNED_DIM:.*]] = iree_input.align %{{.*}}, %[[RHS_ALIGNMENT]] : index -// CHECK: %[[RHS_PADDED:.*]] = tensor.pad %arg1 low[0, 0] high[0, %[[RHS_ALIGNED_DIM]]] { -// CHECK: } : tensor<512x?xf32> to tensor<512x?xf32> -// Pad Output: -// CHECK: %[[OUTPUT_ALIGNMENT:.*]] = arith.constant 32 : index -// CHECK: %[[OUTPUT_ALIGNED_DIM:.*]] = iree_input.align %{{.*}}, %[[OUTPUT_ALIGNMENT]] : index -// CHECK: %[[OUTPUT_PADDED:.*]] = tensor.pad %arg2 low[0, 0] high[0, %[[OUTPUT_ALIGNED_DIM]]] { -// CHECK: } : tensor<256x?xf32> to tensor<256x?xf32> -// Matmul: -// CHECK: %{{.*}} = linalg.matmul ins(%arg0, %[[RHS_PADDED]] : tensor<256x512xf32>, tensor<512x?xf32>) outs(%[[OUTPUT_PADDED]] : tensor<256x?xf32>) -> tensor<256x?xf32> -func.func @pad_matmul_dynamic_col(%arg0 : tensor<256x512xf32>, %arg1 : tensor<512x?xf32>, - %arg2 : tensor<256x?xf32>) -> tensor<256x?xf32> { - %matmul = linalg.matmul - ins(%arg0, %arg1 : tensor<256x512xf32>, tensor<512x?xf32>) - outs(%arg2 : tensor<256x?xf32>) -> tensor<256x?xf32> - return %matmul : tensor<256x?xf32> -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_tiling.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_tiling.mlir deleted file mode 100644 index 21f7af365067..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/pad_tiling.mlir +++ /dev/null @@ -1,41 +0,0 @@ -// RUN: iree-dialects-opt --iree-linalg-ext-tile --split-input-file %s | FileCheck %s -// XFAIL: * -// TODO: Re-enable when upstream tensor.pad op properly implements the tiling -// interface. - -func.func @pad_tensor(%arg0 : tensor, %arg1 : index, %arg2 : index, - %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor { - %0 = tensor.pad %arg0 low[%arg1, %arg2] high[%arg3, %arg4] { - ^bb0(%arg6 : index, %arg7 : index): - tensor.yield %arg5 : f32 - } {__internal_iree_linalg_transform__ = "tiling_input"} - : tensor to tensor - return %0 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)> -// CHECK: func.func @pad_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: f32 -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[UBY:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG3]], %[[D0]]] -// CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[UBX:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG4]], %[[D1]]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[UBY]] step %[[C10]] -// CHECK-SAME: iter_args(%[[ARG7:.+]] = %[[INIT]]) -// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[UBX]] step %[[C20]] -// CHECK-SAME: iter_args(%[[ARG9:.+]] = %[[ARG7]]) -// CHECK: %[[PAD_TILE:.+]] = scf.if -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD_TILE]] into %[[ARG9]] -// CHECK-SAME: [%[[IV0]], %[[IV1]]] -// CHECK: scf.yield %[[INSERT]] -// CHECK: scf.yield %[[YIELD]] -// CHECK: return %[[RESULT]] diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir deleted file mode 100644 index 41bd0f86e89c..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/resolve-shaped-type-result-dims.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: iree-dialects-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s - -func.func @pack_static(%arg0 : tensor<100x250xf32>) -> (index, index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = iree_linalg_ext.set_encoding %arg0 : tensor<100x250xf32> -> tensor<100x250xf32, #iree_linalg_ext.encoding> - %1 = tensor.dim %0, %c0 : tensor<100x250xf32, #iree_linalg_ext.encoding> - %2 = tensor.dim %0, %c1 : tensor<100x250xf32, #iree_linalg_ext.encoding> - return %1, %2 : index, index -} -// CHECK-LABEL: func @pack_static( -// CHECK-DAG: %[[C100:.+]] = arith.constant 100 : index -// CHECK-DAG: %[[C250:.+]] = arith.constant 250 : index -// CHECK: return %[[C100]], %[[C250]] - -// ----- - -func.func @pack_dynamic(%arg0 : tensor) -> (index, index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = tensor.dim %0, %c0 : tensor> - %2 = tensor.dim %0, %c1 : tensor> - return %1, %2 : index, index -} -// CHECK: func @pack_dynamic(%[[ARG0:.+]]: tensor) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: return %[[D0]], %[[D1]] diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir deleted file mode 100644 index 4cb706ca5927..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir +++ /dev/null @@ -1,1000 +0,0 @@ -// RUN: iree-dialects-opt --split-input-file %s | FileCheck %s - -// CHECK-LABEL: func.func @sort_tensor -// CHECK: iree_linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs({{.*}}) -// CHECK: iree_linalg_ext.yield -func.func @sort_tensor(%arg0: tensor<128xi32>) -> tensor<128xi32> { - %0 = iree_linalg_ext.sort - dimension(0) - outs(%arg0 : tensor<128xi32>) { - ^bb0(%arg1: i32, %arg2: i32): // no predecessors - %1 = arith.cmpi sgt, %arg1, %arg2 : i32 - iree_linalg_ext.yield %1 : i1 - } -> tensor<128xi32> - return %0 : tensor<128xi32> -} - -// ----- - -// CHECK-LABEL: func.func @sort_memref -// CHECK: iree_linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs({{.*}}) -// CHECK: iree_linalg_ext.yield -func.func @sort_memref(%arg0: memref<128xi32>) { - iree_linalg_ext.sort dimension(0) - outs(%arg0 : memref<128xi32>) { - ^bb0(%arg1: i32, %arg2: i32): // no predecessors - %0 = arith.cmpi sgt, %arg1, %arg2 : i32 - iree_linalg_ext.yield %0 : i1 - } - return -} - -// ----- - -func.func @sort_multi_result_tensor( - %arg0: tensor, %arg1: tensor) - -> (tensor, tensor) { - %0:2 = iree_linalg_ext.sort dimension(0) - outs(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %1 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %1 : i1 - } -> tensor, tensor - return %0#0, %0#1 : tensor, tensor -} -// CHECK-LABEL: func.func @sort_multi_result_tensor -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor -// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.sort dimension(0) -// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]] -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - -// ----- - -func.func @sort_multi_result_memref( - %arg0: memref, %arg1: memref) { - iree_linalg_ext.sort dimension(0) - outs(%arg0, %arg1 : memref, memref) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %1 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %1 : i1 - } - return -} -// CHECK-LABEL: func.func @sort_multi_result_memref -// CHECK-SAME: %[[ARG0:.+]]: memref -// CHECK-SAME: %[[ARG1:.+]]: memref -// CHECK: iree_linalg_ext.sort dimension(0) -// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]] - -// ----- - -func.func @scatter_tensor_dynamic( - %original: tensor, %indices: tensor, - %update: tensor) -> tensor { - %0 = iree_linalg_ext.scatter - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original: tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} -// CHECK-LABEL: func.func @scatter_tensor_dynamic( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : f32 -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_repeated_tensor_dynamic( - %original: tensor, %indices: tensor, - %update: tensor) -> tensor { - %0 = iree_linalg_ext.scatter - dimension_map = [0] - unique_indices(false) - ins(%update, %indices : tensor, tensor) - outs(%original: tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} -// CHECK-LABEL: func.func @scatter_repeated_tensor_dynamic( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0] -// CHECK-SAME: unique_indices(false) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : f32 -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_tensor_static( - %original: tensor<128x3xf32>, %indices: tensor<48x1xi32>, - %update: tensor<48x3xf32>) -> tensor<128x3xf32> { - %0 = iree_linalg_ext.scatter - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : tensor<48x3xf32>, tensor<48x1xi32>) - outs(%original: tensor<128x3xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor<128x3xf32> - return %0 : tensor<128x3xf32> -} -// CHECK-LABEL: func.func @scatter_tensor_static( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<128x3xf32> -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48x1xi32> -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32> -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK: dimension_map = [0] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : f32 -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_tensor_multi_index_depth( - %original: tensor<1x128x3xf32>, %indices: tensor<48x2xi32>, - %update: tensor<48x3xf32>) -> tensor<1x128x3xf32> { - %0 = iree_linalg_ext.scatter - dimension_map = [0, 1] - unique_indices(true) - ins(%update, %indices : tensor<48x3xf32>, tensor<48x2xi32>) - outs(%original: tensor<1x128x3xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor<1x128x3xf32> - return %0 : tensor<1x128x3xf32> -} -// CHECK-LABEL: func.func @scatter_tensor_multi_index_depth( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<1x128x3xf32> -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48x2xi32> -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32> -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0, 1] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : f32 -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_memref_dynamic( - %original: memref, %indices: memref, - %update: memref) { - iree_linalg_ext.scatter - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : memref, memref) - outs(%original: memref) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } - return -} -// CHECK-LABEL: func.func @scatter_memref_dynamic( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref -// CHECK: iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : f32 -// CHECK: return - -// ----- - -func.func @scatter_memref_static( - %original: memref<128x3xf32>, %indices: memref<48x1xi32>, - %update: memref<48x3xf32>) { - iree_linalg_ext.scatter - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : memref<48x3xf32>, memref<48x1xi32>) - outs(%original: memref<128x3xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } - return -} -// CHECK-LABEL: func.func @scatter_memref_static( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<128x3xf32> -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<48x1xi32> -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref<48x3xf32> -// CHECK: iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : f32 -// CHECK: return - -// ----- - -func.func @scatter_memref_multi_index_depth( - %original: memref<1x128x3xf32>, %indices: memref<48x2xi32>, - %update: memref<48x3xf32>) { - iree_linalg_ext.scatter - dimension_map = [0, 1] - unique_indices(true) - ins(%update, %indices : memref<48x3xf32>, memref<48x2xi32>) - outs(%original: memref<1x128x3xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } - return -} -// CHECK-LABEL: func.func @scatter_memref_multi_index_depth( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<1x128x3xf32> -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<48x2xi32> -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref<48x3xf32> -// CHECK: iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0, 1] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : f32 -// CHECK: return - -// ----- - -func.func @scatter_update_scalar_1D( - %original: tensor<8xi32>, %indices: tensor<3x1xi32>, - %updates: tensor<3xi32>) -> tensor<8xi32> { - %0 = iree_linalg_ext.scatter - dimension_map = [0] - unique_indices(true) - ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) - outs(%original : tensor<8xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - iree_linalg_ext.yield %arg0 : i32 - } -> tensor<8xi32> - return %0 : tensor<8xi32> -} -// CHECK-LABEL: func.func @scatter_update_scalar_1D( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]] -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : i32 -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_update_scalar_2D( - %original: tensor<4x3xi32>, %indices: tensor<3x2xi32>, - %updates: tensor<3xi32>) -> tensor<4x3xi32> { - %0 = iree_linalg_ext.scatter - dimension_map = [0, 1] - unique_indices(true) - ins(%updates, %indices : tensor<3xi32>, tensor<3x2xi32>) - outs(%original : tensor<4x3xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - iree_linalg_ext.yield %arg0 : i32 - } -> tensor<4x3xi32> - return %0 : tensor<4x3xi32> -} -// CHECK-LABEL: func.func @scatter_update_scalar_2D( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]] -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0, 1] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : i32 -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_update_slice_2D( - %original: tensor<4x3xi32>, %indices: tensor<1x1xi32>, - %updates: tensor<1x3xi32>) -> tensor<4x3xi32> { - %0 = iree_linalg_ext.scatter - dimension_map = [0] - unique_indices(true) - ins(%updates, %indices : tensor<1x3xi32>, tensor<1x1xi32>) - outs(%original : tensor<4x3xi32>) { - ^bb0(%arg0: i32, %arg1: i32): // no predecessors - iree_linalg_ext.yield %arg0 : i32 - } -> tensor<4x3xi32> - return %0 : tensor<4x3xi32> -} -// CHECK-LABEL: func.func @scatter_update_slice_2D( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]] -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: dimension_map = [0] -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: iree_linalg_ext.yield %{{.+}} : i32 -// CHECK: return %[[RESULT]] - -// ----- - -func.func @fft_tensor(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) - -> (tensor<1024xf32>, tensor<1024xf32>) { - %cst1 = arith.constant 1 : index - %0:2 = iree_linalg_ext.fft - ins(%cst1: index) - outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>) - : tensor<1024xf32>, tensor<1024xf32> - return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32> -} -// CHECK-LABEL: func.func @fft_tensor( -// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]] -// CHECK: %[[CST:.+]] = arith.constant 1 : index -// CHECK: %[[RES:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[CST]] : index) -// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : tensor<1024xf32>, tensor<1024xf32>) -// CHECK-SAME: : tensor<1024xf32>, tensor<1024xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1 - -// ----- - -func.func @fft_memref(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { - %cst1 = arith.constant 1 : index - iree_linalg_ext.fft - ins(%cst1: index) - outs(%arg0, %arg1: memref<1024xf32>, memref<1024xf32>) - return -} -// CHECK-LABEL: func.func @fft_memref( -// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]] -// CHECK: %[[CST:.+]] = arith.constant 1 : index -// CHECK: iree_linalg_ext.fft -// CHECK-SAME: ins(%[[CST]] : index) -// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : memref<1024xf32>, memref<1024xf32>) -// CHECK: return - -// ----- - -func.func @fft_tensor_coef(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, - %arg2: tensor<1xf32>, %arg3: tensor<1xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { - %cst1 = arith.constant 1 : index - %0:2 = iree_linalg_ext.fft - ins(%cst1, %arg2, %arg3: index, tensor<1xf32>, tensor<1xf32>) - outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>) - : tensor<1024xf32>, tensor<1024xf32> - return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32> -} -// CHECK-LABEL: func.func @fft_tensor_coef( -// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]] -// CHECK: %[[CST:.+]] = arith.constant 1 : index -// CHECK: %[[RES:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[CST]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<1xf32>, tensor<1xf32>) -// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : tensor<1024xf32>, tensor<1024xf32>) -// CHECK-SAME: : tensor<1024xf32>, tensor<1024xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1 - -// ----- - -func.func @fft_memref_coef(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, - %arg2: memref<1xf32>, %arg3: memref<1xf32>) { - %cst1 = arith.constant 1 : index - iree_linalg_ext.fft - ins(%cst1, %arg2, %arg3: index, memref<1xf32>, memref<1xf32>) - outs(%arg0, %arg1: memref<1024xf32>, memref<1024xf32>) - return -} -// CHECK-LABEL: func.func @fft_memref_coef( -// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]] -// CHECK: %[[CST:.+]] = arith.constant 1 : index -// CHECK: iree_linalg_ext.fft -// CHECK-SAME: ins(%[[CST]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, memref<1xf32>, memref<1xf32>) -// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : memref<1024xf32>, memref<1024xf32>) -// CHECK: return - -// ----- - -// The size of coefficient tensor is 2^(stage-1). -func.func @fft_tensor_coef_stage_5(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, - %arg2: tensor<16xf32>, %arg3: tensor<16xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { - %cst1 = arith.constant 5 : index - %0:2 = iree_linalg_ext.fft - ins(%cst1, %arg2, %arg3: index, tensor<16xf32>, tensor<16xf32>) - outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>) - : tensor<1024xf32>, tensor<1024xf32> - return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32> -} -// CHECK-LABEL: func.func @fft_tensor_coef_stage_5( -// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]] -// CHECK: %[[CST:.+]] = arith.constant 5 : index -// CHECK: %[[RES:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[CST]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<16xf32>, tensor<16xf32>) -// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : tensor<1024xf32>, tensor<1024xf32>) -// CHECK-SAME: : tensor<1024xf32>, tensor<1024xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1 - -// ----- - -func.func @reverse_tensor(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xi32>) : tensor<3x5xi32> - return %0 : tensor<3x5xi32> -} -// CHECK-LABEL: func.func @reverse_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32> -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @reverse_memref(%arg0: memref<3x5xi32>, %arg1: memref<3x5xi32>) { - iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : memref<3x5xi32>) - outs(%arg1 : memref<3x5xi32>) - return -} -// CHECK-LABEL: func.func @reverse_memref -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x5xi32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x5xi32> -// CHECK: iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[ARG1]] - -// ----- - -func.func @reverse_dynamic_tensor(%arg0: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %init = tensor.empty(%d0, %d1) : tensor - %0 = iree_linalg_ext.reverse - dimensions(dense<1> : tensor<1xi64>) - ins(%arg0 : tensor) - outs(%init : tensor) : tensor - return %0 : tensor -} -// CHECK-LABEL: func.func @reverse_dynamic_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @reverse_static_dynamic_tensor(%arg0: tensor<3x5xi32>) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor<3x5xi32> - %d1 = tensor.dim %arg0, %c1 : tensor<3x5xi32> - %init = tensor.empty(%d0, %d1) : tensor - %0 = iree_linalg_ext.reverse - dimensions(dense<1> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor) : tensor - return %0 : tensor -} -// CHECK-LABEL: func.func @reverse_static_dynamic_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32> -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @reverse_multi_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - %0 = iree_linalg_ext.reverse - dimensions(dense<[0, 1]> : tensor<2xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xi32>) : tensor<3x5xi32> - return %0 : tensor<3x5xi32> -} -// CHECK-LABEL: func.func @reverse_multi_dims -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32> -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @topk_tensor(%input_values: tensor<20x10x8x4xf32>, %input_indices: tensor<20x10x8x4xi32>) -> (tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>) { - %out_values = tensor.empty() : tensor<20x10x3x4xf32> - %out_indices = tensor.empty() : tensor<20x10x3x4xi32> - %0:2 = iree_linalg_ext.topk - dimension(2) - ins(%input_values, %input_indices : tensor<20x10x8x4xf32> , tensor<20x10x8x4xi32>) - outs(%out_values, %out_indices : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32> - return %0#0, %0#1 : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32> -} - -// CHECK-LABEL: func.func @topk_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x10x8x4xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x10x8x4xi32> -// CHECK: %[[OUT_VALUES:.+]] = tensor.empty() -// CHECK: %[[OUT_INDICES:.+]] = tensor.empty() -// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk -// CHECK-SAME: dimension(2) -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]] -// CHECK: iree_linalg_ext.yield -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - -// ----- - -func.func @topk_memref(%input_values: memref<4x10xf32>, %input_indices: memref<4x10xi32>, %out_values: memref<4x3xf32>, %out_indices: memref<4x3xi32>) { - iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : memref<4x10xf32> , memref<4x10xi32>) - outs(%out_values, %out_indices : memref<4x3xf32>, memref<4x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} -// CHECK-LABEL: func.func @topk_memref -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<4x10xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x10xi32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<4x3xf32> -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: memref<4x3xi32> -// CHECK: iree_linalg_ext.topk -// CHECK-SAME: dimension(1) -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-SAME: outs(%[[ARG2]], %[[ARG3]] -// CHECK: iree_linalg_ext.yield - -// ----- - -func.func @topk_dynamic_tensor(%input_values: tensor, %input_indices: tensor, %out_values: tensor, %out_indices: tensor) -> (tensor, tensor) { - %0:2 = iree_linalg_ext.topk - dimension(1) - ins(%input_values, %input_indices : tensor , tensor) - outs(%out_values, %out_indices : tensor, tensor) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor, tensor - return %0#0, %0#1 : tensor, tensor -} -// CHECK-LABEL: func.func @topk_dynamic_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor -// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk -// CHECK-SAME: dimension(1) -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-SAME: outs(%[[ARG2]], %[[ARG3]] -// CHECK: iree_linalg_ext.yield -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - -// ----- - -func.func @topk_tensor_optional(%input_values: tensor<20x10x8x4xf32>) -> (tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>) { - %out_values = tensor.empty() : tensor<20x10x3x4xf32> - %out_indices = tensor.empty() : tensor<20x10x3x4xi32> - %0:2 = iree_linalg_ext.topk - dimension(2) - ins(%input_values : tensor<20x10x8x4xf32>) - outs(%out_values, %out_indices : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32> - return %0#0, %0#1 : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32> -} - -// CHECK-LABEL: func.func @topk_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x10x8x4xf32> -// CHECK: %[[OUT_VALUES:.+]] = tensor.empty() -// CHECK: %[[OUT_INDICES:.+]] = tensor.empty() -// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk -// CHECK-SAME: dimension(2) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]] -// CHECK: iree_linalg_ext.yield -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - -// ----- - -func.func @pack(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32> { - %1 = iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %arg1 : (tensor<3x3xf32> tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32> - return %1 : tensor<3x3x1x1xf32> -} - -// CHECK: func.func @pack( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x3xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32> -// CHECK: %[[RES:.*]] = iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (tensor<3x3xf32> tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32> -// CHECK: return %[[RES]] : tensor<3x3x1x1xf32> - -// ----- - -func.func @pack(%arg0: memref<3x3xf32>, %arg1: memref<3x3x1x1xf32>) { - iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %arg1 : (memref<3x3xf32> memref<3x3x1x1xf32>) - return -} - -// CHECK: func.func @pack( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) { -// CHECK: iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (memref<3x3xf32> memref<3x3x1x1xf32>) - -// ----- - -func.func @extra_pad_and_pack(%input: tensor<13x15xf32>, %output: tensor<3x8x8x2xf32>, %pad: f32) -> tensor<3x8x8x2xf32> { - // expected-error@+1 {{infered type do not match provided output type. Expected 'tensor<2x8x8x2xf32>' but got: 'tensor<3x8x8x2xf32>}} - %0 = iree_linalg_ext.pack %input padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<3x8x8x2xf32>) -> tensor<3x8x8x2xf32> - return %0 : tensor<3x8x8x2xf32> -} -// CHECK: func @extra_pad_and_pack( -// CHECK-SAME: %[[INPUT:.+]]: tensor<13x15xf32> -// CHECK-SAME: %[[OUTPUT:.+]]: tensor<3x8x8x2xf32> -// CHECK-SAME: %[[PAD:.+]]: f32 -// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]] -// CHECK-SAME: padding_value(%[[PAD]] : f32) -// CHECK-SAME: inner_dims_pos = [0, 1] -// CHECK-SAME: inner_tiles = [8, 2] -// CHECK-SAME: into %[[OUTPUT]] -// CHECK: return %[[RES]] - -// ----- - -func.func @pad_and_pack_static(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: f32) -> tensor<2x8x8x2xf32> { - %0 = iree_linalg_ext.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> - return %0 : tensor<2x8x8x2xf32> -} -// CHECK: func.func @pad_and_pack_static -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<13x15xf32> -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<2x8x8x2xf32> -// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]] -// CHECK-SAME: padding_value(%[[PAD]] : f32) -// CHECK-SAME: inner_dims_pos = [0, 1] -// CHECK-SAME: inner_tiles = [8, 2] -// CHECK-SAME: into %[[OUTPUT]] -// CHECK: return %[[RES]] - -// ----- - -func.func @pad_and_pack_partially_dynamic(%input: tensor, %output: tensor, %pad: f32) -> tensor { - %0 = iree_linalg_ext.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor tensor) -> tensor - return %0 : tensor -} -// CHECK: func.func @pad_and_pack_partially_dynamic -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]] -// CHECK-SAME: padding_value(%[[PAD]] : f32) -// CHECK-SAME: inner_dims_pos = [0, 1] -// CHECK-SAME: inner_tiles = [8, 2] -// CHECK-SAME: into %[[OUTPUT]] -// CHECK: return %[[RES]] - -// ----- - -func.func @pad_and_pack_fully_dynamic(%input: tensor, %output: tensor, %pad: f32, %tile_n : index, %tile_m : index) -> tensor { - %0 = iree_linalg_ext.pack %input padding_value(%pad : f32) - inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %output : (tensor tensor) -> tensor - return %0 : tensor -} -// CHECK: func.func @pad_and_pack_fully_dynamic -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32 -// CHECK-SAME: %[[TILE_N:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[TILE_M:[a-zA-Z0-9_]+]]: index -// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]] -// CHECK-SAME: padding_value(%[[PAD]] : f32) -// CHECK-SAME: inner_dims_pos = [0, 1] -// CHECK-SAME: inner_tiles = [%[[TILE_N]], %[[TILE_M]]] -// CHECK-SAME: into %[[OUTPUT]] -// CHECK: return %[[RES]] - -// ----- - -func.func @unpack(%arg0: memref<3x3xf32>, %arg1: memref<3x3x1x1xf32>) { - iree_linalg_ext.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %arg0 : (memref<3x3x1x1xf32> memref<3x3xf32>) - return -} - -// CHECK: func.func @unpack( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) { -// CHECK: iree_linalg_ext.unpack %[[ARG1]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>) - -// ----- - -func.func @unpack_static(%input: tensor<8x8x32x16xf32>, %output: tensor<256x128xf32>) -> tensor<256x128xf32> { - %0 = iree_linalg_ext.unpack %input inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %output : (tensor<8x8x32x16xf32> tensor<256x128xf32>) -> tensor<256x128xf32> - return %0 : tensor<256x128xf32> -} - -// CHECK: func.func @unpack_static -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]] -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK-SAME: %[[INPUT]] -// CHECK-SAME dim_pos = [0, 1] -// CHECK-SAME inner_pos = [32, 16] -// CHECK-SAME: into %[[OUTPUT]] -// CHECK: return %[[UNPACK]] - -// ----- - -func.func @unpack_undo_padding(%input: tensor<2x8x8x2xf32>, %output: tensor<13x15xf32>) -> tensor<13x15xf32> { - %0 = iree_linalg_ext.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<2x8x8x2xf32> tensor<13x15xf32>) -> tensor<13x15xf32> - return %0 : tensor<13x15xf32> -} -// CHECK: func.func @unpack_undo_padding -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]] -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK-SAME: %[[INPUT]] -// CHECK-SAME dim_pos = [0, 1] -// CHECK-SAME inner_pos = [32, 16] -// CHECK-SAME: into %[[OUTPUT]] -// CHECK: return %[[UNPACK]] - -// ----- - -func.func @unpack(%arg0: memref<3x3xf32>, %arg1: memref<3x3x1x1xf32>) { - iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %arg0 : (memref<3x3x1x1xf32> memref<3x3xf32>) - return -} - -// CHECK: func.func @unpack( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) { -// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>) - -// ----- - -func.func @pack(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) { - iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<32x4x32x8xf32>) - return -} - -// CHECK: func.func @pack -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>) { -// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<32x4x32x8xf32>) - -// ----- - -func.func @pack(%arg0: memref<128x256xf32>, %arg1: memref<4x32x32x8xf32>) { - iree_linalg_ext.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<4x32x32x8xf32>) - return -} - -// CHECK: func.func @pack -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>) { -// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<4x32x32x8xf32>) - -// ----- - -func.func @unpack(%arg0: memref<128x256xf32>, %arg1: memref<4x32x32x8xf32>) { - iree_linalg_ext.unpack %arg1 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg0 : (memref<4x32x32x8xf32> memref<128x256xf32>) - return -} - -// CHECK: func.func @unpack -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>) { -// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<4x32x32x8xf32> memref<128x256xf32>) - -// ----- - -func.func @unpack(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) { - iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg0 : (memref<32x4x32x8xf32> memref<128x256xf32>) - return -} - -// CHECK: func.func @unpack -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>) { -// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<32x4x32x8xf32> memref<128x256xf32>) - -// ----- - -// CHECK: @set_encoding_ops(%[[ARG0:.+]]: tensor) -func.func @set_encoding_ops(%arg0: tensor) -> tensor> { - // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor -> tensor> - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - return %0 : tensor> -} - -// ----- - -// CHECK: @set_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor) -func.func @set_encoding_ops_mixed_dynamic_static(%arg0: tensor) -> tensor<20x?xf32, #iree_linalg_ext.encoding> { - // CHECK: iree_linalg_ext.set_encoding %[[ARG0]] : tensor -> tensor<20x?xf32, #iree_linalg_ext.encoding> - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor<20x?xf32, #iree_linalg_ext.encoding> - return %0 : tensor<20x?xf32, #iree_linalg_ext.encoding> -} - -// ----- - -// CHECK: @unset_encoding_ops(%[[ARG0:.+]]: tensor>) -func.func @unset_encoding_ops(%arg0: tensor>) -> tensor { - // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor> -> tensor - %0 = iree_linalg_ext.unset_encoding %arg0 : tensor> -> tensor - return %0 : tensor -} - -// ----- - -// CHECK: @unset_encoding_ops_mixed_dynamic_static(%[[ARG0:.+]]: tensor<10x?xf32, #iree_linalg_ext.encoding>) -func.func @unset_encoding_ops_mixed_dynamic_static(%arg0: tensor<10x?xf32, #iree_linalg_ext.encoding>) -> tensor { - // CHECK: iree_linalg_ext.unset_encoding %[[ARG0]] : tensor<10x?xf32, #iree_linalg_ext.encoding> - %0 = iree_linalg_ext.unset_encoding %arg0 : tensor<10x?xf32, #iree_linalg_ext.encoding> -> tensor - return %0 : tensor -} - -// ----- - -func.func @encoding_tensors_with_ops(%arg0 : tensor, - %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> - %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> - %2 = iree_linalg_ext.set_encoding %arg2 : tensor -> tensor> - %3 = linalg.matmul - ins(%0, %1 : tensor>, tensor>) - outs(%2 : tensor>) - -> tensor> - %4 = iree_linalg_ext.unset_encoding %3 : tensor> -> tensor - return %4 : tensor -} -// CHECK-LABEL: func.func @encoding_tensors_with_ops -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor -// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[ARG0]] -// CHECK-SAME: tensor -> tensor> -// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[ARG1]] -// CHECK-SAME: tensor -> tensor> -// CHECK: %[[OUT:.+]] = iree_linalg_ext.set_encoding %[[ARG2]] -// CHECK-SAME: tensor -> tensor> -// CHECK: %[[GEMM:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : -// CHECK-SAME: outs(%[[OUT]] : -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.unset_encoding %[[GEMM]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> { - %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32> - %1 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> - return %1 : tensor<8x8x1x2x2x1280xf32> -} -// CHECK: func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) -> -// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> { -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32> -// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x10x10x1280xf32>) outs(%[[D0]] : -// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> -// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32> -// CHECK: } - -// ----- - -func.func @winograd_input_transform_dynamic(%arg0: tensor, %arg1: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> { - %1 = iree_linalg_ext.winograd.input_transform - output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor) outs(%arg1 : tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> - return %1 : tensor<8x8x?x?x?x?xf32> -} -// CHECK: func.func @winograd_input_transform_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> { -// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor) outs(%[[ARG1]] : -// CHECK-SAME: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> -// CHECK: return %[[D0]] : tensor<8x8x?x?x?x?xf32> -// CHECK: } - -// ----- - -func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x1280xf32>) -> tensor<1x12x12x1280xf32> { - %0 = tensor.empty() : tensor<1x12x12x1280xf32> - %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<8x8x1x2x2x1280xf32>) outs(%0 : tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32> - return %1 : tensor<1x12x12x1280xf32> -} -// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>) -> -// CHECK-SAME: tensor<1x12x12x1280xf32> { -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x12x12x1280xf32> -// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] : -// CHECK-SAME: tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32> -// CHECK: return %[[D1]] : tensor<1x12x12x1280xf32> -// CHECK: } - -// ----- - -func.func @winograd_output_transform(%arg0: tensor<8x8x?x?x?x?xf32>, %arg1: tensor) -> tensor { - %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<8x8x?x?x?x?xf32>) outs(%arg1 : tensor) -> tensor - return %1 : tensor -} -// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor) -> tensor { -// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x?x?x?x?xf32>) outs(%[[ARG1]] : -// CHECK-SAME: tensor) -> tensor -// CHECK: return %[[D0]] : tensor -// CHECK: } - -// ----- diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/split-reduction.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/split-reduction.mlir deleted file mode 100644 index 7a7b1ffd5db7..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/split-reduction.mlir +++ /dev/null @@ -1,164 +0,0 @@ -// RUN: iree-dialects-opt --iree-linalg-ext-topk-split-reduction='split-ratios=3' %s | FileCheck %s --check-prefix SINGLE -// RUN: iree-dialects-opt --iree-linalg-ext-topk-split-reduction='split-ratios=4' %s | FileCheck %s --check-prefix MULTIPLE -// RUN: iree-dialects-opt --iree-linalg-ext-topk-split-reduction='split-ratios=40,10' %s | FileCheck %s --check-prefix DOUBLE - -func.func @topk_split_reduction_1d(%input_values: tensor<30xf32>, %out_values: tensor<3xf32>, %out_indices: tensor<3xi32>) -> (tensor<3xf32>, tensor<3xi32>) { - %0:2 = iree_linalg_ext.topk - dimension(0) - ins(%input_values: tensor<30xf32>) - outs(%out_values, %out_indices : tensor<3xf32>, tensor<3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<3xf32>, tensor<3xi32> - return %0#0, %0#1 : tensor<3xf32>, tensor<3xi32> -} - -// SINGLE-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// SINGLE-LABEL: func.func @topk_split_reduction_1d( -// SINGLE-SAME: %[[ARG0:.*]]: tensor<30xf32>, -// SINGLE-SAME: %[[ARG1:.*]]: tensor<3xf32>, -// SINGLE-SAME: %[[ARG2:.*]]: tensor<3xi32>) -> (tensor<3xf32>, tensor<3xi32>) { -// SINGLE-DAG: %[[CNEG:.*]] = arith.constant 0xFF800000 : f32 -// SINGLE-DAG: %[[CPOS:.*]] = arith.constant 2147483647 : i32 -// SINGLE-DAG: %[[C10:.*]] = arith.constant 10 : i32 -// SINGLE: %[[D0:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<30xf32> into tensor<3x10xf32> -// SINGLE: %[[D1:.*]] = tensor.empty() : tensor<3x3xf32> -// SINGLE: %[[D2:.*]] = tensor.empty() : tensor<3x3xi32> -// SINGLE: %[[D3:.*]] = linalg.fill ins(%[[CNEG]] : f32) outs(%[[D1]] : tensor<3x3xf32>) -> tensor<3x3xf32> -// SINGLE: %[[D4:.*]] = linalg.fill ins(%[[CPOS]] : i32) outs(%[[D2]] : tensor<3x3xi32>) -> tensor<3x3xi32> -// SINGLE: %[[D5:.*]]:2 = iree_linalg_ext.topk dimension(1) ins(%[[D0]] : tensor<3x10xf32>) outs(%[[D3]], %[[D4]] : tensor<3x3xf32>, tensor<3x3xi32>) { -// SINGLE: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// SINGLE: %[[D10:.*]] = arith.cmpf ogt, %[[ARG3]], %[[ARG4]] : f32 -// SINGLE: iree_linalg_ext.yield %[[D10]] : i1 -// SINGLE: } -> tensor<3x3xf32>, tensor<3x3xi32> -// SINGLE: %[[ARG3:.*]] = linalg.generic {indexing_maps = [#[[MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[D5:.*]]#1 : tensor<3x3xi32>) { -// SINGLE: ^bb0(%[[ARG3:.*]]: i32): -// SINGLE: %[[D10:.*]] = linalg.index 0 : index -// SINGLE: %[[D11:.*]] = arith.index_cast %[[D10]] : index to i32 -// SINGLE: %[[D12:.*]] = arith.muli %[[D11]], %[[C10]] : i32 -// SINGLE: %[[D13:.*]] = arith.addi %[[D12]], %[[ARG3]] : i32 -// SINGLE: linalg.yield %[[D13]] : i32 -// SINGLE: } -> tensor<3x3xi32> -// SINGLE: %[[D7:.*]] = tensor.collapse_shape %[[D5:.*]]#0 {{\[\[}}0, 1]] : tensor<3x3xf32> into tensor<9xf32> -// SINGLE: %[[D8:.*]] = tensor.collapse_shape %[[D6:.*]] {{\[\[}}0, 1]] : tensor<3x3xi32> into tensor<9xi32> -// SINGLE: %[[D9:.*]]:2 = iree_linalg_ext.topk dimension(0) ins(%[[D7]], %[[D8]] : tensor<9xf32>, tensor<9xi32>) outs(%[[ARG1]], %[[ARG2]] : tensor<3xf32>, tensor<3xi32>) { -// SINGLE: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// SINGLE: %[[D10:.*]] = arith.cmpf ogt, %[[ARG3]], %[[ARG4]] : f32 -// SINGLE: iree_linalg_ext.yield %[[D10]] : i1 -// SINGLE: } -> tensor<3xf32>, tensor<3xi32> -// SINGLE: return %[[D9:.*]]#0, %[[D9]]#1 : tensor<3xf32>, tensor<3xi32> -// SINGLE: } - -// ----- - -func.func @topk_split_reduction_nd(%input_values: tensor<3x10x40x8xf32>, %out_values: tensor<3x10x4x8xf32>, %out_indices: tensor<3x10x4x8xi32>) -> (tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32>) { - %0:2 = iree_linalg_ext.topk - dimension(2) - ins(%input_values : tensor<3x10x40x8xf32>) - outs(%out_values, %out_indices : tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32> - return %0#0, %0#1 : tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32> -} - -// MULTIPLE-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// MULTIPLE-LABEL: func.func @topk_split_reduction_nd( -// MULTIPLE-SAME: %[[ARG0:.*]]: tensor<3x10x40x8xf32>, -// MULTIPLE-SAME: %[[ARG1:.*]]: tensor<3x10x4x8xf32>, -// MULTIPLE-SAME: %[[ARG2:.*]]: tensor<3x10x4x8xi32>) -> (tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32>) { -// MULTIPLE-DAG: %[[CNEG:.*]] = arith.constant 0xFF800000 : f32 -// MULTIPLE-DAG: %[[CPOS:.*]] = arith.constant 2147483647 : i32 -// MULTIPLE-DAG: %[[C10:.*]] = arith.constant 10 : i32 -// MULTIPLE: %[[D0:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<3x10x40x8xf32> into tensor<3x10x4x10x8xf32> -// MULTIPLE: %[[D1:.*]] = tensor.empty() : tensor<3x10x4x4x8xf32> -// MULTIPLE: %[[D2:.*]] = tensor.empty() : tensor<3x10x4x4x8xi32> -// MULTIPLE: %[[D3:.*]] = linalg.fill ins(%[[CNEG]] : f32) outs(%[[D1]] : tensor<3x10x4x4x8xf32>) -> tensor<3x10x4x4x8xf32> -// MULTIPLE: %[[D4:.*]] = linalg.fill ins(%[[CPOS]] : i32) outs(%[[D2]] : tensor<3x10x4x4x8xi32>) -> tensor<3x10x4x4x8xi32> -// MULTIPLE: %[[D5:.*]]:2 = iree_linalg_ext.topk dimension(3) ins(%[[D0]] : tensor<3x10x4x10x8xf32>) outs(%[[D3]], %[[D4]] : tensor<3x10x4x4x8xf32>, tensor<3x10x4x4x8xi32>) { -// MULTIPLE: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// MULTIPLE: %[[D10:.*]] = arith.cmpf ogt, %[[ARG3]], %[[ARG4]] : f32 -// MULTIPLE: iree_linalg_ext.yield %[[D10]] : i1 -// MULTIPLE: } -> tensor<3x10x4x4x8xf32>, tensor<3x10x4x4x8xi32> -// MULTIPLE: %[[D6:.*]] = linalg.generic {indexing_maps = [#[[MAP0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%[[D5:.*]]#1 : tensor<3x10x4x4x8xi32>) { -// MULTIPLE: ^bb0(%[[ARG3:.*]]: i32): -// MULTIPLE: %[[D10:.*]] = linalg.index 2 : index -// MULTIPLE: %[[D11:.*]] = arith.index_cast %[[D10]] : index to i32 -// MULTIPLE: %[[D12:.*]] = arith.muli %[[D11]], %[[C10]] : i32 -// MULTIPLE: %[[D13:.*]] = arith.addi %[[D12]], %[[ARG3]] : i32 -// MULTIPLE: linalg.yield %[[D13]] : i32 -// MULTIPLE: } -> tensor<3x10x4x4x8xi32> -// MULTIPLE: %[[D7:.*]] = tensor.collapse_shape %[[D5:.*]]#0 {{\[\[}}0], [1], [2, 3], [4]] : tensor<3x10x4x4x8xf32> into tensor<3x10x16x8xf32> -// MULTIPLE: %[[D8:.*]] = tensor.collapse_shape %[[D6:.*]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<3x10x4x4x8xi32> into tensor<3x10x16x8xi32> -// MULTIPLE: %[[D9:.*]]:2 = iree_linalg_ext.topk dimension(2) ins(%[[D7]], %[[D8]] : tensor<3x10x16x8xf32>, tensor<3x10x16x8xi32>) outs(%[[ARG1]], %[[ARG2]] : tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32>) { -// MULTIPLE: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// MULTIPLE: %[[D10:.*]] = arith.cmpf ogt, %[[ARG3]], %[[ARG4]] : f32 -// MULTIPLE: iree_linalg_ext.yield %[[D10]] : i1 -// MULTIPLE: } -> tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32> -// MULTIPLE: return %[[D9:.*]]#0, %[[D9]]#1 : tensor<3x10x4x8xf32>, tensor<3x10x4x8xi32> -// MULTIPLE: } - -// ----- - -func.func @topk_split_reduction_double(%input_values: tensor<400xf32>, %out_values: tensor<3xf32>, %out_indices: tensor<3xi32>) -> (tensor<3xf32>, tensor<3xi32>) { - %0:2 = iree_linalg_ext.topk - dimension(0) - ins(%input_values: tensor<400xf32>) - outs(%out_values, %out_indices : tensor<3xf32>, tensor<3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<3xf32>, tensor<3xi32> - return %0#0, %0#1 : tensor<3xf32>, tensor<3xi32> -} - -// DOUBLE-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// DOUBLE-LABEL: func.func @topk_split_reduction_double( -// DOUBLE-SAME: %[[ARG0:.*]]: tensor<400xf32>, -// DOUBLE-SAME: %[[ARG1:.*]]: tensor<3xf32>, -// DOUBLE-SAME: %[[ARG2:.*]]: tensor<3xi32>) -> (tensor<3xf32>, tensor<3xi32>) { -// DOUBLE-DAG: %[[CNEG:.*]] = arith.constant 0xFF800000 : f32 -// DOUBLE-DAG: %[[CPOS:.*]] = arith.constant 2147483647 : i32 -// DOUBLE-DAG: %[[C10:.*]] = arith.constant 10 : i32 -// DOUBLE: %[[D0:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<400xf32> into tensor<40x10xf32> -// DOUBLE: %[[D1:.*]] = tensor.empty() : tensor<40x3xf32> -// DOUBLE: %[[D2:.*]] = tensor.empty() : tensor<40x3xi32> -// DOUBLE: %[[D3:.*]] = linalg.fill ins(%[[CNEG]] : f32) outs(%[[D1]] : tensor<40x3xf32>) -> tensor<40x3xf32> -// DOUBLE: %[[D4:.*]] = linalg.fill ins(%[[CPOS]] : i32) outs(%[[D2]] : tensor<40x3xi32>) -> tensor<40x3xi32> -// DOUBLE: %[[D5:.*]]:2 = iree_linalg_ext.topk dimension(1) ins(%[[D0]] : tensor<40x10xf32>) outs(%[[D3]], %[[D4]] : tensor<40x3xf32>, tensor<40x3xi32>) { -// DOUBLE: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// DOUBLE: %[[D19:.*]] = arith.cmpf ogt, %[[ARG3]], %[[ARG4]] : f32 -// DOUBLE: iree_linalg_ext.yield %[[D19]] : i1 -// DOUBLE: } -> tensor<40x3xf32>, tensor<40x3xi32> -// DOUBLE: %[[D6:.*]] = linalg.generic {indexing_maps = [#[[MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[D5:.*]]#1 : tensor<40x3xi32>) { -// DOUBLE: ^bb0(%[[ARG3:.*]]: i32): -// DOUBLE: %[[D19:.*]] = linalg.index 0 : index -// DOUBLE: %[[D20:.*]] = arith.index_cast %[[D19]] : index to i32 -// DOUBLE: %[[D21:.*]] = arith.muli %[[D20]], %[[C10]] : i32 -// DOUBLE: %[[D22:.*]] = arith.addi %[[D21]], %[[ARG3]] : i32 -// DOUBLE: linalg.yield %[[D22]] : i32 -// DOUBLE: } -> tensor<40x3xi32> -// DOUBLE: %[[D7:.*]] = tensor.collapse_shape %[[D5:.*]]#0 {{\[\[}}0, 1]] : tensor<40x3xf32> into tensor<120xf32> -// DOUBLE: %[[D8:.*]] = tensor.collapse_shape %[[D6:.*]] {{\[\[}}0, 1]] : tensor<40x3xi32> into tensor<120xi32> -// DOUBLE: %[[D9:.*]] = tensor.expand_shape %[[D7]] {{\[\[}}0, 1]] : tensor<120xf32> into tensor<10x12xf32> -// DOUBLE: %[[D10:.*]] = tensor.expand_shape %[[D8]] {{\[\[}}0, 1]] : tensor<120xi32> into tensor<10x12xi32> -// DOUBLE: %[[D11:.*]] = tensor.empty() : tensor<10x3xf32> -// DOUBLE: %[[D12:.*]] = tensor.empty() : tensor<10x3xi32> -// DOUBLE: %[[D13:.*]] = linalg.fill ins(%[[CNEG]] : f32) outs(%[[D11]] : tensor<10x3xf32>) -> tensor<10x3xf32> -// DOUBLE: %[[D14:.*]] = linalg.fill ins(%[[CPOS]] : i32) outs(%[[D12]] : tensor<10x3xi32>) -> tensor<10x3xi32> -// DOUBLE: %[[D15:.*]]:2 = iree_linalg_ext.topk dimension(1) ins(%[[D9]], %[[D10]] : tensor<10x12xf32>, tensor<10x12xi32>) outs(%[[D13]], %[[D14]] : tensor<10x3xf32>, tensor<10x3xi32>) { -// DOUBLE: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// DOUBLE: %[[D19:.*]] = arith.cmpf ogt, %[[ARG3]], %[[ARG4]] : f32 -// DOUBLE: iree_linalg_ext.yield %[[D19]] : i1 -// DOUBLE: } -> tensor<10x3xf32>, tensor<10x3xi32> -// DOUBLE: %[[D16:.*]] = tensor.collapse_shape %[[D15:.*]]#0 {{\[\[}}0, 1]] : tensor<10x3xf32> into tensor<30xf32> -// DOUBLE: %[[D17:.*]] = tensor.collapse_shape %[[D15:.*]]#1 {{\[\[}}0, 1]] : tensor<10x3xi32> into tensor<30xi32> -// DOUBLE: %[[D18:.*]]:2 = iree_linalg_ext.topk dimension(0) ins(%[[D16]], %[[D17]] : tensor<30xf32>, tensor<30xi32>) outs(%[[ARG1]], %[[ARG2]] : tensor<3xf32>, tensor<3xi32>) { -// DOUBLE: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// DOUBLE: %[[D10:.*]] = arith.cmpf ogt, %[[ARG3]], %[[ARG4]] : f32 -// DOUBLE: iree_linalg_ext.yield %[[D10]] : i1 -// DOUBLE: } -> tensor<3xf32>, tensor<3xi32> -// DOUBLE: return %[[D18:.*]]#0, %[[D18]]#1 : tensor<3xf32>, tensor<3xi32> -// DOUBLE: } diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir deleted file mode 100644 index aabe3d9707ab..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_winograd.mlir +++ /dev/null @@ -1,195 +0,0 @@ -// RUN: iree-dialects-opt --iree-linalg-ext-tile-and-decompose-winograd --split-input-file %s | FileCheck %s - -#map = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -#map1 = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -module { - func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c1280 = arith.constant 1280 : index - %c32 = arith.constant 32 : index - %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32> - %1 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %0) -> (tensor<8x8x1x2x2x1280xf32>) { - %2 = affine.min #map(%arg1)[%c1, %c1] - %3 = scf.for %arg3 = %c0 to %c1280 step %c32 iter_args(%arg4 = %arg2) -> (tensor<8x8x1x2x2x1280xf32>) { - %4 = affine.min #map1(%arg3)[%c32, %c1280] - %extracted_slice = tensor.extract_slice %arg0[%arg1, 0, 0, %arg3] [%2, 10, 10, %4] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor - %extracted_slice_0 = tensor.extract_slice %0[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to tensor<8x8x?x2x2x?xf32> - %5 = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) ins(%extracted_slice : tensor) outs(%extracted_slice_0 : tensor<8x8x?x2x2x?xf32>) -> tensor<8x8x?x2x2x?xf32> - %inserted_slice = tensor.insert_slice %5 into %arg4[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into tensor<8x8x1x2x2x1280xf32> - scf.yield %inserted_slice : tensor<8x8x1x2x2x1280xf32> - } - scf.yield %3 : tensor<8x8x1x2x2x1280xf32> - } - return %1 : tensor<8x8x1x2x2x1280xf32> - } -} -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 6)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (-d0 + 10, 8)> -// CHECK: func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) -> -// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> { -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: %[[C1280:.+]] = arith.constant 1280 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[CST:.+]] = arith.constant dense< -// CHECK: %[[CST_0:.+]] = arith.constant dense< -// CHECK: %[[CST_1:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8xf32> -// CHECK: %[[D1:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32> -// CHECK: %[[D2:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D1]]) -> (tensor<8x8x1x2x2x1280xf32>) { -// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]] -// CHECK: %[[D4:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<8x8x1x2x2x1280xf32>) { -// CHECK-DAG: %[[D5:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C1280]]] -// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D3]], 10, -// CHECK-SAME: 10, %[[D5]]] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor -// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[D1]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8, -// CHECK-SAME: %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to -// CHECK-SAME: tensor<8x8x?x2x2x?xf32> -// CHECK: %[[D6:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D3]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[EXTRACTED_SLICE_2]]) -> (tensor<8x8x?x2x2x?xf32>) { -// CHECK: %[[D7:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor<8x8x?x2x2x?xf32>) { -// CHECK-DAG: %[[D8:.+]] = affine.apply #[[MAP2]](%[[ARG7]]) -// CHECK-DAG: %[[D9:.+]] = affine.min #[[MAP3]](%[[D8]]) -// CHECK: %[[D10:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor<8x8x?x2x2x?xf32>) { -// CHECK-DAG: %[[D11:.+]] = affine.apply #[[MAP2]](%[[ARG9]]) -// CHECK-DAG: %[[D12:.+]] = affine.min #[[MAP3]](%[[D11]]) -// CHECK: %[[D13:.+]] = scf.for %[[ARG11:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D5]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG12:[a-zA-Z0-9_]+]] = %[[ARG10]]) -> (tensor<8x8x?x2x2x?xf32>) { -// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], %[[D8]], -// CHECK-SAME: %[[D11]], %[[ARG11]]] [1, %[[D9]], %[[D12]], 1] [1, 1, 1, 1] : tensor to -// CHECK-SAME: tensor -// CHECK: %[[D14:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[D0]] : tensor<8x8xf32>) -> -// CHECK-SAME: tensor<8x8xf32> -// CHECK: %[[INSERTED_SLICE_4:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE_3]] into %[[D14]][0, 0] -// CHECK-SAME: [%[[D9]], %[[D12]]] [1, 1] : tensor into tensor<8x8xf32> -// CHECK: %[[EXTRACTED_SLICE_5:.+]] = tensor.extract_slice %[[ARG12]][0, 0, %[[ARG5]], %[[ARG7]], -// CHECK-SAME: %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> to -// CHECK-SAME: tensor<8x8xf32> -// CHECK: %[[D15:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_5]] : -// CHECK-SAME: tensor<8x8xf32>) -> tensor<8x8xf32> -// CHECK: %[[D16:.+]] = linalg.matmul ins(%[[INSERTED_SLICE_4]], %[[CST_0]] : tensor<8x8xf32>, -// CHECK-SAME: tensor<8x8xf32>) outs(%[[D15]] : tensor<8x8xf32>) -> tensor<8x8xf32> -// CHECK: %[[D17:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_5]] : -// CHECK-SAME: tensor<8x8xf32>) -> tensor<8x8xf32> -// CHECK: %[[D18:.+]] = linalg.matmul ins(%[[CST]], %[[D16]] : tensor<8x8xf32>, tensor<8x8xf32>) -// CHECK-SAME: outs(%[[D17]] : tensor<8x8xf32>) -> tensor<8x8xf32> -// CHECK: %[[INSERTED_SLICE_6:.+]] = tensor.insert_slice %[[D18]] into %[[ARG12]][0, 0, %[[ARG5]], -// CHECK-SAME: %[[ARG7]], %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<8x8xf32> -// CHECK-SAME: into tensor<8x8x?x2x2x?xf32> -// CHECK: scf.yield %[[INSERTED_SLICE_6]] : tensor<8x8x?x2x2x?xf32> -// CHECK: } -// CHECK: scf.yield %[[D13]] : tensor<8x8x?x2x2x?xf32> -// CHECK: } -// CHECK: scf.yield %[[D10]] : tensor<8x8x?x2x2x?xf32> -// CHECK: } -// CHECK: scf.yield %[[D7]] : tensor<8x8x?x2x2x?xf32> -// CHECK: } -// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]] into %[[ARG4]][0, 0, %[[ARG1]], 0, 0, -// CHECK-SAME: %[[ARG3]]] [8, 8, %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into -// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> -// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<8x8x1x2x2x1280xf32> -// CHECK: } -// CHECK: scf.yield %[[D4]] : tensor<8x8x1x2x2x1280xf32> -// CHECK: } -// CHECK: return %[[D2]] : tensor<8x8x1x2x2x1280xf32> -// CHECK: } - -// ----- - -#map = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -#map1 = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -module { - func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x12x12x32xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %0 = tensor.empty() : tensor<1x12x12x32xf32> - %1 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %0) -> (tensor<1x12x12x32xf32>) { - %2 = affine.min #map(%arg1)[%c1, %c1] - %3 = scf.for %arg3 = %c0 to %c32 step %c32 iter_args(%arg4 = %arg2) -> (tensor<1x12x12x32xf32>) { - %4 = affine.min #map1(%arg3)[%c32, %c32] - %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg1, 0, 0, %arg3] [8, 8, %2, 2, 2, %4] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32> - %extracted_slice_0 = tensor.extract_slice %0[%arg1, 0, 0, %arg3] [%2, 12, 12, %4] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor - %5 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) ins(%extracted_slice : tensor<8x8x?x2x2x?xf32>) outs(%extracted_slice_0 : tensor) -> tensor - %inserted_slice = tensor.insert_slice %5 into %arg4[%arg1, 0, 0, %arg3] [%2, 12, 12, %4] [1, 1, 1, 1] : tensor into tensor<1x12x12x32xf32> - scf.yield %inserted_slice : tensor<1x12x12x32xf32> - } - scf.yield %3 : tensor<1x12x12x32xf32> - } - return %1 : tensor<1x12x12x32xf32> - } -} -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 6)> -// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x32xf32>) -> -// CHECK-SAME: tensor<1x12x12x32xf32> { -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[CST:.+]] = arith.constant dense< -// CHECK: %[[CST_0:.+]] = arith.constant dense< -// CHECK: %[[CST_1:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x6xf32> -// CHECK: %[[D1:.+]] = tensor.empty() : tensor<1x12x12x32xf32> -// CHECK: %[[D2:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D1]]) -> (tensor<1x12x12x32xf32>) { -// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]] -// CHECK: %[[D4:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<1x12x12x32xf32>) { -// CHECK-DAG: %[[D5:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C32]]] -// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8, -// CHECK-SAME: %[[D3]], 2, 2, %[[D5]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32> -// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[D1]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D3]], 12, -// CHECK-SAME: 12, %[[D5]]] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor -// CHECK: %[[D6:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D3]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[EXTRACTED_SLICE_2]]) -> (tensor) { -// CHECK: %[[D7:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor) { -// CHECK-DAG: %[[D8:.+]] = affine.apply #[[MAP2]](%[[ARG7]]) -// CHECK: %[[D9:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor) { -// CHECK-DAG: %[[D10:.+]] = affine.apply #[[MAP2]](%[[ARG9]]) -// CHECK: %[[D11:.+]] = scf.for %[[ARG11:[a-zA-Z0-9_]+]] = %[[C0]] to %[[D5]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG12:[a-zA-Z0-9_]+]] = %[[ARG10]]) -> (tensor) { -// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, %[[ARG5]], -// CHECK-SAME: %[[ARG7]], %[[ARG9]], %[[ARG11]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : -// CHECK-SAME: tensor<8x8x?x2x2x?xf32> to tensor<8x8xf32> -// CHECK: %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG12]][%[[ARG5]], %[[D8]], %[[D10]], -// CHECK-SAME: %[[ARG11]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor to tensor<6x6xf32> -// CHECK: %[[D12:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[D0]] : tensor<8x6xf32>) -> -// CHECK-SAME: tensor<8x6xf32> -// CHECK: %[[D13:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE_3]], %[[CST_0]] : tensor<8x8xf32>, -// CHECK-SAME: tensor<8x6xf32>) outs(%[[D12]] : tensor<8x6xf32>) -> tensor<8x6xf32> -// CHECK: %[[D14:.+]] = linalg.fill ins(%[[CST_1]] : f32) outs(%[[EXTRACTED_SLICE_4]] : -// CHECK-SAME: tensor<6x6xf32>) -> tensor<6x6xf32> -// CHECK: %[[D15:.+]] = linalg.matmul ins(%[[CST]], %[[D13]] : tensor<6x8xf32>, tensor<8x6xf32>) -// CHECK-SAME: outs(%[[D14]] : tensor<6x6xf32>) -> tensor<6x6xf32> -// CHECK: %[[INSERTED_SLICE_5:.+]] = tensor.insert_slice %[[D15]] into %[[ARG12]][%[[ARG5]], %[[D8]], -// CHECK-SAME: %[[D10]], %[[ARG11]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<6x6xf32> into -// CHECK-SAME: tensor -// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor -// CHECK: } -// CHECK: scf.yield %[[D11]] : tensor -// CHECK: } -// CHECK: scf.yield %[[D9]] : tensor -// CHECK: } -// CHECK: scf.yield %[[D7]] : tensor -// CHECK: } -// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]] into %[[ARG4]][%[[ARG1]], 0, 0, %[[ARG3]]] -// CHECK-SAME: [%[[D3]], 12, 12, %[[D5]]] [1, 1, 1, 1] : tensor into tensor<1x12x12x32xf32> -// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<1x12x12x32xf32> -// CHECK: } -// CHECK: scf.yield %[[D4]] : tensor<1x12x12x32xf32> -// CHECK: } -// CHECK: return %[[D2]] : tensor<1x12x12x32xf32> -// CHECK: } diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir deleted file mode 100644 index d89d4ed91a99..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir +++ /dev/null @@ -1,1424 +0,0 @@ -// RUN: iree-dialects-opt --iree-linalg-ext-tile --split-input-file --verify-diagnostics -cse %s | FileCheck %s - -func.func @scatter_tiling( - %original: tensor, %indices: tensor, - %update : tensor) -> tensor { - %0 = iree_linalg_ext.scatter - {__internal_linalg_transform__ = "tiling_input"} - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK: func.func @scatter_tiling( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[TILESIZEY:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[TILESIZEX:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZEY]] -// CHECK-SAME: iter_args(%[[INITY:.+]] = %[[ORIGINAL]]) -// CHECK-DAG: %[[USED_TILESIZEY:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[TILESIZEY]], %[[D0]]] -// CHECK: %[[RESULT_INNER:.+]] = scf.for %[[IV1:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZEX]] -// CHECK-SAME: iter_args(%[[INITX:.+]] = %[[INITY]]) -// CHECK: %[[USED_TILESIZEX:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[TILESIZEX]], %[[D1]]] -// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[USED_TILESIZEY]], %[[USED_TILESIZEX]]] -// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV0]], 0] -// CHECK-SAME: [%[[USED_TILESIZEY]], 1] -// CHECK: %[[SCATTER_DIM:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] -// CHECK: %[[ORIGINAL_SLICE:.+]] = tensor.extract_slice %[[ORIGINAL]][0, %[[IV1]]] -// CHECK-SAME: [%[[SCATTER_DIM]], %[[USED_TILESIZEX]]] -// CHECK: %[[SCATTER_TILE:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: __internal_linalg_transform__ = "tiling_output" -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] -// CHECK-SAME: outs(%[[ORIGINAL_SLICE]] -// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INITX]][0, %[[IV1]]] -// CHECK-SAME: [%[[SCATTER_DIM]], %[[USED_TILESIZEX]]] -// CHECK: scf.yield %[[YIELD]] -// CHECK: scf.yield %[[RESULT_INNER]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_tiling_memref( - %original: memref, %indices: memref, - %update : memref) { - iree_linalg_ext.scatter - {__internal_linalg_transform__ = "tiling_input"} - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : memref, memref) - outs(%original : memref) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } - return -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK: func.func @scatter_tiling_memref( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[TILESIZEY:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[TILESIZEX:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]] -// CHECK: scf.for %[[IV0:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZEY]] -// CHECK-DAG: %[[USED_TILESIZEY:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[TILESIZEY]], %[[D0]]] -// CHECK: scf.for %[[IV1:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZEX]] -// CHECK-DAG: %[[USED_TILESIZEX:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[TILESIZEX]], %[[D1]]] -// CHECK: %[[UPDATE_SLICE:.+]] = memref.subview %[[UPDATES]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[USED_TILESIZEY]], %[[USED_TILESIZEX]]] -// CHECK: %[[INDEX_SLICE:.+]] = memref.subview %[[INDICES]][%[[IV0]], 0] -// CHECK-SAME: [%[[USED_TILESIZEY]], 1] -// CHECK: %[[SCATTER_DIM:.+]] = memref.dim %[[ORIGINAL]], %[[C0]] -// CHECK: %[[ORIGINAL_SLICE:.+]] = memref.subview %[[ORIGINAL]][0, %[[IV1]] -// CHECK-SAME: [%[[SCATTER_DIM]], %[[USED_TILESIZEX]]] -// CHECK: iree_linalg_ext.scatter -// CHECK-SAME: __internal_linalg_transform__ = "tiling_output" -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] -// CHECK-SAME: outs(%[[ORIGINAL_SLICE]] - -// ----- - -func.func @scatter_tiling_distribution( - %original: tensor, %indices: tensor, - %update : tensor) -> tensor { - %0 = iree_linalg_ext.scatter - {__internal_linalg_transform__ = "distribute_input"} - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK: func.func @scatter_tiling_distribution( -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]] -// CHECK-DAG: %[[ID:.+]] = iree_input.dispatch.workgroup.id[0] -// CHECK-DAG: %[[COUNT:.+]] = iree_input.dispatch.workgroup.count[0] -// CHECK-DAG: %[[OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[ID]]] -// CHECK-DAG: %[[STEP:.+]] = affine.apply #[[MAP0]]()[%[[COUNT]]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[OFFSET]] to %[[D0]] step %[[STEP]] -// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ORIGINAL]]) -// CHECK: %[[USED_TILESIZE:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[TILESIZE]], %[[D0]]] -// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] -// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], 1] -// CHECK: %[[D2:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] -// CHECK: %[[ORIGINAL_SLICE:.+]] = tensor.extract_slice %[[ORIGINAL]][0, 0] -// CHECK-SAME: [%[[D2]], %[[D1]]] -// CHECK: %[[SCATTER_TILE:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] -// CHECK-SAME: outs(%[[ORIGINAL_SLICE]] -// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0] -// CHECK-SAME: [%[[D2]], %[[D1]]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_no_tiling( - %original: tensor, %indices: tensor, - %update : tensor) -> tensor { - %0 = iree_linalg_ext.scatter - {__internal_linalg_transform__ = "no_tiling_input"} - dimension_map = [0] - unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} -// CHECK: func.func @scatter_no_tiling -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: __internal_linalg_transform__ = "no_tiling_output" -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] -// CHECK-SAME: outs(%[[ORIGINAL]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_repeated_indices_tiling( - %original: tensor, %indices: tensor, - %update : tensor) -> tensor { - %0 = iree_linalg_ext.scatter - {__internal_linalg_transform__ = "tiling_repeated_indices_scatter_input"} - dimension_map = [0] - unique_indices(false) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK: func.func @scatter_repeated_indices_tiling -// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]] -// CHECK: %[[RESULT:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]] -// CHECK-SAME: iter_args(%[[ITER:.+]] = %[[ORIGINAL]]) -// CHECK: %[[SZ:.+]] = affine.min #[[MAP]](%[[I]])[%[[TILESIZE]], %[[D1]]] -// CHECK: %[[UPDATES_TILE:.+]] = tensor.extract_slice -// CHECK-SAME: %[[UPDATES]][0, %[[I]]] [%[[D0]], %[[SZ]]] [1, 1] -// CHECK: %[[INDICES_TILE:.+]] = tensor.extract_slice -// CHECK-SAME: %[[INDICES]][0, 0] [%[[D0]], 1] [1, 1] -// CHECK: %[[ORIGINAL_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] -// CHECK: %[[ORIGINAL_TILE:.+]] = tensor.extract_slice -// CHECK-SAME: %[[ORIGINAL]][0, %[[I]]] [%[[ORIGINAL_D0]], %[[SZ]]] [1, 1] -// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: __internal_linalg_transform__ = "tiling_repeated_indices_scatter_output" -// CHECK-SAME: unique_indices(false) -// CHECK-SAME: ins(%[[UPDATES_TILE]], %[[INDICES_TILE]] -// CHECK-SAME: outs(%[[ORIGINAL_TILE]] -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SCATTER]] into -// CHECK-SAME: %[[ITER]][0, %[[I]]] [%[[ORIGINAL_D0]], %[[SZ]]] [1, 1] -// CHECK: scf.yield %[[RES]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scatter_repeated_indices_no_tiling( - %original: tensor, %indices: tensor, - %update : tensor) -> tensor { - // expected-error @+1 {{unimplemented tiling of non-parallel loop iterator type}} - %0 = iree_linalg_ext.scatter - {__internal_linalg_transform__ = "tiling_input"} - dimension_map = [0] - unique_indices(false) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - iree_linalg_ext.yield %1 : f32 - } -> tensor - return %0 : tensor -} - -// ----- - -func.func @sort_1d(%arg0: tensor) -> tensor { - %0 = iree_linalg_ext.sort - {__internal_linalg_transform__ = "outer_reduce_input"} - dimension(0) - outs(%arg0 : tensor) { - ^bb0(%arg2: i32, %arg3: i32): // no predecessors - %0 = arith.cmpi sgt, %arg2, %arg3 : i32 - iree_linalg_ext.yield %0 : i1 - } -> tensor - return %0 : tensor -} -// CHECK: func.func @sort_1d( -// CHECK-SAME: %[[OPERAND:.+]]: tensor -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.sort -// CHECK-SAME: {__internal_linalg_transform__ = "outer_reduce_output"} -// CHECK-SAME: outs(%[[OPERAND]] : -// CHECK: return %[[RESULT]] - -// ----- - -func.func @sort_2d(%arg0: tensor) -> tensor { - %0 = iree_linalg_ext.sort - {__internal_linalg_transform__ = "inner_reduce_input"} - dimension(1) - outs(%arg0 : tensor) { - ^bb0(%arg2: i32, %arg3: i32): // no predecessors - %0 = arith.cmpi sgt, %arg2, %arg3 : i32 - iree_linalg_ext.yield %0 : i1 - } -> tensor - return %0 : tensor -} -// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK: func.func @sort_2d( -// CHECK-SAME: %[[OPERAND:.+]]: tensor -// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] -// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]]) -// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] -// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[OPERAND]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] -// CHECK: %[[SORT_TILE:.+]] = iree_linalg_ext.sort -// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output" -// CHECK-SAME: outs(%[[OPERAND_SLICE]] -// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] -// CHECK: scf.yield %[[YIELD]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @sort_2d_inner_parallel(%arg0: tensor) -> tensor { - %0 = iree_linalg_ext.sort - {__internal_linalg_transform__ = "outer_reduce_input"} - dimension(0) - outs(%arg0 : tensor) { - ^bb0(%arg2: i32, %arg3: i32): // no predecessors - %0 = arith.cmpi sgt, %arg2, %arg3 : i32 - iree_linalg_ext.yield %0 : i1 - } -> tensor - return %0 : tensor -} -// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK: func.func @sort_2d_inner_parallel( -// CHECK-SAME: %[[OPERAND:.+]]: tensor -// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]] -// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]]) -// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]] -// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[OPERAND]][0, %[[IV]]] -// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] -// CHECK: %[[SORT_TILE:.+]] = iree_linalg_ext.sort -// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" -// CHECK-SAME: outs(%[[OPERAND_SLICE]] -// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][0, %[[IV]]] -// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] -// CHECK: scf.yield %[[YIELD]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @sort_2d_multi_result( - %arg0: tensor, %arg1: tensor) - -> (tensor, tensor) { - %0:2 = iree_linalg_ext.sort - {__internal_linalg_transform__ = "inner_reduce_input"} - dimension(1) - outs(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %1 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %1 : i1 - } -> tensor, tensor - return %0#0, %0#1 : tensor, tensor -} -// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK: func.func @sort_2d_multi_result( -// CHECK-SAME: %[[OPERAND1:.+]]: tensor -// CHECK-SAME: %[[OPERAND2:.+]]: tensor -// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]] -// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] -// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]]) -// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] -// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[OPERAND1]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] -// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[OPERAND2]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] -// CHECK: %[[SORT_TILE:.+]]:2 = iree_linalg_ext.sort -// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output" -// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] -// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_TILE]]#0 into %[[INIT1]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] -// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_TILE]]#1 into %[[INIT2]][%[[IV]], 0] -// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] -// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]] -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - -// ----- - -func.func @sort_2d_multi_result_memref( - %arg0: memref, %arg1: memref) { - iree_linalg_ext.sort - {__internal_linalg_transform__ = "outer_reduce_input"} - dimension(0) - outs(%arg0, %arg1 : memref, memref) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %0 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} -// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK: func.func @sort_2d_multi_result_memref( -// CHECK-SAME: %[[OPERAND1:.+]]: memref -// CHECK-SAME: %[[OPERAND2:.+]]: memref -// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]] -// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]] -// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]] -// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][0, %[[IV]]] -// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] -// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][0, %[[IV]]] -// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] -// CHECK: iree_linalg_ext.sort -// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" -// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] - -// ----- - -func.func @sort_3d_multi_result_distribute( - %arg0: tensor, %arg1 : tensor) - -> (tensor, tensor) { - %0, %1 = iree_linalg_ext.sort - {__internal_linalg_transform__ = "distribute_input"} - dimension(1) - outs(%arg0, %arg1 : tensor, tensor) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %2 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %2 : i1 - } -> tensor, tensor - return %0, %1 : tensor, tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> -// CHECK: func.func @sort_3d_multi_result_distribute( -// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[TILESIZE1:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[TILESIZE2:.+]] = arith.constant 30 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]] -// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[OPERAND1]], %[[C2]] -// CHECK-DAG: %[[IDX:.+]] = iree_input.dispatch.workgroup.id[0] -// CHECK-DAG: %[[COUNTX:.+]] = iree_input.dispatch.workgroup.count[0] -// CHECK-DAG: %[[IDY:.+]] = iree_input.dispatch.workgroup.id[1] -// CHECK-DAG: %[[COUNTY:.+]] = iree_input.dispatch.workgroup.count[1] -// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]] -// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]] -// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]] -// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]]) -// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]] -// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]] -// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]] -// CHECK: %[[RESULT_INNER:.+]]:2 = scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]] -// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT1]], %[[INIT4:.+]] = %[[INIT2]]) -// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]] -// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[OPERAND1]][%[[IV0]], 0, %[[IV1]]] -// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] -// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[OPERAND2]][%[[IV0]], 0, %[[IV1]]] -// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] -// CHECK: %[[SORT_SLICE:.+]]:2 = iree_linalg_ext.sort -// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" -// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] -// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_SLICE]]#0 -// CHECK-SAME: into %[[INIT3]][%[[IV0]], 0, %[[IV1]]] -// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_SLICE]]#1 -// CHECK-SAME: into %[[INIT4]][%[[IV0]], 0, %[[IV1]]] -// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]] -// CHECK: scf.yield %[[RESULT_INNER]]#0, %[[RESULT_INNER]]#1 -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - -// ----- - -func.func @sort_3d_multi_result_distribute_memref( - %arg0: memref, %arg1 : memref) { - iree_linalg_ext.sort - {__internal_linalg_transform__ = "distribute_input"} - dimension(1) - outs(%arg0, %arg1 : memref, memref) { - ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors - %0 = arith.cmpf ogt, %arg4, %arg5 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> -// CHECK: func.func @sort_3d_multi_result_distribute_memref( -// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[TILESIZE1:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[TILESIZE2:.+]] = arith.constant 30 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]] -// CHECK-DAG: %[[D2:.+]] = memref.dim %[[OPERAND1]], %[[C2]] -// CHECK-DAG: %[[IDX:.+]] = iree_input.dispatch.workgroup.id[0] -// CHECK-DAG: %[[COUNTX:.+]] = iree_input.dispatch.workgroup.count[0] -// CHECK-DAG: %[[IDY:.+]] = iree_input.dispatch.workgroup.id[1] -// CHECK-DAG: %[[COUNTY:.+]] = iree_input.dispatch.workgroup.count[1] -// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]] -// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]] -// CHECK: scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]] -// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]] -// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]] -// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]] -// CHECK: scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]] -// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]] -// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][%[[IV0]], 0, %[[IV1]]] -// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] -// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][%[[IV0]], 0, %[[IV1]]] -// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] -// CHECK: iree_linalg_ext.sort -// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" -// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] - -// ----- - -func.func @fft_1d_stage_5(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, - %arg2: tensor<16xf32>, %arg3: tensor<16xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { - %cst1 = arith.constant 5 : index - %0:2 = iree_linalg_ext.fft - {__internal_linalg_transform__ = "tiling_1d_stage5_fft_input"} - ins(%cst1, %arg2, %arg3: index, tensor<16xf32>, tensor<16xf32>) - outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>) - : tensor<1024xf32>, tensor<1024xf32> - return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK: func.func @fft_1d_stage_5( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index -// CHECK: %[[RES:.+]]:2 = scf.for %[[I:.+]] = %[[C0]] to %[[C1024]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG5:.+]] = %[[ARG0]], %[[ARG6:.+]] = %[[ARG1]]) -// CHECK-SAME: -> (tensor<1024xf32>, tensor<1024xf32>) { -// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C32]], %[[C1024]]] -// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG0]][%[[I]]] [%[[SIZE]]] [1] : tensor<1024xf32> to tensor -// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG1]][%[[I]]] [%[[SIZE]]] [1] : tensor<1024xf32> to tensor -// CHECK: %[[FFT:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_1d_stage5_fft_output"} -// CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<16xf32>, tensor<16xf32>) -// CHECK-SAME: outs(%[[SLICE1]], %[[SLICE2]] : tensor, tensor) -// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[FFT]]#0 into %[[ARG5]][%[[I]]] [%[[SIZE]]] [1] : tensor into tensor<1024xf32> -// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[FFT]]#1 into %[[ARG6]][%[[I]]] [%[[SIZE]]] [1] : tensor into tensor<1024xf32> -// CHECK: scf.yield %[[INSERT1]], %[[INSERT2]] -// CHECK: return %[[RES]]#0, %[[RES]]#1 : tensor<1024xf32>, tensor<1024xf32> - -// ----- - -func.func @fft_2d_stage_5(%arg0: tensor<3x1024xf32>, %arg1: tensor<3x1024xf32>, - %arg2: tensor<16xf32>, %arg3: tensor<16xf32>) -> (tensor<3x1024xf32>, tensor<3x1024xf32>) { - %cst1 = arith.constant 5 : index - %0:2 = iree_linalg_ext.fft - {__internal_linalg_transform__ = "tiling_2d_stage5_fft_input"} - ins(%cst1, %arg2, %arg3: index, tensor<16xf32>, tensor<16xf32>) - outs(%arg0, %arg1: tensor<3x1024xf32>, tensor<3x1024xf32>) - : tensor<3x1024xf32>, tensor<3x1024xf32> - return %0#0, %0#1 : tensor<3x1024xf32>, tensor<3x1024xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK: func.func @fft_2d_stage_5( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index -// CHECK: %[[RES:.+]]:2 = scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C10]] -// CHECK-SAME: iter_args(%[[ARG5:.+]] = %[[ARG0]], %[[ARG6:.+]] = %[[ARG1]]) -// CHECK-SAME: -> (tensor<3x1024xf32>, tensor<3x1024xf32>) { -// CHECK: %[[SZ1:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[C3]]] -// CHECK: %{{.+}} = scf.for %[[J:.+]] = %[[C0]] to %[[C1024]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG5]], %[[ARG9:.+]] = %[[ARG6]]) -> (tensor<3x1024xf32>, tensor<3x1024xf32>) { -// CHECK: %[[SZ2:.+]] = affine.min #[[MAP1]](%[[J]])[%[[C32]], %[[C1024]]] -// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG0]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1] -// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG1]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1] -// CHECK: %[[FFT:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_2d_stage5_fft_output"} -// CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<16xf32>, tensor<16xf32>) -// CHECK-SAME: outs(%[[SLICE1]], %[[SLICE2]] : tensor, tensor) -// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[FFT]]#0 into %[[ARG8]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1] -// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[FFT]]#1 into %[[ARG9]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1] -// CHECK: scf.yield %[[INSERT1]], %[[INSERT2]] : tensor<3x1024xf32>, tensor<3x1024xf32> - -// ----- - -func.func @fft_1d_stage_5_memref(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, - %arg2: memref<16xf32>, %arg3: memref<16xf32>) { - %cst1 = arith.constant 5 : index - iree_linalg_ext.fft - {__internal_linalg_transform__ = "tiling_1d_stage5_fft_input"} - ins(%cst1, %arg2, %arg3: index, memref<16xf32>, memref<16xf32>) - outs(%arg0, %arg1: memref<1024xf32>, memref<1024xf32>) - return -} -// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK: func.func @fft_1d_stage_5_memref( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C1024]] step %[[C32]] { -// CHECK: %[[SZ:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C32]], %[[C1024]]] -// CHECK: %[[SUB1:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref> -// CHECK: %[[SUB2:.+]] = memref.subview %[[ARG1]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref> -// CHECK: iree_linalg_ext.fft -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_1d_stage5_fft_output"} -// CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, memref<16xf32>, memref<16xf32>) -// CHECK-SAME: outs(%[[SUB1]], %[[SUB2]] : memref>, memref>) - -// ----- - -func.func @reverse_memref(%arg0: memref, %arg1: memref) { - iree_linalg_ext.reverse - {__internal_linalg_transform__ = "tiling_input"} - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0: memref) - outs(%arg1: memref) - return -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)> -// CHECK: func.func @reverse_memref( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] { -// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[D0]]] -// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE]]] -// CHECK-DAG: %[[SUB_IN:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SIZE]]] [1] -// CHECK-DAG: %[[SUB_OUT:.+]] = memref.subview %[[ARG1]][%[[IDX]]] [%[[SIZE]]] [1] -// CHECK: iree_linalg_ext.reverse -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_output"} -// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>) -// CHECK-SAME: ins(%[[SUB_IN]] -// CHECK-SAME: outs(%[[SUB_OUT]] - -// ----- - -func.func @reverse_tensor_multi_dim(%arg0: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %init = tensor.empty(%d0, %d1) : tensor - %0 = iree_linalg_ext.reverse - {__internal_linalg_transform__ = "tiling_input"} - dimensions(dense<[0, 1]> : tensor<2xi64>) - ins(%arg0: tensor) - outs(%init: tensor) : tensor - return %0 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)> -// CHECK: func.func @reverse_tensor_multi_dim( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor -// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] -// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT]]) -> (tensor) { -// CHECK: %[[SIZE_I:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[D0]]] -// CHECK: %[[RES2:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C20]] -// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT2]]) -> (tensor) { -// CHECK-DAG: %[[SIZE_J:.+]] = affine.min #[[MAP1]](%[[J]])[%[[C20]], %[[D1]]] -// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE_I]]] -// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP2]]()[%[[D1]], %[[J]], %[[SIZE_J]]] -// CHECK: %[[SUB_IN:.+]] = tensor.extract_slice -// CHECK-SAME: %[[ARG0]][%[[I]], %[[J]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1] -// CHECK: %[[SUB_INIT:.+]] = tensor.extract_slice -// CHECK-SAME: %[[INIT]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1] -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_output"} -// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>) -// CHECK-SAME: ins(%[[SUB_IN]] -// CHECK-SAME: outs(%[[SUB_INIT]] -// CHECK: %[[RES3:.+]] = tensor.insert_slice %[[REV]] into -// CHECK-SAME: %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1] -// CHECK: scf.yield %[[RES3]] -// CHECK: scf.yield %[[RES2]] -// CHECK: return %[[RES]] - -// ----- - -func.func @scan_1d(%0: tensor<128xi32>) -> tensor<128xi32> { - %c0 = tensor.empty() : tensor - %1 = tensor.empty() : tensor<128xi32> - %2:2 = iree_linalg_ext.scan - {__internal_linalg_transform__ = "outer_reduce_input"} - dimension(0) inclusive(true) - ins(%0 : tensor<128xi32>) outs(%1, %c0 : tensor<128xi32>, tensor) { - ^bb0(%arg0 : i32, %arg1 : i32): - %sum = arith.addi %arg0, %arg1 : i32 - iree_linalg_ext.yield %sum : i32 - } -> tensor<128xi32>, tensor - return %2#0 : tensor<128xi32> -} -// CHECK: func.func @scan_1d( -// CHECK-SAME: %[[OPERAND:.+]]: tensor<128xi32> -// CHECK: %[[ACC:.+]] = tensor.empty() : tensor -// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<128xi32> -// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.scan -// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" -// CHECK-SAME: ins(%[[OPERAND]] : -// CHECK-SAME: outs(%[[OUTPUT]], %[[ACC]] : -// CHECK: return %[[RESULT]] - -// ----- - -func.func @scan_2d(%0: tensor<16x32xi32>) -> tensor<16x32xi32> { - %c0 = tensor.empty() : tensor<32xi32> - %1 = tensor.empty() : tensor<16x32xi32> - %2:2 = iree_linalg_ext.scan - {__internal_linalg_transform__ = "outer_reduce_input"} - dimension(0) inclusive(true) - ins(%0 : tensor<16x32xi32>) outs(%1, %c0 : tensor<16x32xi32>, tensor<32xi32>) { - ^bb0(%arg0 : i32, %arg1 : i32): - %sum = arith.addi %arg0, %arg1 : i32 - iree_linalg_ext.yield %sum : i32 - } -> tensor<16x32xi32>, tensor<32xi32> - return %2#0 : tensor<16x32xi32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK: func.func @scan_2d( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[C16:.+]] = arith.constant 16 : index -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: %[[C20:.+]] = arith.constant 20 : index -// CHECK: %[[ACC:.+]] = tensor.empty() : tensor<32xi32> -// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<16x32xi32> -// CHECK: %[[RESULT:.+]]:2 = scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C20]] -// CHECK-SAME: iter_args(%[[ARG2:.+]] = %[[OUTPUT]], %[[ARG3:.+]] = %[[ACC]]) -// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C20]], %[[C32]]] -// CHECK: %[[UPDATE_SLICE_IN:.+]] = tensor.extract_slice %[[ARG0]][0, %[[I]]] [%[[C16]], %[[SIZE]]] -// CHECK: %[[UPDATE_SLICE_OUT:.+]] = tensor.extract_slice %[[OUTPUT]][0, %[[I]]] [%[[C16]], %[[SIZE]]] -// CHECK: %[[UPDATE_SLICE_ACC:.+]] = tensor.extract_slice %[[ACC]][%[[I]]] [%[[SIZE]]] -// CHECK: %[[SCAN_TILE:.+]]:2 = iree_linalg_ext.scan -// CHECK-SAME: {__internal_linalg_transform__ = "outer_reduce_output"} -// CHECK-SAME: dimension(0) inclusive(true) -// CHECK-SAME: ins(%[[UPDATE_SLICE_IN]] -// CHECK-SAME: outs(%[[UPDATE_SLICE_OUT]], %[[UPDATE_SLICE_ACC]] -// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCAN_TILE]]#0 into %[[ARG2]][0, %[[I]]] -// CHECK-SAME: [%[[C16]], %[[SIZE]]] -// CHECK: %[[ACC_YIELD:.+]] = tensor.insert_slice %[[SCAN_TILE]]#1 into %[[ARG3]][%[[I]]] -// CHECK-SAME: [%[[SIZE]]] -// CHECK: scf.yield %[[YIELD]], %[[ACC_YIELD]] : tensor<16x32xi32>, tensor<32xi32> -// CHECK: return %[[RESULT]]#0 - -// ----- - -func.func @scan_2d_memref(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { - %c0 = memref.alloc() : memref<32xi32> - iree_linalg_ext.scan - {__internal_linalg_transform__ = "outer_reduce_input"} - dimension(0) inclusive(true) - ins(%0 : memref<16x32xi32>) outs(%1, %c0 : memref<16x32xi32>, memref<32xi32>) { - ^bb0(%arg0 : i32, %arg1 : i32): - %sum = arith.addi %arg0, %arg1 : i32 - iree_linalg_ext.yield %sum : i32 - } - return -} -// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> -// CHECK: func.func @scan_2d_memref( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[C16:.+]] = arith.constant 16 : index -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: %[[C20:.+]] = arith.constant 20 : index -// CHECK: %[[ACC:.+]] = memref.alloc() : memref<32xi32> -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C20]] -// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C20]], %[[C32]]] -// CHECK: %[[UPDATE_SLICE_IN:.+]] = memref.subview %[[ARG0]][0, %[[I]]] [%[[C16]], %[[SIZE]]] -// CHECK: %[[UPDATE_SLICE_OUT:.+]] = memref.subview %[[ARG1]][0, %[[I]]] [%[[C16]], %[[SIZE]]] -// CHECK: %[[UPDATE_SLICE_ACC:.+]] = memref.subview %[[ACC]][%[[I]]] [%[[SIZE]]] -// CHECK: iree_linalg_ext.scan -// CHECK-SAME: {__internal_linalg_transform__ = "outer_reduce_output"} -// CHECK-SAME: dimension(0) inclusive(true) -// CHECK-SAME: ins(%[[UPDATE_SLICE_IN]] -// CHECK-SAME: outs(%[[UPDATE_SLICE_OUT]], %[[UPDATE_SLICE_ACC]] -// CHECK: return - -// ----- - -func.func @topk_tile_tensor(%input_values: tensor, %input_indices: tensor, %out_values: tensor , %out_indices: tensor) -> (tensor, tensor) { - %0:2 = iree_linalg_ext.topk - {__internal_linalg_transform__ = "inner_reduce_input"} - dimension(1) - ins(%input_values, %input_indices : tensor , tensor) - outs(%out_values, %out_indices : tensor, tensor) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor, tensor - return %0#0, %0#1 : tensor, tensor -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-LABEL: func.func @topk_tile_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0:.+]], %[[C0]] -// CHECK: %[[D1:.+]] = tensor.dim %[[ARG0:.+]], %[[C1]] -// CHECK: %[[RESULT:.+]]:2 = scf.for %[[ARG4:.+]] = %[[C0]] to %[[D0]] step %[[C10]] iter_args(%[[ARG5:.+]] = %[[ARG2]], %[[ARG6:.+]] = %[[ARG3]]) -// CHECK: %[[D3:.+]] = affine.min #[[MAP0]](%[[ARG4]])[%[[C10]], %[[D0]]] -// CHECK: %[[D4:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0] [%[[D3]], %[[D1]]] [1, 1] -// CHECK: %[[D5:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0] [%[[D3]], %[[D1]]] [1, 1] -// CHECK: %[[D6:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG4]], 0] [%[[D3]], 3] [1, 1] -// CHECK: %[[D7:.+]] = tensor.extract_slice %[[ARG3]][%[[ARG4]], 0] [%[[D3]], 3] [1, 1] -// CHECK: %[[D8:.+]]:2 = iree_linalg_ext.topk {__internal_linalg_transform__ = "inner_reduce_output"} -// CHECK-SAME: dimension(1) -// CHECK-SAME: ins(%[[D4]], %[[D5]] -// CHECK-SAME: outs(%[[D6]], %[[D7]] -// CHECK: %[[D9:.+]] = tensor.insert_slice %[[D8]]#0 into %[[ARG5]][%[[ARG4]], 0] [%[[D3]], 3] [1, 1] -// CHECK: %[[D10:.+]] = tensor.insert_slice %[[D8]]#1 into %[[ARG6]][%[[ARG4]], 0] [%[[D3]], 3] [1, 1] -// CHECK: scf.yield %[[D9]], %[[D10]] -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - - -// ----- - -func.func @topk_tile_memref(%input_values: memref, %input_indices: memref, %out_values: memref, %out_indices: memref) { - iree_linalg_ext.topk - {__internal_linalg_transform__ = "inner_reduce_input"} - dimension(1) - ins(%input_values, %input_indices : memref , memref) - outs(%out_values, %out_indices : memref, memref) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } - return -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-LABEL: func.func @topk_tile_memref -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK: %[[D0:.+]] = memref.dim %[[ARG0:.+]], %[[C0]] -// CHECK: %[[D1:.+]] = memref.dim %[[ARG0:.+]], %[[C1]] -// CHECK: scf.for %[[ARG4:.+]] = %[[C0]] to %[[D0]] step %[[C10]] -// CHECK: %[[D2:.+]] = affine.min #[[MAP0]](%[[ARG4]])[%[[C10]], %[[D0]]] -// CHECK: %[[D3:.+]] = memref.subview %[[ARG0]][%[[ARG4]], 0] [%[[D2]], %[[D1]]] [1, 1] -// CHECK: %[[D4:.+]] = memref.subview %[[ARG1]][%[[ARG4]], 0] [%[[D2]], %[[D1]]] [1, 1] -// CHECK: %[[D5:.+]] = memref.subview %[[ARG2]][%[[ARG4]], 0] [%[[D2]], 3] [1, 1] -// CHECK: %[[D6:.+]] = memref.subview %[[ARG3]][%[[ARG4]], 0] [%[[D2]], 3] [1, 1] -// CHECK: iree_linalg_ext.topk {__internal_linalg_transform__ = "inner_reduce_output"} -// CHECK-SAME: dimension(1) -// CHECK-SAME: ins(%[[D3]], %[[D4]] -// CHECK-SAME: outs(%[[D5]], %[[D6]] -// CHECK: return - -// ----- - -func.func @topk_tile_tensor_optional(%input_values: tensor<20x10xf32>, %out_values: tensor<20x3xf32> , %out_indices: tensor<20x3xi32>) -> (tensor<20x3xf32>, tensor<20x3xi32>) { - %0:2 = iree_linalg_ext.topk - {__internal_linalg_transform__ = "inner_reduce_input"} - dimension(1) - ins(%input_values : tensor<20x10xf32>) - outs(%out_values, %out_indices : tensor<20x3xf32>, tensor<20x3xi32>) { - ^bb0(%arg0: f32, %arg1: f32): // no predecessors - %0 = arith.cmpf ogt, %arg0, %arg1 : f32 - iree_linalg_ext.yield %0 : i1 - } -> tensor<20x3xf32>, tensor<20x3xi32> - return %0#0, %0#1 : tensor<20x3xf32>, tensor<20x3xi32> -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-LABEL: func.func @topk_tile_tensor_optional -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK: %[[RESULT:.+]]:2 = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C20]] step %[[C10]] iter_args(%[[ARG4:.+]] = %[[ARG1]], %[[ARG5:.+]] = %[[ARG2]]) -// CHECK: %[[D1:.+]] = affine.min #[[MAP0]](%[[ARG3]])[%[[C10]], %[[C20]]] -// CHECK: %[[D2:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0] [%[[D1]], %[[C10]]] [1, 1] -// CHECK: %[[D3:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0] [%[[D1]], 3] [1, 1] -// CHECK: %[[D4:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG3]], 0] [%[[D1]], 3] [1, 1] -// CHECK: %[[D5:.+]]:2 = iree_linalg_ext.topk {__internal_linalg_transform__ = "inner_reduce_output"} -// CHECK-SAME: dimension(1) -// CHECK-SAME: ins(%[[D2]] -// CHECK-SAME: outs(%[[D3]], %[[D4]] -// CHECK: %[[D6:.+]] = tensor.insert_slice %[[D5]]#0 into %[[ARG4]][%[[ARG3]], 0] [%[[D1]], 3] [1, 1] -// CHECK: %[[D7:.+]] = tensor.insert_slice %[[D5]]#1 into %[[ARG5]][%[[ARG3]], 0] [%[[D1]], 3] [1, 1] -// CHECK: scf.yield %[[D6]], %[[D7]] -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 - -// ----- - -func.func @NC_to_NCnc(%arg0: tensor<128x256xf32>, %arg1: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { - %0 = iree_linalg_ext.pack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : (tensor<128x256xf32> tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> - return %0 : tensor<4x8x32x32xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (64, d0 * -32 + 128)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (128, d0 * -32 + 256)> -// CHECK-LABEL: func.func @NC_to_NCnc( -// CHECK-SAME: %[[IN:.*]]: tensor<128x256xf32>, -// CHECK-SAME: %[[OUT:.*]]: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[RES0:.*]] = scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<4x8x32x32xf32>) { -// CHECK: %[[RES1:.+]] = scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<4x8x32x32xf32>) { -// CHECK-DAG: %[[IN_N:.+]] = affine.apply #[[MAP0]](%[[N]]) -// CHECK-DAG: %[[IN_N_SZ:.*]] = affine.min #[[MAP1]] -// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]]) -// CHECK-DAG: %[[IN_C_SZ:.*]] = affine.min #[[MAP2]] -// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor -// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [%[[C2]], %[[C4]], 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor -// CHECK: %[[SUB_RES:.*]] = iree_linalg_ext.pack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"} -// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[SUB_OUT]] -// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]] -// CHECK: scf.yield %[[INSERT]] : tensor<4x8x32x32xf32> -// CHECK: } -// CHECK: scf.yield %[[RES1:.*]] : tensor<4x8x32x32xf32> -// CHECK: } -// CHECK: return %[[RES0:.*]] : tensor<4x8x32x32xf32> -// CHECK: } - -// ----- - -func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> { - %0 = iree_linalg_ext.pack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (tensor<128x256xf32> tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> - return %0 : tensor<32x4x32x8xf32> -} -// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (128, d0 * -32 + 128)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> -// CHECK: #[[MAP3:.+]] = affine_map<(d0) -> (16, d0 * -8 + 256)> -// CHECK-LABEL: func.func @KC_to_CKkc -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK: scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C2]] -// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C4]] step %[[C4]] -// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) -// CHECK-DAG: %[[IN_K_SZ:.+]] = affine.min #[[MAP1]](%[[K]]) -// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]]) -// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP3]](%[[C]]) -// CHECK: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]] -// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [%[[IN_K_SZ]], %[[IN_C_SZ]]] -// CHECK: %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], %[[K]], 0, 0] [%[[C2]], %[[C4]], 32, 8] -// CHECK: iree_linalg_ext.pack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"} -// CHECK-SAME: %[[INPUT_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] -// CHECK-SAME: into %[[OUTPUT_SLICE]] - -// ----- - -func.func @pad_and_pack_static(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: f32) -> tensor<2x8x8x2xf32> { - %0 = iree_linalg_ext.pack {__internal_linalg_transform__ = "tiling_pack_input"} %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> - return %0 : tensor<2x8x8x2xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 8)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (16, d0 * -8 + 13)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 2)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (8, d0 * -2 + 15)> -// CHECK-LABEL: func.func @pad_and_pack_static( -// CHECK-SAME: %[[IN:.*]]: tensor<13x15xf32>, -// CHECK-SAME: %[[OUT:.*]]: tensor<2x8x8x2xf32>, -// CHECK-SAME: %[[PAD:.*]]: f32) -> tensor<2x8x8x2xf32> { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK: %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<2x8x8x2xf32>) { -// CHECK: %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<2x8x8x2xf32>) { -// CHECK-DAG: %[[IN_I:.*]] = affine.apply #[[MAP0]](%[[I]]) -// CHECK-DAG: %[[IN_I_SZ:.*]] = affine.min #[[MAP1]](%[[I]]) -// CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]]) -// CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[J]]) -// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor<13x15xf32> to tensor -// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[C2]], %[[C4]], 8, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor -// CHECK: %[[SUB_RES:.*]] = iree_linalg_ext.pack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"} -// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] -// CHECK-SAME: into %[[SUB_OUT]] -// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]] -// CHECK: scf.yield %[[INSERT]] : tensor<2x8x8x2xf32> -// CHECK: } -// CHECK: scf.yield %[[RES1:.*]] : tensor<2x8x8x2xf32> -// CHECK: } -// CHECK: return %[[RES0:.*]] : tensor<2x8x8x2xf32> -// CHECK: } - -// ----- - -func.func @pad_and_pack_partially_dynamic(%input: tensor, %output: tensor, %pad: f32) -> tensor { - %0 = iree_linalg_ext.pack {__internal_linalg_transform__ = "tiling_pack_input"} %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor tensor) -> tensor - return %0 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 8, d1 * -8 + s0)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 2)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 2, d1 * -2 + s0)> -// CHECK-LABEL: func.func @pad_and_pack_partially_dynamic( -// CHECK-SAME: %[[IN:.*]]: tensor, -// CHECK-SAME: %[[OUT:.*]]: tensor, -// CHECK-SAME: %[[PAD:.*]]: f32) -> tensor { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor -// CHECK-DAG: %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor -// CHECK: %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor) { -// CHECK-DAG: %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]] -// CHECK: %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor) { -// CHECK-DAG: %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]] -// CHECK-DAG: %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]]) -// CHECK-DAG: %[[IN_I_SZ:.*]] = affine.min #[[MAP3]] -// CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP4]](%[[J]]) -// CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP5]] -// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor to tensor -// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], 8, 2] [1, 1, 1, 1] : tensor to tensor -// CHECK: %[[SUB_RES:.*]] = iree_linalg_ext.pack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"} -// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] -// CHECK-SAME: into %[[SUB_OUT]] -// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]] - -// CHECK: scf.yield %[[INSERT]] : tensor -// CHECK: } -// CHECK: scf.yield %[[RES1:.*]] : tensor -// CHECK: } -// CHECK: return %[[VAL_34:.*]] : tensor -// CHECK: } - -// ----- - -func.func @pad_and_pack_fully_dynamic(%input: tensor, %output: tensor, %pad: f32, %tile_n : index, %tile_m : index) -> tensor { - %0 = iree_linalg_ext.pack {__internal_linalg_transform__ = "tiling_pack_input"} %input padding_value(%pad : f32) - inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %output : (tensor tensor) -> tensor - return %0 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0, -(d1 * s0) + s1)> -// CHECK-LABEL: func.func @pad_and_pack_fully_dynamic( -// CHECK-SAME: %[[IN:.*]]: tensor, -// CHECK-SAME: %[[OUT:.*]]: tensor, -// CHECK-SAME: %[[PAD:.*]]: f32, -// CHECK-SAME: %[[TILE_0:.*]]: index, -// CHECK-SAME: %[[TILE_1:.*]]: index) -> tensor { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor -// CHECK-DAG: %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor -// CHECK: %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor) { -// CHECK: %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]] -// CHECK: %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor) { -// CHECK: %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]] -// CHECK: %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])[%[[TILE_0]]] -// CHECK: %[[IN_D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[IN_I_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_I_SZ]], %[[I]])[%[[TILE_0]], %[[IN_D0]]] -// CHECK: %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]] -// CHECK: %[[IN_D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_J_SZ]], %[[J]])[%[[TILE_1]], %[[IN_D1]]] -// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor to tensor -// CHECK: %[[OUT_D2:.+]] = tensor.dim %[[OUT]], %[[C2]] -// CHECK: %[[OUT_D3:.+]] = tensor.dim %[[OUT]], %[[C3]] -// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], %[[OUT_D2]], %[[OUT_D3]]] [1, 1, 1, 1] : tensor to tensor -// CHECK: %[[PACK:.*]] = iree_linalg_ext.pack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"} -// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_0]], %[[TILE_1]]] -// CHECK-SAME: into %[[SUB_OUT]] -// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[PACK]] into %[[ITER1]] -// CHECK: scf.yield %[[INSERT]] : tensor -// CHECK: } -// CHECK: scf.yield %[[RES1:.*]] : tensor -// CHECK: } -// CHECK: return %[[RES0:.*]] : tensor -// CHECK: } - -// ----- - -func.func @NCnc_to_NC(%input: tensor<8x8x32x16xf32>, %output: tensor<256x128xf32>) -> tensor<256x128xf32> { - %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %input inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %output : (tensor<8x8x32x16xf32> tensor<256x128xf32>) -> tensor<256x128xf32> - return %0 : tensor<256x128xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 16)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 16)> -// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 16 - d0 floordiv 16 + 1)> -// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 16) * 16 - (d0 floordiv 16) * 16 + 16)> -// CHECK-LABEL: func.func @NCnc_to_NC -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index -// CHECK: %{{.+}} = scf.for %[[I:.+]] = %[[C0]] to %[[C256]] step %[[C2]] -// CHECK: %{{.+}} = scf.for %[[J:.+]] = %[[C0]] to %[[C128]] step %[[C4]] -// CHECK-DAG: %[[IN_I:.+]] = affine.apply #[[MAP0]](%[[I]]) -// CHECK-DAG: %[[OFFSET_I:.+]] = affine.apply #[[MAP1]](%[[I]]) -// CHECK-DAG: %[[IN_I_SZ:.+]] = affine.apply #[[MAP2]](%[[I]]) -// CHECK-DAG: %[[IN_J:.+]] = affine.apply #[[MAP4]](%[[J]]) -// CHECK-DAG: %[[OFFSET_J:.+]] = affine.apply #[[MAP5]](%[[J]]) -// CHECK-DAG: %[[IN_J_SZ:.+]] = affine.apply #[[MAP6]](%[[J]]) -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN]] -// CHECK-SAME: [%[[IN_I]], %[[IN_J]], 0, 0] [%[[IN_I_SZ]], %[[IN_J_SZ]], 32, 16] -// CHECK-SAME: : tensor<8x8x32x16xf32> to tensor -// CHECK: %[[EMPTY:.+]] = tensor.empty -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output"} -// CHECK-SAME: %[[SLICE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] -// CHECK-SAME: into %[[EMPTY]] -// CHECK: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]] -// CHECK-SAME: [%[[OFFSET_I]], %[[OFFSET_J]]] [%[[C2]], %[[C4]]] -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK_SLICE]] -// CHECK-SAME: into %{{.+}}[%[[I]], %[[J]]] [%[[C2]], %[[C4]]] -// CHECK: scf.yield %[[RES]] - -// ----- - -func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { - %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (tensor<32x4x32x8xf32> tensor<128x256xf32>) -> tensor<128x256xf32> - return %0 : tensor<128x256xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (((d0 + 1) floordiv 32) * 32 - (d0 floordiv 32) * 32 + 32)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 8)> -// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 8 - d0 floordiv 8 + 1)> -// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0) -> (((d0 + 3) floordiv 8) * 8 - (d0 floordiv 8) * 8 + 8)> -// CHECK-LABEL: func.func @CKkc_to_KC -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index -// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C128]] step %[[C2]] -// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[C256]] step %[[C4]] -// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) -// CHECK-DAG: %[[OFFSET_K:.+]] = affine.apply #[[MAP1]](%[[K]]) -// CHECK-DAG: %[[IN_K_SZ:.+]] = affine.apply #[[MAP2]](%[[K]]) -// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP4]](%[[C]]) -// CHECK-DAG: %[[OFFSET_C:.+]] = affine.apply #[[MAP5]](%[[C]]) -// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.apply #[[MAP6]](%[[C]]) -// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]] -// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [%[[IN_C_SZ]], %[[IN_K_SZ]], 32, 8] -// CHECK: %[[EMPTY:.+]] = tensor.empty -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output" -// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] -// CHECK-SAME: into %[[EMPTY]] -// CHECK: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]] -// CHECK-SAME: [%[[OFFSET_K]], %[[OFFSET_C]]] [%[[C2]], %[[C4]]] -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK_SLICE]] -// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]] -// CHECK: scf.yield %[[RES]] - -// ----- - -func.func @perfect_CKkc_to_KC(%arg0: tensor<32x4x2x4xf32>, %arg1: tensor<8x128xf32>) -> tensor<8x128xf32> { - %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %arg1 : (tensor<32x4x2x4xf32> tensor<8x128xf32>) -> tensor<8x128xf32> - return %0 : tensor<8x128xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 4)> -// CHECK-LABEL: func.func @perfect_CKkc_to_KC -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index -// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C8]] step %[[C2]] -// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[C128]] step %[[C4]] -// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) -// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]]) -// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]] -// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 2, 4] -// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]] -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output" -// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 4] -// CHECK-SAME: into %[[ITER_SLICE]] -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] -// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[C2]], %[[C4]]] -// CHECK: scf.yield %[[RES]] - -// ----- - -func.func @dynamic_perfect_CKkc_to_KC(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = iree_linalg_ext.unpack {__internal_linalg_transform__ = "tiling_pack_input"} %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %arg1 : (tensor tensor) -> tensor - return %0 : tensor -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 2)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 ceildiv 2)> -// CHECK-LABEL: func.func @dynamic_perfect_CKkc_to_KC -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C0]] -// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[DIM_0]] step %[[C2]] -// CHECK-DAG: %[[OUT_K_SZ:.+]] = affine.min #[[MAP0]](%[[K]])[%[[DIM_0]]] -// CHECK: %{{.+}} = scf.for %[[C:.+]] = %[[C0]] to %[[DIM_1]] step %[[C4]] -// CHECK-DAG: %[[OUT_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])[%[[DIM_1]]] -// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP2]](%[[K]]) -// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]]) -// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.apply #[[MAP3]](%[[OUT_C_SZ]]) -// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]] -// CHECK: [%[[IN_C]], %[[IN_K]], 0, 0] [%[[IN_C_SZ]], 1, 2, 2] -// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]] -// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack -// CHECK-SAME: {__internal_linalg_transform__ = "tiling_pack_output" -// CHECK-SAME: %[[IN_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] -// CHECK-SAME: into %[[ITER_SLICE]] -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] -// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]] -// CHECK: scf.yield %[[RES]] - -// ----- - -func.func @winograd_input_transform(%arg0: tensor<1x10x10x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> { - %0 = tensor.empty() : tensor<8x8x1x2x2x1280xf32> - %1 = iree_linalg_ext.winograd.input_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"} - output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32> - return %1 : tensor<8x8x1x2x2x1280xf32> -} -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK: func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) -> -// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> { -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C1280:.+]] = arith.constant 1280 : index -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32> -// CHECK: %[[D1:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<8x8x1x2x2x1280xf32>) { -// CHECK-DAG: %[[D2:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]] -// CHECK: %[[D3:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<8x8x1x2x2x1280xf32>) { -// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C1280]]] -// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D2]], 10, -// CHECK-SAME: 10, %[[D4]]] [1, 1, 1, 1] : tensor<1x10x10x1280xf32> to tensor -// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[D0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8, -// CHECK-SAME: %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x1280xf32> to -// CHECK-SAME: tensor<8x8x?x2x2x?xf32> -// CHECK: %[[D5:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXTRACTED_SLICE]] : tensor) -// CHECK-SAME: outs(%[[EXTRACTED_SLICE]]_0 : tensor<8x8x?x2x2x?xf32>) -> tensor<8x8x?x2x2x?xf32> -// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG4]][0, 0, %[[ARG1]], 0, 0, -// CHECK-SAME: %[[ARG3]]] [8, 8, %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x?x2x2x?xf32> into -// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> -// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<8x8x1x2x2x1280xf32> -// CHECK: } -// CHECK: scf.yield %[[D3]] : tensor<8x8x1x2x2x1280xf32> -// CHECK: } -// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32> -// CHECK: } - -// ----- - -func.func @winograd_input_transform_memref(%arg0: memref<1x10x10x1280xf32>, %arg1: memref<8x8x1x2x2x1280xf32>) { - iree_linalg_ext.winograd.input_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"} - output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : memref<1x10x10x1280xf32>) outs(%arg1 : memref<8x8x1x2x2x1280xf32>) - return -} -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK: func.func @winograd_input_transform_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<1x10x10x1280xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<8x8x1x2x2x1280xf32>) { -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C1280:.+]] = arith.constant 1280 : index -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: scf.for %[[ARG2:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] { -// CHECK-DAG: %[[D0:.+]] = affine.min #[[MAP2]](%[[ARG2]])[%[[C1]], %[[C1]]] -// CHECK: scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1280]] step %[[C32]] { -// CHECK-DAG: %[[D1:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C32]], %[[C1280]]] -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG3]]] [%[[D0]], 10, 10, %[[D1]]] -// CHECK-SAME: [1, 1, 1, 1] : memref<1x10x10x1280xf32> to memref> -// CHECK: %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][0, 0, %[[ARG2]], 0, 0, %[[ARG3]]] [8, 8, %[[D0]], 2, -// CHECK-SAME: 2, %[[D1]]] [1, 1, 1, 1, 1, 1] : memref<8x8x1x2x2x1280xf32> to memref<8x8x?x2x2x?xf32, -// CHECK-SAME: strided<[40960, 5120, 5120, 2560, 1280, 1], offset: ?>> -// CHECK: iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3) image_dimensions([1, -// CHECK-SAME: 2]) ins(%[[SUBVIEW]] : memref>) -// CHECK-SAME: outs(%[[SUBVIEW]]_0 : memref<8x8x?x2x2x?xf32, strided<[40960, 5120, 5120, 2560, 1280, 1], offset: -// CHECK-SAME: ?>>) -// CHECK: } -// CHECK: } -// CHECK: return -// CHECK: } - -// ----- - -func.func @winograd_output_transform(%arg0: tensor<8x8x1x2x2x32xf32>) -> tensor<1x12x12x32xf32> { - %0 = tensor.empty() : tensor<1x12x12x32xf32> - %1 = iree_linalg_ext.winograd.output_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"} - output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : tensor<8x8x1x2x2x32xf32>) outs(%0 : tensor<1x12x12x32xf32>) -> tensor<1x12x12x32xf32> - return %1 : tensor<1x12x12x32xf32> -} -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x32xf32>) -> -// CHECK-SAME: tensor<1x12x12x32xf32> { -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x12x12x32xf32> -// CHECK: %[[D1:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<1x12x12x32xf32>) { -// CHECK-DAG: %[[D2:.+]] = affine.min #[[MAP]](%[[ARG1]])[%[[C1]], %[[C1]]] -// CHECK: %[[D3:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<1x12x12x32xf32>) { -// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG3]])[%[[C32]], %[[C32]]] -// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG1]], 0, 0, %[[ARG3]]] [8, 8, -// CHECK-SAME: %[[D2]], 2, 2, %[[D4]]] [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> to tensor<8x8x?x2x2x?xf32> -// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[D0]][%[[ARG1]], 0, 0, %[[ARG3]]] [%[[D2]], 12, -// CHECK-SAME: 12, %[[D4]]] [1, 1, 1, 1] : tensor<1x12x12x32xf32> to tensor -// CHECK: %[[D5:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) -// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXTRACTED_SLICE]] : tensor<8x8x?x2x2x?xf32>) -// CHECK-SAME: outs(%[[EXTRACTED_SLICE]]_0 : tensor) -> tensor -// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG4]][%[[ARG1]], 0, 0, %[[ARG3]]] -// CHECK-SAME: [%[[D2]], 12, 12, %[[D4]]] [1, 1, 1, 1] : tensor into tensor<1x12x12x32xf32> -// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<1x12x12x32xf32> -// CHECK: } -// CHECK: scf.yield %[[D3]] : tensor<1x12x12x32xf32> -// CHECK: } -// CHECK: return %[[D1]] : tensor<1x12x12x32xf32> -// CHECK: } - -// ----- - -func.func @winograd_output_transform_memref(%arg0: memref<8x8x1x2x2x32xf32>, %arg1: memref<1x12x12x32xf32>) { - iree_linalg_ext.winograd.output_transform {__internal_linalg_transform__ = "tiling_winograd_input_nhwc"} - output_tile_size(6) kernel_size(3) image_dimensions([1, 2]) - ins(%arg0 : memref<8x8x1x2x2x32xf32>) outs(%arg1 : memref<1x12x12x32xf32>) - return -} -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)> -// CHECK: func.func @winograd_output_transform_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<8x8x1x2x2x32xf32>, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<1x12x12x32xf32>) { -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C32:.+]] = arith.constant 32 : index -// CHECK: scf.for %[[ARG2:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1]] step %[[C1]] { -// CHECK-DAG: %[[D0:.+]] = affine.min #[[MAP2]](%[[ARG2]])[%[[C1]], %[[C1]]] -// CHECK: scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C32]] { -// CHECK-DAG: %[[D1:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C32]], %[[C32]]] -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0, %[[ARG2]], 0, 0, %[[ARG3]]] [8, 8, %[[D0]], 2, 2, -// CHECK-SAME: %[[D1]]] [1, 1, 1, 1, 1, 1] : memref<8x8x1x2x2x32xf32> to memref<8x8x?x2x2x?xf32, strided<[1024, -// CHECK-SAME: 128, 128, 64, 32, 1], offset: ?>> -// CHECK: %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][%[[ARG2]], 0, 0, %[[ARG3]]] [%[[D0]], 12, 12, %[[D1]]] -// CHECK-SAME: [1, 1, 1, 1] : memref<1x12x12x32xf32> to memref> -// CHECK: iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, -// CHECK-SAME: 2]) ins(%[[SUBVIEW]] : memref<8x8x?x2x2x?xf32, strided<[1024, 128, 128, 64, 32, 1], offset: ?>>) -// CHECK-SAME: outs(%[[SUBVIEW]]_0 : memref>) -// CHECK: } -// CHECK: } -// CHECK: return -// CHECK: } - -// ----- diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir deleted file mode 100644 index 17671aa2e8b5..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/iree_linalg_ext/vectorization.mlir +++ /dev/null @@ -1,354 +0,0 @@ -// RUN: iree-dialects-opt --iree-linalg-ext-vectorization --split-input-file %s | FileCheck %s - -func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> { - %0 = iree_linalg_ext.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : (tensor<1x1x32x8xf32> tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> - return %0 : tensor<1x1x1x1x8x32xf32> -} -// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true, true, true] -// CHECK-SAME: : tensor<1x1x32x8xf32>, vector<1x1x32x8xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast %[[READ]] : vector<1x1x32x8xf32> to vector<1x1x1x1x32x8xf32> -// CHECK: %[[TRANS:.+]] = vector.transpose %[[BCAST]], [2, 3, 0, 1, 5, 4] : vector<1x1x1x1x32x8xf32> to vector<1x1x1x1x8x32xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[TRANS]], -// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x8x32xf32>, tensor<1x1x1x1x8x32xf32> - -// ----- - -func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> { - %0 = iree_linalg_ext.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<5x1xf32> tensor<1x1x8x2xf32>) -> tensor<1x1x8x2xf32> - return %0 : tensor<1x1x8x2xf32> -} -// CHECK-LABEL: func.func @simple_pad_and_pack -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[PAD:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]][%[[C0]], %[[C0]]], %[[PAD]] -// CHECK-SAME: : tensor<5x1xf32>, vector<8x2xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast %[[READ]] : vector<8x2xf32> to vector<1x1x8x2xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[BCAST]], -// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<1x1x8x2xf32>, tensor<1x1x8x2xf32> - -// ----- - -func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{ - %0 = iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (tensor<32x8xf32> tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> - return %0 : tensor<1x1x32x8xf32> -} -// CHECK-LABEL: func.func @simple_NC_to_CNnc -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]][%[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true]} -// CHECK-SAME: : tensor<32x8xf32>, vector<32x8xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast %[[READ]] : vector<32x8xf32> to vector<1x1x32x8xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[BCAST]], -// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<1x1x32x8xf32>, tensor<1x1x32x8xf32> - -// ----- - -func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> { - %0 = iree_linalg_ext.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : (tensor<1x1x128x64xf32> tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> - return %0 : tensor<1x1x4x8x8x32xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 8)> -// CHECK-LABEL: func.func @KCRS_to_KCRSsr -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK: %[[RES0:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[OUT]]) -// CHECK: %[[RES1:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[C8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]]) -// CHECK-DAG: %[[IDX2:.+]] = affine.apply #[[MAP0]](%[[I]]) -// CHECK-DAG: %[[IDX3:.+]] = affine.apply #[[MAP1]](%[[J]]) -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[C0]], %[[C0]], %[[IDX2]], %[[IDX3]]] -// CHECK-SAME: {in_bounds = [true, true, true, true]} -// CHECK-SAME: : tensor<1x1x128x64xf32>, vector<1x1x32x8xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast %[[READ]] : vector<1x1x32x8xf32> to vector<1x1x1x1x32x8xf32> -// CHECK: %[[TRANS:.+]] = vector.transpose %[[BCAST]], [2, 3, 0, 1, 5, 4] : vector<1x1x1x1x32x8xf32> to vector<1x1x1x1x8x32xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[TRANS]] -// CHECK-SAME: %[[ITER1]][%[[C0]], %[[C0]], %[[I]], %[[J]], %[[C0]], %[[C0]] -// CHECK-SAME: {in_bounds = [true, true, true, true, true, true]} -// CHECK-SAME: : vector<1x1x1x1x8x32xf32>, tensor<1x1x4x8x8x32xf32> -// CHECK: scf.yield %[[WRITE]] -// CHECK: } -// CHECK: scf.yield %[[RES1]] -// CHECK: } -// CHECK: return %[[RES0]] - -// ----- - -func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %arg2: f32) -> tensor<2x8x8x2xf32> { - %0 = iree_linalg_ext.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : (tensor<13x15xf32> tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> - return %0 : tensor<2x8x8x2xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 8)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -8 + 13, 8)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 2)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -2 + 15, 2)> -// CHECK-LABEL: func.func @pad_and_pack -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[PAD:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK: %[[RES0:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[OUT]]) -// CHECK: %[[RES1:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[C8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]]) -// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]](%[[I]]) -// CHECK-DAG: %[[SZ0:.+]] = affine.min #[[MAP1]](%[[I]]) -// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP2]](%[[J]]) -// CHECK-DAG: %[[SZ1:.+]] = affine.min #[[MAP3]](%[[J]]) -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN]][%[[IDX0]], %[[IDX1]]] [%[[SZ0]], %[[SZ1]]] -// CHECK: %[[READ:.+]] = vector.transfer_read %[[SLICE]] -// CHECK-SAME: [%[[C0]], %[[C0]]], %[[PAD]] -// CHECK-SAME: : tensor, vector<8x2xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast %[[READ]] : vector<8x2xf32> to vector<1x1x8x2xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[BCAST]] -// CHECK-SAME: %[[ITER1]][%[[I]], %[[J]], %[[C0]], %[[C0]] -// CHECK-SAME: {in_bounds = [true, true, true, true]} -// CHECK-SAME: : vector<1x1x8x2xf32>, tensor<2x8x8x2xf32> -// CHECK: scf.yield %[[WRITE]] -// CHECK: } -// CHECK: scf.yield %[[RES1]] -// CHECK: } -// CHECK: return %[[RES0]] - -// ----- - -func.func @KC_to_CKck(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> { - %0 = iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (tensor<128x256xf32> tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> - return %0 : tensor<32x4x32x8xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 8)> -// CHECK-LABEL: func.func @KC_to_CKck -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK: %[[RES0:.+]] = scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[OUT]]) -// CHECK: %[[RES1:.+]] = scf.for %[[K:.+]] = %[[C0]] to %[[C4]] step %[[C1]] -// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]]) -// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) -// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]]) -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] -// CHECK-SAME: {in_bounds = [true, true]} -// CHECK-SAME: : tensor<128x256xf32>, vector<32x8xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast %[[READ]] : vector<32x8xf32> to vector<1x1x32x8xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[BCAST]] -// CHECK-SAME: %[[ITER1]][%[[C]], %[[K]], %[[C0]], %[[C0]] -// CHECK-SAME: {in_bounds = [true, true, true, true]} -// CHECK-SAME: : vector<1x1x32x8xf32>, tensor<32x4x32x8xf32> -// CHECK: scf.yield %[[WRITE]] -// CHECK: } -// CHECK: scf.yield %[[RES1]] -// CHECK: } -// CHECK: return %[[RES0]] - -// ----- - -func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> { - %0 = iree_linalg_ext.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : (tensor<1x1x1x1x8x32xf32> tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> - return %0 : tensor<1x1x32x8xf32> -} -// CHECK-LABEL: func.func @simple_KCRSsr_to_KCRS -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[ZERO]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<1x1x1x1x8x32xf32>, vector<8x32xf32> -// CHECK: %[[TRANSP:.+]] = vector.transpose %[[READ]], [1, 0] -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[TRANSP]] -// CHECK-SAME: %[[OUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true]} : vector<32x8xf32>, tensor<1x1x32x8xf32> -// CHECK: return %[[WRITE]] - -// ----- - -func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> { - %0 = iree_linalg_ext.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<1x1x8x2xf32> tensor<5x1xf32>) -> tensor<5x1xf32> - return %0 : tensor<5x1xf32> -} -// CHECK-LABEL: func.func @simple_unpack_and_extract_slice -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32> -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[ZERO]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<1x1x8x2xf32>, vector<8x2xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[READ]], -// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true]} : vector<8x2xf32>, tensor<8x2xf32> -// CHECK: %[[RES:.+]] = tensor.extract_slice %[[WRITE]] -// CHECK-SAME: [0, 0] [5, 1] [1, 1] : tensor<8x2xf32> to tensor<5x1xf32> -// CHECK: return %[[RES:.+]] - -// ----- - -func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32>) -> tensor<32x8xf32>{ - %0 = iree_linalg_ext.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (tensor<1x1x32x8xf32> tensor<32x8xf32>) -> tensor<32x8xf32> - return %0 : tensor<32x8xf32> -} -// CHECK-LABEL: func.func @simple_CNnc_to_NC -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[ZERO]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<1x1x32x8xf32>, vector<32x8xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[READ]], -// CHECK-SAME: %[[EMPTY]][%[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true]} : vector<32x8xf32>, tensor<32x8xf32> -// CHECK: return %[[WRITE]] - -// ----- - -func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> { - %0 = iree_linalg_ext.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : (tensor<1x1x4x8x8x32xf32> tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> - return %0 : tensor<1x1x128x64xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-LABEL: func.func @KCRSsr_to_KCRS -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[RES0:.+]] = scf.for %[[R:.+]] = %[[C0]] to %[[C128]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[OUT]]) -// CHECK: %[[RES1:.+]] = scf.for %[[S:.+]] = %[[C0]] to %[[C64]] step %[[C8]] -// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]]) -// CHECK-DAG: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]]) -// CHECK-DAG: %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]]) -// CHECK-DAG: %[[ITER1_SLICE:.+]] = tensor.extract_slice %[[ITER1]] -// CHECK-SAME: [0, 0, %[[R]], %[[S]]] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[C0]], %[[C0]], %[[IN_R]], %[[IN_S]], %[[C0]], %[[C0]]], %[[ZERO]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<1x1x4x8x8x32xf32>, vector<8x32xf32> -// CHECK: %[[TRANSP:.+]] = vector.transpose %[[READ]], [1, 0] -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[TRANSP]] -// CHECK-SAME: %[[ITER1_SLICE]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true, true]} : vector<32x8xf32>, tensor<1x1x32x8xf32> -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[WRITE]] -// CHECK-SAME: into %[[ITER1]][0, 0, %[[R]], %[[S]]] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK: scf.yield %[[INSERT]] -// CHECK: } -// CHECK: scf.yield %[[RES1]] -// CHECK: } -// CHECK: return %[[RES0]] - -// ----- - -func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> { - %0 = iree_linalg_ext.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : (tensor<2x8x8x2xf32> tensor<13x15xf32>) -> tensor<13x15xf32> - return %0 : tensor<13x15xf32> -} -//CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)> -//CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)> -//CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -//CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)> -// CHECK-LABEL: func.func @unpack_and_extract_slice -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK-DAG: %[[C13:.+]] = arith.constant 13 : index -// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[RES0:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C13]] step %[[C8]] -// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[OUT]]) -// CHECK-DAG: %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]]) -// CHECK: %[[RES1:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[C15]] step %[[C2]] -// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]]) -// CHECK-DAG: %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]]) -// CHECK-DAG: %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]]) -// CHECK-DAG: %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]]) -// CHECK-DAG: %[[ITER1_SLICE1:.+]] = tensor.extract_slice %[[ITER1]] -// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] -// CHECK-DAG: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[IN_I]], %[[IN_J]], %[[C0]], %[[C0]]], %[[ZERO]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<2x8x8x2xf32>, vector<8x2xf32> -// CHECK-DAG: %[[ITER1_SLICE2:.+]] = tensor.extract_slice %[[ITER1_SLICE1]] -// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[READ]] -// CHECK-SAME: %[[ITER1_SLICE2]][%[[C0]], %[[C0]]] -// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[WRITE]] -// CHECK-SAME: into %[[ITER1_SLICE1]][0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] -// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] -// CHECK-SAME: into %[[ITER1]][%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1] -// CHECK: scf.yield %[[INSERT2]] -// CHECK: } -// CHECK: scf.yield %[[RES1]] -// CHECK: } -// CHECK: return %[[RES0]] - -// ----- - -func.func @CKck_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { - %0 = iree_linalg_ext.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : (tensor<32x4x32x8xf32> tensor<128x256xf32>) -> tensor<128x256xf32> - return %0 : tensor<128x256xf32> -} -//CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)> -//CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)> -// CHECK-LABEL: func.func @CKck_to_KC -// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]: -// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[RES0:.+]] = scf.for %[[K:.+]] = %[[C0]] to %[[C128]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[OUT]]) -// CHECK: %[[RES1:.+]] = scf.for %[[C:.+]] = %[[C0]] to %[[C256]] step %[[C8]] -// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[ITER0]]) -// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]]) -// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]]) -// CHECK-DAG: %[[READ:.+]] = vector.transfer_read %[[IN]] -// CHECK-SAME: [%[[IN_C]], %[[IN_K]], %[[C0]], %[[C0]]], %[[ZERO]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<32x4x32x8xf32>, vector<32x8xf32> -// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[READ]] -// CHECK-SAME: %[[ITER1]][%[[K]], %[[C]]] -// CHECK: scf.yield %[[WRITE]] -// CHECK: } -// CHECK: scf.yield %[[RES1]] -// CHECK: } -// CHECK: return %[[RES0]] diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir deleted file mode 100644 index 5f4e74e38991..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s - -// CHECK-LABEL: func.func @matmul_tensors( -// CHECK-SAME: %[[TA:[0-9a-z]+]]: memref<128x128xf32 -// CHECK-SAME: %[[TB:[0-9a-z]+]]: memref<128x128xf32 -// CHECK-SAME: %[[TC:[0-9a-z]+]]: memref<128x128xf32 -// CHECK-NOT: -> tensor -func.func @matmul_tensors( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true}) - -> tensor<128x128xf32> { - // CHECK: linalg.matmul ins(%[[TA]], %[[TB]] : memref{{.*}}, memref{{.*}} outs(%[[TC]] : memref{{.*}}) - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - - // CHECK: return %[[TC]] - return %0 : tensor<128x128xf32> -// CHECK: } -} - -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - bufferize -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir deleted file mode 100644 index de014ea0f132..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: iree-dialects-opt --transform-dialect-drop-schedule %s | FileCheck %s - -func.func @matmul_tensors( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true}) - -> tensor<128x128xf32> { - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - return %0 : tensor<128x128xf32> -} - -// CHECK-NOT: pdl.pattern -transform.with_pdl_patterns { -^bb0(%arg0: !pdl.operation): - pdl.pattern @pdl_target : benefit(1) { - %args = operands - %results = types - %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) - %1 = pdl.attribute = @matmul_tensors - apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute) - // TODO: we don't want this, but it is the required terminator for pdl.pattern - rewrite %0 with "transform.apply" - } - - // CHECK-NOT: canonicalized_sequence - transform.structured.canonicalized_sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @pdl_target in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.tile %0 [4, 4, 4] - } -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/expert.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/expert.mlir deleted file mode 100644 index 3dc056c0e3a4..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/expert.mlir +++ /dev/null @@ -1,168 +0,0 @@ -// _UN: iree-dialects-opt --linalg-transform-expert-expansion --split-input-file %s | FileCheck %s --check-prefix=EXPAND -// _UN: iree-dialects-opt --linalg-transform-expert-expansion --linalg-interp-transforms --split-input-file %s | FileCheck %s -// RUN: true - -// CHECK-LABEL: func.func @matmul_tensors -// CHECK-NOT: linalg -// CHECK: llvm -func.func @matmul_tensors( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true}) - -> tensor<128x128xf32> { - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - - return %0 : tensor<128x128xf32> -} - -pdl.pattern @pdl_target : benefit(1) { - %args = operands - %results = types - %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) - %1 = pdl.attribute = @matmul_tensors - apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute) - // TODO: we don't want this, but it is the required terminator for pdl.pattern - rewrite %0 with "iree_linalg_transform.apply" -} - -iree_linalg_transform.sequence { - // This should match the strategy below. - // EXPAND-NOT: expert apply - // EXPAND: %[[OP:.*]] = match @pdl_target - // EXPAND: %[[HANDLE:.*]], %{{.*}}:3 = tile %[[OP]] {sizes = [4, 4, 4]} - // EXPAND: %[[HANDLE2:.*]] = vectorize %[[HANDLE]] vectorize_padding - // EXPAND: bufferize - // EXPAND: lower_vectors {multireduction_lowering = "innerreduce"} - // EXPAND: lower_to_llvm - %0 = match @pdl_target - expert apply "single_tiling" to %0 - { - tile_sizes = [4, 4, 4], - vectorize_padding = true, - multireduction_lowering = "innerreduce" - } -} - -// CHECK-NOT: @strategies -// EXPAND-NOT: @strategies -module @strategies { - pdl.pattern @single_tiling_matcher : benefit(1) { - %tile_sizes = attribute - %vectorize_padding = attribute - %multireduction_lowering = attribute - %name = attribute : "single_tiling" - %type = type : !pdl.operation - %target = operand : %type - %transformed = type - %root = operation "iree_linalg_transform.expert"(%target : !pdl.value) { - "expertName" = %name, - "tile_sizes" = %tile_sizes, - "vectorize_padding" = %vectorize_padding, - "multireduction_lowering" = %multireduction_lowering - } -> (%transformed : !pdl.type) - - rewrite %root { - %tile = operation "iree_linalg_transform.tile"(%target : !pdl.value) { - "sizes" = %tile_sizes - } -> (%transformed, %transformed, %transformed, %transformed : !pdl.type, !pdl.type, !pdl.type, !pdl.type) - %handle = result 0 of %tile - - %vectorize = operation "iree_linalg_transform.vectorize"(%handle : !pdl.value) { - "vectorize_padding" = %vectorize_padding - } -> (%transformed : !pdl.type) - %handle2 = result 0 of %vectorize - - %bufferize = operation "iree_linalg_transform.bufferize" - %lower_vectors = operation "iree_linalg_transform.lower_vectors" { - "multireduction_lowering" = %multireduction_lowering - } - %lower_to_llvm = operation "iree_linalg_transform.lower_to_llvm" - - replace %root with (%handle2 : !pdl.value) - } - } -} - -// ----- - -// CHECK-LABEL: func.func @matmul_tensors2 -// CHECK-NOT: linalg -// CHECK: llvm -func.func @matmul_tensors2( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true}) - -> tensor<128x128xf32> { - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - - return %0 : tensor<128x128xf32> -} - -pdl.pattern @pdl_target2 : benefit(1) { - %args = pdl.operands - %results = pdl.types - %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) - %1 = pdl.attribute = @matmul_tensors2 - apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute) - // TODO: we don't want this, but it is the required terminator for pdl.pattern - pdl.rewrite %0 with "iree_linalg_transform.apply" -} - -iree_linalg_transform.sequence { - // This should match the strategy below. - // EXPAND-NOT: expert apply - // EXPAND: %[[OP:.*]] = match @pdl_target2 - // EXPAND: %[[HANDLE:.*]], %{{.*}}:3 = tile %[[OP]] {sizes = [32, 8, 8]} - // EXPAND: %[[HANDLE2:.*]], %{{.*}}:3 = tile %[[HANDLE]] {sizes = [4, 4, 4]} - // EXPAND: %[[HANDLE3:.*]] = vectorize %[[HANDLE2]] - // EXPAND-NOT: vectorize_padding - // EXPAND: bufferize - // EXPAND: lower_vectors {multireduction_lowering = "innerparallel"} - // EXPAND: lower_to_llvm - %0 = match @pdl_target2 - %1, %loops:3 = tile %0 {sizes = [32, 8, 8]} - expert apply "single_tiling" to %1 - { - tile_sizes = [4, 4, 4], - vectorize_padding = false, - multireduction_lowering = "innerparallel" - } -} - -module @strategies { - pdl.pattern @single_tiling_operand : benefit(1) { - %tile_sizes = attribute - %vectorize_padding = attribute - %multireduction_lowering = attribute - %name = attribute : "single_tiling" - %type = type : !pdl.operation - %target = operand : %type - %transformed = type - %root = operation "iree_linalg_transform.expert"(%target : !pdl.value) { - "expertName" = %name, - "tile_sizes" = %tile_sizes, - "vectorize_padding" = %vectorize_padding, - "multireduction_lowering" = %multireduction_lowering - } -> (%transformed : !pdl.type) - - rewrite %root { - %tile = operation "iree_linalg_transform.tile"(%target : !pdl.value) { - "sizes" = %tile_sizes - } -> (%transformed, %transformed, %transformed, %transformed : !pdl.type, !pdl.type, !pdl.type, !pdl.type) - %handle = result 0 of %tile - - %vectorize = operation "iree_linalg_transform.vectorize"(%handle : !pdl.value) { - "vectorize_padding" = %vectorize_padding - } -> (%transformed : !pdl.type) - %handle2 = result 0 of %vectorize - - %bufferize = operation "iree_linalg_transform.bufferize" - %lower_vectors = operation "iree_linalg_transform.lower_vectors" { - "multireduction_lowering" = %multireduction_lowering - } - %lower_to_llvm = operation "iree_linalg_transform.lower_to_llvm" - - replace %root with (%handle2 : !pdl.value) - } - } -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/failure.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/failure.mlir deleted file mode 100644 index 0927947503ce..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/failure.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: iree-dialects-opt --transform-dialect-interpreter --split-input-file --verify-diagnostics --allow-unregistered-dialect %s - -func.func public @no_outlining() { - // expected-note @below {{target op}} - "some.operation"() ({}, {}) : () -> () - return -} - -transform.with_pdl_patterns { -^bb0(%arg0: !pdl.operation): - pdl.pattern @some_operation : benefit(1) { - %0 = operation "some.operation" - rewrite %0 with "transform.dialect" - } - - transform.structured.canonicalized_sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @some_operation in %arg1 : (!pdl.operation) -> !pdl.operation - // Make sure we don't crash on wrong operation type. - // expected-error@below {{failed to outline}} - transform.loop.outline %0 {func_name = "outlined"} : (!pdl.operation) -> !pdl.operation - } -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/invalid.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/invalid.mlir deleted file mode 100644 index 93c21e6425ce..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/invalid.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: iree-dialects-opt %s --split-input-file -verify-diagnostics - -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation - // expected-error@below {{expects iterator_interchange to be a permutation, found [1, 1]}} - transform.structured.interchange %0 {iterator_interchange = [1, 1]} -} - -// ----- - -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation - // expected-error@below {{expected 'tile_sizes' attribute}} - transform.structured.fuse %0 -} - -// ----- - -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation - // expected-error@below {{expects interchange to be a permutation, found [1, 1]}} - transform.structured.fuse %0 {tile_sizes=[0, 1], tile_interchange = [1, 1]} -} - -// ----- - -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation - // expected-error@below {{expects pack_paddings to contain booleans (0/1), found [1, 7]}} - transform.structured.pad %0 {pack_paddings=[1, 7]} -} - -// ----- - -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation - // expected-error@below {{expects hoist_paddings to contain positive integers, found [1, -7]}} - transform.structured.pad %0 {hoist_paddings=[1, -7]} -} - -// ----- - -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation - // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}} - transform.structured.pad %0 {transpose_paddings=[[1, 1]]} -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir deleted file mode 100644 index 518e04b73c7d..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: iree-dialects-opt %s | FileCheck %s - -// CHECK: transform.structured.canonicalized_sequence -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - // CHECK: %[[OPS:.*]] = pdl_match @match1 in %{{.*}} - %0 = pdl_match @match1 in %arg0 : (!pdl.operation) -> !pdl.operation - // CHECK: %[[TILED:.*]], %{{.*}}:3 = transform.structured.tile %[[OPS]][4, 4, 4] - %1, %loops1:3 = transform.structured.tile %0 [4, 4, 4] - // CHECK: %[[TILED2:.*]], %{{.*}}:3 = transform.structured.tile %[[TILED]] - %2, %loops2:3 = transform.structured.tile %1 [2, 2, 2] - // CHECK: %[[PADDED:.*]] = transform.structured.pad %[[TILED2]] {pack_paddings = [1, 1, 0]} - %3 = transform.structured.pad %2 {pack_paddings = [1, 1, 0]} - // CHECK: %{{.*}} = transform.structured.vectorize %[[PADDED]] {vectorize_padding} - %4 = transform.structured.vectorize %3 { vectorize_padding } - // CHECK: %[[OPS2:.*]] = pdl_match @{{.*}} - %5 = pdl_match @match2 in %arg0 : (!pdl.operation) -> !pdl.operation - // CHECK: transform.structured.vectorize %[[OPS2]] - transform.structured.vectorize %5 - // CHECK: bufferize - bufferize - // CHECK: lower_vectors {multireduction_lowering = "innerreduce"} - lower_vectors { multireduction_lowering = "innerreduce"} - // CHECK: lower_to_llvm - lower_to_llvm -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/scoped.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/scoped.mlir deleted file mode 100644 index 7e66f0942b3e..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/scoped.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: iree-dialects-opt --test-wrap-scope='opname=arith.addi' %s | FileCheck %s --check-prefix WRAP -// RUN: iree-dialects-opt --test-unwrap-scope %s | FileCheck %s --check-prefix UNWRAP - -// WRAP-LABEL: @test_wrap -// WRAP-SAME: (%[[ARG0:.*]]: i32) -> i32 -func.func @test_wrap(%arg0: i32) -> i32 { - // WRAP: %[[V:.*]] = iree_linalg_transform.util.scope(%[[ARG0]], %[[ARG0]]) { - // WRAP-NEXT: ^[[B:.*]](%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32): - // WRAP-NEXT: %[[ADD:.*]] = arith.addi %[[ARG2]], %[[ARG2]] - // WRAP-NEXT: iree_linalg_transform.util.forward %[[ADD]] - // WRAP-NEXT: } : (i32, i32) -> i32 - %0 = arith.addi %arg0, %arg0 : i32 - // WRAP: return %[[V]] - return %0 : i32 -} - -// UNWRAP-LABEL: @test_unwrap -// UNWRAP-SAME: (%[[ARG0:.*]]: i32) -> (i32, i32) -func.func @test_unwrap(%arg0: i32) -> (i32, i32) { - // UNWRAP: %[[V0:.*]] = arith.addi %[[ARG0]], %[[ARG0]] - // UNWRAP-NEXT: %[[V1:.*]] = arith.addi %[[V0]], %[[ARG0]] - %0:2 = iree_linalg_transform.util.scope(%arg0) { - ^bb0(%arg1: i32): - %1 = arith.addi %arg1, %arg1 : i32 - %2 = arith.addi %1, %arg1 : i32 - iree_linalg_transform.util.forward %1, %2 : i32, i32 - } : (i32) -> (i32, i32) - // UNWRAP-NEXT: return %[[V0]], %[[V1]] - return %0#0, %0#1 : i32, i32 -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir deleted file mode 100644 index 5782ed4f2a08..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir +++ /dev/null @@ -1,159 +0,0 @@ -// RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s - -// CHECK-LABEL: func.func @matmul_tensors_1( -func.func @matmul_tensors_1( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, - %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) - -> tensor<128x128xf32> { - // This operation is marked for tiling only. - // CHECK-COUNT-3: scf.for - // CHECK-COUNT-3: tensor.extract_slice - // CHECK: linalg.matmul - // CHECK-SAME: -> tensor<4x4xf32> - %0 = linalg.matmul { test.attrA } - ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - func.return %0 : tensor<128x128xf32> -} - -func.func @matmul_tensors_2( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, - %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) - -> tensor<128x128xf32> { - // This operation is marked f - // This operation is marked for tiling and vectorization. - // Note that the loop-invariant read is hoisted out of the innermost loop. - // CHECK: scf.for - // CHECK: scf.for - // CHECK: vector.transfer_read - // CHECK: scf.for - // CHECK: vector.transfer_read - // CHECK: vector.transfer_read - // CHECK: vector.contract - // CHECK-NOT: linalg.matmul - // CHECK: vector.transfer_write - %0 = linalg.matmul { test.attrA, test.attrC } - ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - func.return %0 : tensor<128x128xf32> -} - -func.func @matmul_tensors_3( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, - %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) - -> tensor<128x128xf32> { - // This operation is marked for vectorization only. - // CHECK-NOT: scf.for - // CHECK-COUNT-3: vector.transfer_read - // CHECK: vector.contract - // CHECK-SAME: into vector<128x128xf32> - // CHECK: vector.transfer_write - %0 = linalg.matmul { test.attrC } - ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - func.return %0 : tensor<128x128xf32> -} - -transform.with_pdl_patterns { -^bb0(%arg0: !pdl.operation): - // Match matmul operations inside @matmul_tensors with test.attrA set. - pdl.pattern @pdl_target_attrA : benefit(1) { - %args = operands - %results = types - %attr = attribute - %0 = operation "linalg.matmul"(%args : !pdl.range) {"test.attrA" = %attr}-> (%results : !pdl.range) - // TODO: we don't want this, but it is the required terminator for pdl.pattern - rewrite %0 with "transform.dialect" - } - - // Match matmul operations inside @matmul_tensors with test.attrC set. - pdl.pattern @pdl_target_attrC : benefit(1) { - %args = operands - %results = types - %attr = attribute - %0 = operation "linalg.matmul"(%args : !pdl.range) {"test.attrC" = %attr}-> (%results : !pdl.range) - // TODO: we don't want this, but it is the required terminator for pdl.pattern - rewrite %0 with "transform.dialect" - } - - transform.structured.canonicalized_sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @pdl_target_attrA in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.tile %0 [4, 4, 4] - %1 = pdl_match @pdl_target_attrC in %arg1 : (!pdl.operation) -> !pdl.operation - %2 = transform.get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation - transform.structured.vectorize %2 - } -} - -// ----- - -// CHECK-LABEL: @vectorize_one -func.func @vectorize_one( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, - %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) - -> tensor<128x128xf32> { - // CHECK: vector.contract - %0 = linalg.matmul {test.attrA} - ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - func.return %0 : tensor<128x128xf32> -} - -func.func @vectorize_none( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, - %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) - -> tensor<128x128xf32> { - // CHECK: linalg.matmul - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - func.return %0 : tensor<128x128xf32> -} - -transform.with_pdl_patterns { -^bb0(%arg0: !pdl.operation): - pdl.pattern @pdl_target : benefit(1) { - %args = operands - %results = types - %attr = attribute - %0 = operation "linalg.matmul"(%args : !pdl.range) {"test.attrA" = %attr}-> (%results : !pdl.range) - // TODO: we don't want this, but it is the required terminator for pdl.pattern - rewrite %0 with "transform.dialect" - } - - transform.structured.canonicalized_sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @pdl_target in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation - transform.structured.vectorize %1 - } -} - -// ----- - -// CHECK-LABEL: @vectorize_all -func.func @vectorize_all( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, - %arg3: tensor<128x128xf32> {linalg.inplaceable = true}) - -> tensor<128x128xf32> { - // CHECK: vector.contract - %0 = linalg.matmul {test.attrA} - ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - // CHECK: vector.contract - %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg3: tensor<128x128xf32>) - -> tensor<128x128xf32> - return %1 : tensor<128x128xf32> -} - -transform.structured.canonicalized_sequence failures(propagate) { -^bb0(%arg0: !pdl.operation): - transform.structured.vectorize %arg0 -} diff --git a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir deleted file mode 100644 index d0df60fba6d6..000000000000 --- a/integrations/tensorflow/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s - -// CHECK-LABEL: func @matmul_tensors -// CHECK-NOT: linalg -// CHECK: llvm -func.func @matmul_tensors( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true}) - -> tensor<128x128xf32> { - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - - return %0 : tensor<128x128xf32> -} - -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%module_op: !pdl.operation): - %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op - %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] - %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation - transform.structured.vectorize %2 { vectorize_padding } - bufferize - lower_vectors { multireduction_lowering = "innerreduce"} - lower_to_llvm -} diff --git a/integrations/tensorflow/iree-dialects/test/lib/Dialect/CMakeLists.txt b/integrations/tensorflow/iree-dialects/test/lib/Dialect/CMakeLists.txt deleted file mode 100644 index 1da2860785ea..000000000000 --- a/integrations/tensorflow/iree-dialects/test/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(LinalgTransform) diff --git a/integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/CMakeLists.txt b/integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/CMakeLists.txt deleted file mode 100644 index 261b91d7e49c..000000000000 --- a/integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_mlir_library(IREELinalgTransformTestPasses - TestScopedTransform.cpp - - EXCLUDE_FROM_LIBMLIR - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - IREELinalgTransformDialectPasses - MLIRPass - ) diff --git a/integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/TestScopedTransform.cpp b/integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/TestScopedTransform.cpp deleted file mode 100644 index 8d46c1a0aa69..000000000000 --- a/integrations/tensorflow/iree-dialects/test/lib/Dialect/LinalgTransform/TestScopedTransform.cpp +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; -using namespace mlir::linalg; - -namespace { -struct TestWrapScopePass : public PassWrapper { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrapScopePass) - - TestWrapScopePass() = default; - TestWrapScopePass(const TestWrapScopePass &other) : PassWrapper(other) {} - - StringRef getArgument() const final { return "test-wrap-scope"; } - StringRef getDescription() const final { return "Test wrap scope pass."; } - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - getOperation()->walk([&](Operation *op) { - if (op->getName().getStringRef() != opToWrap) - return; - linalg::transform::wrapInScope(op); - }); - } - - Pass::Option opToWrap{*this, "opname", - llvm::cl::desc("Op to wrap")}; -}; - -struct TestUnwrapScopePass : public PassWrapper { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnwrapScopePass) - StringRef getArgument() const final { return "test-unwrap-scope"; } - StringRef getDescription() const final { return "Test unwrap scope pass."; } - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } - - void runOnOperation() override { - getOperation()->walk( - [](linalg::transform::ScopeOp scope) { (void)unwrapScope(scope); }); - } -}; -} // namespace - -namespace mlir { -namespace test_ext { -void registerTestLinalgTransformWrapScope() { - PassRegistration(); - PassRegistration(); -} -} // namespace test_ext -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/test/lib/Transforms/CMakeLists.txt b/integrations/tensorflow/iree-dialects/test/lib/Transforms/CMakeLists.txt deleted file mode 100644 index 7d39dbd0f207..000000000000 --- a/integrations/tensorflow/iree-dialects/test/lib/Transforms/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_mlir_library(IREETransformsTestPasses - TestListenerPasses.cpp - - DEPENDS - mlir-headers - - EXCLUDE_FROM_LIBMLIR - - LINK_LIBS PUBLIC - IREELinalgTransformDialect - MLIRPass - ) diff --git a/integrations/tensorflow/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp b/integrations/tensorflow/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp deleted file mode 100644 index 68f2103c884a..000000000000 --- a/integrations/tensorflow/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/Listener.h" -#include "iree-dialects/Transforms/ListenerCSE.h" -#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { - -/// The test listener prints stuff to `stdout` so that it can be checked by lit -/// tests. -struct TestListener : public RewriteListener { - void notifyRootReplaced(Operation *op, ValueRange newValues) override { - llvm::outs() << "REPLACED " << op->getName() << "\n"; - } - void notifyOperationRemoved(Operation *op) override { - llvm::outs() << "REMOVED " << op->getName() << "\n"; - } -}; - -struct TestListenerCanonicalizePass - : public PassWrapper { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestListenerCanonicalizePass) - - TestListenerCanonicalizePass() = default; - TestListenerCanonicalizePass(const TestListenerCanonicalizePass &other) - : PassWrapper(other) {} - - StringRef getArgument() const final { return "test-listener-canonicalize"; } - StringRef getDescription() const final { return "Test canonicalize pass."; } - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } - - void runOnOperation() override { - TestListener listener; - RewriteListener *listenerToUse = nullptr; - if (withListener) - listenerToUse = &listener; - - RewritePatternSet patterns(&getContext()); - for (Dialect *dialect : getContext().getLoadedDialects()) - dialect->getCanonicalizationPatterns(patterns); - for (RegisteredOperationName op : getContext().getRegisteredOperations()) - op.getCanonicalizationPatterns(patterns, &getContext()); - - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - GreedyRewriteConfig(), - listenerToUse))) - signalPassFailure(); - } - - Pass::Option withListener{ - *this, "listener", llvm::cl::desc("Whether to run with a test listener"), - llvm::cl::init(false)}; -}; - -struct TestListenerCSEPass : public PassWrapper { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestListenerCSEPass) - - TestListenerCSEPass() = default; - TestListenerCSEPass(const TestListenerCSEPass &other) : PassWrapper(other) {} - - StringRef getArgument() const final { return "test-listener-cse"; } - StringRef getDescription() const final { return "Test CSE pass."; } - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } - - void runOnOperation() override { - TestListener listener; - RewriteListener *listenerToUse = nullptr; - if (withListener) - listenerToUse = &listener; - - if (failed(eliminateCommonSubexpressions(getOperation(), - /*domInfo=*/nullptr, - listenerToUse))) - signalPassFailure(); - } - - Pass::Option withListener{ - *this, "listener", llvm::cl::desc("Whether to run with a test listener"), - llvm::cl::init(false)}; -}; - -} // namespace - -namespace mlir { -namespace test_ext { -void registerTestListenerPasses() { - PassRegistration(); - PassRegistration(); -} -} // namespace test_ext -} // namespace mlir diff --git a/integrations/tensorflow/iree-dialects/test/python/dialects/iree_structured_transform.py b/integrations/tensorflow/iree-dialects/test/python/dialects/iree_structured_transform.py deleted file mode 100644 index 57e76a888af8..000000000000 --- a/integrations/tensorflow/iree-dialects/test/python/dialects/iree_structured_transform.py +++ /dev/null @@ -1,30 +0,0 @@ -# RUN: %PYTHON %s | FileCheck %s - -import iree.compiler.ir as ir -import iree.compiler.dialects.transform.iree_structured as iree_structured_transform -import iree.compiler._mlir_libs._ireeDialects.transform - - -def constructAndPrintInModule(f): - print("\nTEST:", f.__name__) - with ir.Context() as ctx, ir.Location.unknown(): - iree.compiler._mlir_libs._ireeDialects.transform.register_dialect(ctx) - module = ir.Module.create() - with ir.InsertionPoint(module.body): - f() - print(module) - return f - - -# CHECK-LABEL: TEST: testLowerVectorsOp -# CHECK: transform.lower_vectors {contraction_lowering = "outerproduct", multireduction_lowering = "innerparallel", split_transfers = "linalg-copy", stages = [1], transpose_avx2_lowering = false, transpose_lowering = "shuffle", unroll_vector_transfers = true} -@constructAndPrintInModule -def testLowerVectorsOp(): - op = iree_structured_transform.LowerVectorsOp( - contraction_lowering="outerproduct", - multireduction_lowering="innerparallel", - split_transfers="linalg-copy", - stages=[1], - transpose_avx2_lowering=False, - transpose_lowering="shuffle", - unroll_vector_transfers=True) diff --git a/integrations/tensorflow/iree-dialects/test/python/smoketest.py b/integrations/tensorflow/iree-dialects/test/python/smoketest.py index 1a0402b3a665..dfc5fac50fad 100644 --- a/integrations/tensorflow/iree-dialects/test/python/smoketest.py +++ b/integrations/tensorflow/iree-dialects/test/python/smoketest.py @@ -2,10 +2,6 @@ import iree.compiler.ir from iree.compiler.dialects import iree_input as iree_d -from iree.compiler.dialects import iree_linalg_ext -from iree.compiler.dialects import iree_linalg_transform with iree.compiler.ir.Context() as ctx: iree_d.register_dialect() - iree_linalg_ext.register_dialect() - iree_linalg_transform.register_dialect() diff --git a/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt index 3741c34db2ed..b50d4742534f 100644 --- a/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt +++ b/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt @@ -1,14 +1,6 @@ set(LIBS # Local dialects. IREEInputDialect - IREELinalgExtDialect - IREELinalgExtPasses - IREELinalgExtTransformOps - IREELinalgExtTransforms - IREELinalgTransformDialect - IREELinalgTransformDialectPasses - IREELinalgTransformTestPasses - IREETransformsTestPasses # Core dialects. MLIRAffineDialect MLIRArithDialect diff --git a/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp index a3021ce027a5..1206764b28c4 100644 --- a/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp +++ b/integrations/tensorflow/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp @@ -5,12 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree-dialects/Dialect/Input/InputDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" -#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h" -#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" -#include "iree-dialects/Dialect/LinalgTransform/Passes.h" -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -40,7 +34,6 @@ namespace mlir { namespace test_ext { /// Test passes, do not deserve an include. void registerTestLinalgTransformWrapScope(); -void registerTestListenerPasses(); } // namespace test_ext } // namespace mlir @@ -53,8 +46,6 @@ int main(int argc, char **argv) { // clang-format off // Local dialects mlir::iree_compiler::IREE::Input::IREEInputDialect, - mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect, - mlir::linalg::transform::LinalgTransformDialect, // Upstream dialects mlir::async::AsyncDialect, mlir::arith::ArithDialect, @@ -75,20 +66,10 @@ int main(int argc, char **argv) { memref::registerMemRefPasses(); registerTransformsPasses(); registerSCFPasses(); - // Local dialect passes. - mlir::iree_compiler::IREE::LinalgExt::registerPasses(); - mlir::linalg::transform::registerTransformDialectInterpreterPass(); - mlir::linalg::transform::registerLinalgTransformExpertExpansionPass(); - mlir::linalg::transform::registerDropSchedulePass(); - // Local test passes. - mlir::test_ext::registerTestLinalgTransformWrapScope(); - mlir::test_ext::registerTestListenerPasses(); // External models. mlir::linalg::registerTilingInterfaceExternalModels(registry); - registry.addExtensions(); mlir::linalg::registerTransformDialectExtension(registry); mlir::scf::registerTransformDialectExtension(registry);